Skip to content

Commit

Permalink
feat: add agent acks to in-memory coordinator (#12786)
Browse files Browse the repository at this point in the history
When an agent receives a node, it responds with an ACK which is relayed
to the client. After the client receives the ACK, it's allowed to begin
pinging.
  • Loading branch information
coadler committed Apr 10, 2024
1 parent 9cf2358 commit e801e87
Show file tree
Hide file tree
Showing 13 changed files with 878 additions and 122 deletions.
4 changes: 3 additions & 1 deletion codersdk/workspacesdk/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ func runTailnetAPIConnector(
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
defer tac.cancelGracefulCtx()
<-tac.ctx.Done()
timer := time.NewTimer(time.Second)
defer timer.Stop()
select {
case <-tac.closed:
case <-time.After(time.Second):
case <-timer.C:
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}

func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}

func (*fakeTailnetConn) SetTunnelDestination(uuid.UUID) {}

func newFakeTailnetConn() *fakeTailnetConn {
return &fakeTailnetConn{}
}
2 changes: 1 addition & 1 deletion enterprise/coderd/workspaceproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request)
if err != nil {
return xerrors.Errorf("insert replica: %w", err)
}
} else if err != nil {
} else {
return xerrors.Errorf("get replica: %w", err)
}

Expand Down
200 changes: 160 additions & 40 deletions tailnet/configmaps.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func (c *configMaps) close() {
c.L.Lock()
defer c.L.Unlock()
for _, lc := range c.peers {
lc.resetTimer()
lc.resetLostTimer()
}
c.closing = true
c.Broadcast()
Expand Down Expand Up @@ -216,6 +216,12 @@ func (c *configMaps) netMapLocked() *netmap.NetworkMap {
func (c *configMaps) peerConfigLocked() []*tailcfg.Node {
out := make([]*tailcfg.Node, 0, len(c.peers))
for _, p := range c.peers {
// Don't add nodes that we havent received a READY_FOR_HANDSHAKE for
// yet, if they're a destination. If we received a READY_FOR_HANDSHAKE
// for a peer before we receive their node, the node will be nil.
if (!p.readyForHandshake && p.isDestination) || p.node == nil {
continue
}
n := p.node.Clone()
if c.blockEndpoints {
n.Endpoints = nil
Expand All @@ -225,6 +231,19 @@ func (c *configMaps) peerConfigLocked() []*tailcfg.Node {
return out
}

func (c *configMaps) setTunnelDestination(id uuid.UUID) {
c.L.Lock()
defer c.L.Unlock()
lc, ok := c.peers[id]
if !ok {
lc = &peerLifecycle{
peerID: id,
}
c.peers[id] = lc
}
lc.isDestination = true
}

// setAddresses sets the addresses belonging to this node to the given slice. It
// triggers configuration of the engine if the addresses have changed.
// c.L MUST NOT be held.
Expand Down Expand Up @@ -331,8 +350,10 @@ func (c *configMaps) updatePeers(updates []*proto.CoordinateResponse_PeerUpdate)
// worry about them being up-to-date when handling updates below, and it covers
// all peers, not just the ones we got updates about.
for _, lc := range c.peers {
if peerStatus, ok := status.Peer[lc.node.Key]; ok {
lc.lastHandshake = peerStatus.LastHandshake
if lc.node != nil {
if peerStatus, ok := status.Peer[lc.node.Key]; ok {
lc.lastHandshake = peerStatus.LastHandshake
}
}
}

Expand Down Expand Up @@ -363,7 +384,7 @@ func (c *configMaps) updatePeerLocked(update *proto.CoordinateResponse_PeerUpdat
return false
}
logger := c.logger.With(slog.F("peer_id", id))
lc, ok := c.peers[id]
lc, peerOk := c.peers[id]
var node *tailcfg.Node
if update.Kind == proto.CoordinateResponse_PeerUpdate_NODE {
// If no preferred DERP is provided, we can't reach the node.
Expand All @@ -377,48 +398,76 @@ func (c *configMaps) updatePeerLocked(update *proto.CoordinateResponse_PeerUpdat
return false
}
logger = logger.With(slog.F("key_id", node.Key.ShortString()), slog.F("node", node))
peerStatus, ok := status.Peer[node.Key]
// Starting KeepAlive messages at the initialization of a connection
// causes a race condition. If we send the handshake before the peer has
// our node, we'll have to wait for 5 seconds before trying again.
// Ideally, the first handshake starts when the user first initiates a
// connection to the peer. After a successful connection we enable
// keep alives to persist the connection and keep it from becoming idle.
// SSH connections don't send packets while idle, so we use keep alives
// to avoid random hangs while we set up the connection again after
// inactivity.
node.KeepAlive = ok && peerStatus.Active
node.KeepAlive = c.nodeKeepalive(lc, status, node)
}
switch {
case !ok && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
case !peerOk && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
// new!
var lastHandshake time.Time
if ps, ok := status.Peer[node.Key]; ok {
lastHandshake = ps.LastHandshake
}
c.peers[id] = &peerLifecycle{
lc = &peerLifecycle{
peerID: id,
node: node,
lastHandshake: lastHandshake,
lost: false,
}
c.peers[id] = lc
logger.Debug(context.Background(), "adding new peer")
return true
case ok && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
return lc.validForWireguard()
case peerOk && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
// update
node.Created = lc.node.Created
if lc.node != nil {
node.Created = lc.node.Created
}
dirty = !lc.node.Equal(node)
lc.node = node
// validForWireguard checks that the node is non-nil, so should be
// called after we update the node.
dirty = dirty && lc.validForWireguard()
lc.lost = false
lc.resetTimer()
lc.resetLostTimer()
if lc.isDestination && !lc.readyForHandshake {
// We received the node of a destination peer before we've received
// their READY_FOR_HANDSHAKE. Set a timer
lc.setReadyForHandshakeTimer(c)
logger.Debug(context.Background(), "setting ready for handshake timeout")
}
logger.Debug(context.Background(), "node update to existing peer", slog.F("dirty", dirty))
return dirty
case !ok:
case peerOk && update.Kind == proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
dirty := !lc.readyForHandshake
lc.readyForHandshake = true
if lc.readyForHandshakeTimer != nil {
lc.readyForHandshakeTimer.Stop()
}
if lc.node != nil {
old := lc.node.KeepAlive
lc.node.KeepAlive = c.nodeKeepalive(lc, status, lc.node)
dirty = dirty || (old != lc.node.KeepAlive)
}
logger.Debug(context.Background(), "peer ready for handshake")
// only force a reconfig if the node populated
return dirty && lc.node != nil
case !peerOk && update.Kind == proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
// When we receive a READY_FOR_HANDSHAKE for a peer we don't know about,
// we create a peerLifecycle with the peerID and set readyForHandshake
// to true. Eventually we should receive a NODE update for this peer,
// and it'll be programmed into wireguard.
logger.Debug(context.Background(), "got peer ready for handshake for unknown peer")
lc = &peerLifecycle{
peerID: id,
readyForHandshake: true,
}
c.peers[id] = lc
return false
case !peerOk:
// disconnected or lost, but we don't have the node. No op
logger.Debug(context.Background(), "skipping update for peer we don't recognize")
return false
case update.Kind == proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
lc.resetTimer()
lc.resetLostTimer()
delete(c.peers, id)
logger.Debug(context.Background(), "disconnected peer")
return true
Expand Down Expand Up @@ -476,10 +525,12 @@ func (c *configMaps) peerLostTimeout(id uuid.UUID) {
"timeout triggered for peer that is removed from the map")
return
}
if peerStatus, ok := status.Peer[lc.node.Key]; ok {
lc.lastHandshake = peerStatus.LastHandshake
if lc.node != nil {
if peerStatus, ok := status.Peer[lc.node.Key]; ok {
lc.lastHandshake = peerStatus.LastHandshake
}
logger = logger.With(slog.F("key_id", lc.node.Key.ShortString()))
}
logger = logger.With(slog.F("key_id", lc.node.Key.ShortString()))
if !lc.lost {
logger.Debug(context.Background(),
"timeout triggered for peer that is no longer lost")
Expand Down Expand Up @@ -522,7 +573,7 @@ func (c *configMaps) nodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bo
c.L.Lock()
defer c.L.Unlock()
for _, lc := range c.peers {
if lc.node.Key == publicKey {
if lc.node != nil && lc.node.Key == publicKey {
return lc.node.Addresses, true
}
}
Expand All @@ -539,9 +590,10 @@ func (c *configMaps) fillPeerDiagnostics(d *PeerDiagnostics, peerID uuid.UUID) {
}
}
lc, ok := c.peers[peerID]
if !ok {
if !ok || lc.node == nil {
return
}

d.ReceivedNode = lc.node
ps, ok := status.Peer[lc.node.Key]
if !ok {
Expand All @@ -550,34 +602,102 @@ func (c *configMaps) fillPeerDiagnostics(d *PeerDiagnostics, peerID uuid.UUID) {
d.LastWireguardHandshake = ps.LastHandshake
}

func (c *configMaps) peerReadyForHandshakeTimeout(peerID uuid.UUID) {
logger := c.logger.With(slog.F("peer_id", peerID))
logger.Debug(context.Background(), "peer ready for handshake timeout")
c.L.Lock()
defer c.L.Unlock()
lc, ok := c.peers[peerID]
if !ok {
logger.Debug(context.Background(),
"ready for handshake timeout triggered for peer that is removed from the map")
return
}

wasReady := lc.readyForHandshake
lc.readyForHandshake = true
if !wasReady {
logger.Info(context.Background(), "setting peer ready for handshake after timeout")
c.netmapDirty = true
c.Broadcast()
}
}

func (*configMaps) nodeKeepalive(lc *peerLifecycle, status *ipnstate.Status, node *tailcfg.Node) bool {
// If the peer is already active, keepalives should be enabled.
if peerStatus, statusOk := status.Peer[node.Key]; statusOk && peerStatus.Active {
return true
}
// If the peer is a destination, we should only enable keepalives if we've
// received the READY_FOR_HANDSHAKE.
if lc != nil && lc.isDestination && lc.readyForHandshake {
return true
}

// If none of the above are true, keepalives should not be enabled.
return false
}

type peerLifecycle struct {
peerID uuid.UUID
node *tailcfg.Node
lost bool
lastHandshake time.Time
timer *clock.Timer
peerID uuid.UUID
// isDestination specifies if the peer is a destination, meaning we
// initiated a tunnel to the peer. When the peer is a destination, we do not
// respond to node updates with `READY_FOR_HANDSHAKE`s, and we wait to
// program the peer into wireguard until we receive a READY_FOR_HANDSHAKE
// from the peer or the timeout is reached.
isDestination bool
// node is the tailcfg.Node for the peer. It may be nil until we receive a
// NODE update for it.
node *tailcfg.Node
lost bool
lastHandshake time.Time
lostTimer *clock.Timer
readyForHandshake bool
readyForHandshakeTimer *clock.Timer
}

func (l *peerLifecycle) resetTimer() {
if l.timer != nil {
l.timer.Stop()
l.timer = nil
func (l *peerLifecycle) resetLostTimer() {
if l.lostTimer != nil {
l.lostTimer.Stop()
l.lostTimer = nil
}
}

func (l *peerLifecycle) setLostTimer(c *configMaps) {
if l.timer != nil {
l.timer.Stop()
if l.lostTimer != nil {
l.lostTimer.Stop()
}
ttl := lostTimeout - c.clock.Since(l.lastHandshake)
if ttl <= 0 {
ttl = time.Nanosecond
}
l.timer = c.clock.AfterFunc(ttl, func() {
l.lostTimer = c.clock.AfterFunc(ttl, func() {
c.peerLostTimeout(l.peerID)
})
}

const readyForHandshakeTimeout = 5 * time.Second

func (l *peerLifecycle) setReadyForHandshakeTimer(c *configMaps) {
if l.readyForHandshakeTimer != nil {
l.readyForHandshakeTimer.Stop()
}
l.readyForHandshakeTimer = c.clock.AfterFunc(readyForHandshakeTimeout, func() {
c.logger.Debug(context.Background(), "ready for handshake timeout", slog.F("peer_id", l.peerID))
c.peerReadyForHandshakeTimeout(l.peerID)
})
}

// validForWireguard returns true if the peer is ready to be programmed into
// wireguard.
func (l *peerLifecycle) validForWireguard() bool {
valid := l.node != nil
if l.isDestination {
return valid && l.readyForHandshake
}
return valid
}

// prefixesDifferent returns true if the two slices contain different prefixes
// where order doesn't matter.
func prefixesDifferent(a, b []netip.Prefix) bool {
Expand Down

0 comments on commit e801e87

Please sign in to comment.