diff --git a/pkg/kv/kvserver/liveness/liveness.go b/pkg/kv/kvserver/liveness/liveness.go index f158f8e56675..7dfc92c866b0 100644 --- a/pkg/kv/kvserver/liveness/liveness.go +++ b/pkg/kv/kvserver/liveness/liveness.go @@ -280,9 +280,12 @@ type NodeLiveness struct { nodeDialer *nodedialer.Dialer engineSyncs *singleflight.Group - // onIsLive is a callback registered by stores prior to starting liveness. - // It fires when a node transitions from not live to live. - onIsLive []IsLiveCallback // see RegisterCallback + // onIsLiveMu holds callback registered by stores. + // They fire when a node transitions from not live to live. + onIsLiveMu struct { + syncutil.Mutex + callbacks []IsLiveCallback + } // see RegisterCallback // onSelfHeartbeat is invoked after every successful heartbeat // of the local liveness instance's heartbeat loop. @@ -548,15 +551,8 @@ func (nl *NodeLiveness) cacheUpdated(old livenesspb.Liveness, new livenesspb.Liv // Need to use a different signal to determine if liveness changed. now := nl.clock.Now() if !old.IsLive(now) && new.IsLive(now) { - // NB: If we are not started, we don't use the onIsLive callbacks since they - // can still change. This is a bit of a tangled mess since the startup of - // liveness requires the stores to be started, but stores can't start until - // liveness can run. Ideally we could cache all these updates and call - // onIsLive as part of start. - if nl.started.Get() { - for _, fn := range nl.onIsLive { - fn(new) - } + for _, fn := range nl.callbacks() { + fn(new) } } if !old.Membership.Decommissioned() && new.Membership.Decommissioned() && nl.onNodeDecommissioned != nil { @@ -639,15 +635,6 @@ func (nl *NodeLiveness) Start(ctx context.Context) { retryOpts.Closer = nl.stopper.ShouldQuiesce() nl.started.Set(true) - // We may have received some liveness records from Gossip prior to Start being - // called. We need to go through and notify all the callers of them now. - for _, entry := range nl.ScanNodeVitalityFromCache() { - if entry.IsLive(livenesspb.IsAliveNotification) { - for _, fn := range nl.onIsLive { - fn(entry.GetInternalLiveness()) - } - } - } _ = nl.stopper.RunAsyncTaskEx(ctx, stop.TaskOpts{TaskName: "liveness-hb", SpanOpt: stop.SterileRootSpan}, func(context.Context) { ambient := nl.ambientCtx @@ -746,6 +733,22 @@ func (nl *NodeLiveness) Heartbeat(ctx context.Context, liveness livenesspb.Liven return nl.heartbeatInternal(ctx, liveness, false /* increment epoch */) } +func (nl *NodeLiveness) callbacks() []IsLiveCallback { + nl.onIsLiveMu.Lock() + defer nl.onIsLiveMu.Unlock() + return append([]IsLiveCallback(nil), nl.onIsLiveMu.callbacks...) +} + +func (nl *NodeLiveness) notifyIsAliveCallbacks(fns []IsLiveCallback) { + for _, entry := range nl.ScanNodeVitalityFromCache() { + if entry.IsLive(livenesspb.IsAliveNotification) { + for _, fn := range fns { + fn(entry.GetInternalLiveness()) + } + } + } +} + func (nl *NodeLiveness) heartbeatInternal( ctx context.Context, oldLiveness livenesspb.Liveness, incrementEpoch bool, ) (err error) { @@ -1077,13 +1080,15 @@ func (nl *NodeLiveness) Metrics() Metrics { return nl.metrics } -// RegisterCallback registers a callback to be invoked any time a -// node's IsLive() state changes to true. This must be called before Start. +// RegisterCallback registers a callback to be invoked any time a node's +// IsLive() state changes to true. The provided callback will be invoked +// synchronously from RegisterCallback if the node is currently live. func (nl *NodeLiveness) RegisterCallback(cb IsLiveCallback) { - if nl.started.Get() { - log.Fatalf(context.TODO(), "RegisterCallback called after Start") - } - nl.onIsLive = append(nl.onIsLive, cb) + nl.onIsLiveMu.Lock() + nl.onIsLiveMu.callbacks = append(nl.onIsLiveMu.callbacks, cb) + nl.onIsLiveMu.Unlock() + + nl.notifyIsAliveCallbacks([]IsLiveCallback{cb}) } // updateLiveness does a conditional put on the node liveness record for the