Skip to content

Commit

Permalink
Account for racecondition in deleting/closing update channel
Browse files Browse the repository at this point in the history
This commit tries to address the possible raceondition  that can happen
if a client closes its connection after we have fetched it from the
syncmap before sending the message.

To try to avoid introducing new dead lock conditions, all messages sent
to updateChannel has been moved into a function, which handles the
locking (instead of calling it all over the place)

The same lock is used around the delete/close function.
  • Loading branch information
kradalby committed Aug 20, 2021
1 parent 1f422af commit 88d7ac0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 28 deletions.
3 changes: 2 additions & 1 deletion app.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ type Headscale struct {
aclPolicy *ACLPolicy
aclRules *[]tailcfg.FilterRule

clientsUpdateChannels sync.Map
clientsUpdateChannels sync.Map
clientsUpdateChannelMutex sync.Mutex

lastStateChange sync.Map
}
Expand Down
42 changes: 40 additions & 2 deletions machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) {
Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
err := h.requestUpdate(p)
err := h.sendRequestOnUpdateChannel(p)
if err != nil {
log.Info().
Str("func", "notifyChangesToPeers").
Expand All @@ -283,7 +283,45 @@ func (h *Headscale) notifyChangesToPeers(m *Machine) {
}
}

func (h *Headscale) requestUpdate(m *tailcfg.Node) error {
func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} {
var updateChan chan struct{}
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
updateChan = unwrapped
} else {
log.Error().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Failed to convert update channel to struct{}")
}
} else {
log.Debug().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Update channel not found, creating")

updateChan = make(chan struct{})
h.clientsUpdateChannels.Store(m.ID, updateChan)
}
return updateChan
}

func (h *Headscale) closeUpdateChannel(m *Machine) {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()

if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
close(unwrapped)
}
}
h.clientsUpdateChannels.Delete(m.ID)
}

func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()

pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID))
if ok {
log.Info().
Expand Down
29 changes: 4 additions & 25 deletions poll.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Loading or creating update channel")
var updateChan chan struct{}
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if wrapped, ok := storedChan.(chan struct{}); ok {
updateChan = wrapped
} else {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Failed to convert update channel to struct{}")
}
} else {
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Update channel not found, creating")

updateChan = make(chan struct{})
h.clientsUpdateChannels.Store(m.ID, updateChan)
}
updateChan := h.getOrOpenUpdateChannel(&m)

pollDataChan := make(chan []byte)
// defer close(pollData)
Expand Down Expand Up @@ -215,7 +195,7 @@ func (h *Headscale) PollNetMapStream(
mKey wgkey.Key,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan chan struct{},
updateChan <-chan struct{},
cancelKeepAlive chan struct{},
) {
go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m)
Expand Down Expand Up @@ -364,8 +344,7 @@ func (h *Headscale) PollNetMapStream(

cancelKeepAlive <- struct{}{}

h.clientsUpdateChannels.Delete(m.ID)
// close(updateChan)
h.closeUpdateChannel(&m)

close(pollDataChan)

Expand Down Expand Up @@ -411,7 +390,7 @@ func (h *Headscale) scheduledPollWorker(
// Send an update request regardless of outdated or not, if data is sent
// to the node is determined in the updateChan consumer block
n, _ := m.toNode()
err := h.requestUpdate(n)
err := h.sendRequestOnUpdateChannel(n)
if err != nil {
log.Error().
Str("func", "keepAlive").
Expand Down

0 comments on commit 88d7ac0

Please sign in to comment.