Skip to content

Commit 2b7904d

Browse files
committed
Centralize worker stream cache
Instead of each worker having their own cache, have the parent node hold the cache to avoid duplicate caching.
1 parent a19129a commit 2b7904d

File tree

3 files changed

+70
-11
lines changed

3 files changed

+70
-11
lines changed

pool/node.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ type (
5353

5454
localWorkers sync.Map // workers created by this node
5555
workerStreams sync.Map // worker streams indexed by ID
56+
workerAckStreams sync.Map // streams for worker acks indexed by ID
5657
pendingJobChannels sync.Map // channels used to send DispatchJob results, nil if event is requeued
5758
pendingEvents sync.Map // pending events indexed by sender and event IDs
5859

@@ -218,6 +219,7 @@ func AddNode(ctx context.Context, poolName string, rdb *redis.Client, opts ...No
218219
shutdownMap: wsm,
219220
tickerMap: tm,
220221
workerStreams: sync.Map{},
222+
workerAckStreams: sync.Map{},
221223
pendingJobChannels: sync.Map{},
222224
pendingEvents: sync.Map{},
223225
poolStream: poolStream,
@@ -682,7 +684,7 @@ func (node *Node) ackWorkerEvent(ctx context.Context, ev *streaming.Event) {
682684
// dispatched the job.
683685
if pending.EventName == evStartJob {
684686
_, nodeID := unmarshalJobKeyAndNodeID(pending.Payload)
685-
stream, err := streaming.NewStream(nodeStreamName(node.PoolName, nodeID), node.rdb, soptions.WithStreamLogger(node.logger))
687+
stream, err := node.getOrCreateWorkerAckStream(ctx, nodeID)
686688
if err != nil {
687689
node.logger.Error(fmt.Errorf("ackWorkerEvent: failed to create node event stream %q: %w", nodeStreamName(node.PoolName, nodeID), err))
688690
return
@@ -1111,6 +1113,29 @@ func (node *Node) workerStream(_ context.Context, id string) (*streaming.Stream,
11111113
return val.(*streaming.Stream), nil
11121114
}
11131115

1116+
// getOrCreateWorkerAckStream gets or creates a stream for worker acks
1117+
func (node *Node) getOrCreateWorkerAckStream(ctx context.Context, nodeID string) (*streaming.Stream, error) {
1118+
if val, ok := node.workerAckStreams.Load(nodeID); ok {
1119+
return val.(*streaming.Stream), nil
1120+
}
1121+
1122+
stream, err := streaming.NewStream(
1123+
nodeStreamName(node.PoolName, nodeID),
1124+
node.rdb,
1125+
soptions.WithStreamLogger(node.logger),
1126+
)
1127+
if err != nil {
1128+
return nil, err
1129+
}
1130+
1131+
actual, loaded := node.workerAckStreams.LoadOrStore(nodeID, stream)
1132+
if loaded {
1133+
// Another goroutine created the stream first, just discard our local reference
1134+
return actual.(*streaming.Stream), nil
1135+
}
1136+
return stream, nil
1137+
}
1138+
11141139
// cleanup removes the worker from all pool maps.
11151140
func (node *Node) cleanupWorker(ctx context.Context, id string) {
11161141
if _, err := node.workerMap.Delete(ctx, id); err != nil {

pool/node_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,44 @@ func TestShutdownStopsAllJobs(t *testing.T) {
834834
assert.Empty(t, worker2.Jobs(), "Worker2 should have no remaining jobs")
835835
}
836836

837+
func TestWorkerAckStreams(t *testing.T) {
838+
testName := strings.Replace(t.Name(), "/", "_", -1)
839+
ctx := ptesting.NewTestContext(t)
840+
rdb := ptesting.NewRedisClient(t)
841+
node := newTestNode(t, ctx, rdb, testName)
842+
defer ptesting.CleanupRedis(t, rdb, true, testName)
843+
844+
// Create a worker and dispatch a job
845+
worker := newTestWorker(t, ctx, node)
846+
require.NoError(t, node.DispatchJob(ctx, testName, []byte("payload")))
847+
848+
// Wait for the job to start and be acknowledged
849+
require.Eventually(t, func() bool {
850+
return len(worker.Jobs()) == 1
851+
}, max, delay)
852+
853+
// Verify stream is created and cached
854+
stream1, err := node.getOrCreateWorkerAckStream(ctx, node.ID)
855+
require.NoError(t, err)
856+
stream2, err := node.getOrCreateWorkerAckStream(ctx, node.ID)
857+
require.NoError(t, err)
858+
assert.Same(t, stream1, stream2, "Expected same stream instance to be returned")
859+
860+
// Verify stream exists before shutdown
861+
streamKey := "pulse:stream:" + nodeStreamName(testName, node.ID)
862+
exists, err := rdb.Exists(ctx, streamKey).Result()
863+
assert.NoError(t, err)
864+
assert.Equal(t, int64(1), exists, "Expected stream to exist before shutdown")
865+
866+
// Shutdown node
867+
assert.NoError(t, node.Shutdown(ctx))
868+
869+
// Verify stream is destroyed in Redis
870+
exists, err = rdb.Exists(ctx, streamKey).Result()
871+
assert.NoError(t, err)
872+
assert.Equal(t, int64(0), exists, "Expected stream to be destroyed after shutdown")
873+
}
874+
837875
type mockAcker struct {
838876
XAckFunc func(ctx context.Context, streamKey, sinkName string, ids ...string) *redis.IntCmd
839877
}

pool/worker.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -291,22 +291,18 @@ func (w *Worker) notify(_ context.Context, key string, payload []byte) error {
291291
// ackPoolEvent acknowledges the pool event that originated from the node with
292292
// the given ID.
293293
func (w *Worker) ackPoolEvent(ctx context.Context, nodeID, eventID string, ackerr error) {
294-
stream, ok := w.nodeStreams.Load(nodeID)
295-
if !ok {
296-
var err error
297-
stream, err = streaming.NewStream(nodeStreamName(w.node.PoolName, nodeID), w.node.rdb, soptions.WithStreamLogger(w.logger))
298-
if err != nil {
299-
w.logger.Error(fmt.Errorf("failed to create stream for node %q: %w", nodeID, err))
300-
return
301-
}
302-
w.nodeStreams.Store(nodeID, stream)
294+
stream, err := w.node.getOrCreateWorkerAckStream(ctx, nodeID)
295+
if err != nil {
296+
w.logger.Error(fmt.Errorf("failed to get ack stream for node %q: %w", nodeID, err))
297+
return
303298
}
299+
304300
var msg string
305301
if ackerr != nil {
306302
msg = ackerr.Error()
307303
}
308304
ack := &ack{EventID: eventID, Error: msg}
309-
if _, err := stream.(*streaming.Stream).Add(ctx, evAck, marshalEnvelope(w.ID, marshalAck(ack))); err != nil {
305+
if _, err := stream.Add(ctx, evAck, marshalEnvelope(w.ID, marshalAck(ack))); err != nil {
310306
w.logger.Error(fmt.Errorf("failed to ack event %q from node %q: %w", eventID, nodeID, err))
311307
}
312308
}

0 commit comments

Comments
 (0)