Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ type systemDatabase interface {
}

type sysDB struct {
pool *pgxpool.Pool
notificationLoopDone chan struct{}
notificationsMap *sync.Map
logger *slog.Logger
schema string
launched bool
pool *pgxpool.Pool
notificationLoopDone chan struct{}
workflowNotificationsMap *sync.Map
workflowEventsMap *sync.Map
logger *slog.Logger
schema string
launched bool
}

/*******************************/
Expand Down Expand Up @@ -330,14 +331,16 @@ func newSystemDatabase(ctx context.Context, inputs newSystemDatabaseInput) (syst
}

// Create a map of notification payloads to channels
notificationsMap := &sync.Map{}
workflowNotificationsMap := &sync.Map{}
workflowEventsMap := &sync.Map{}

return &sysDB{
pool: pool,
notificationsMap: notificationsMap,
notificationLoopDone: make(chan struct{}),
logger: logger.With("service", "system_database"),
schema: databaseSchema,
pool: pool,
workflowNotificationsMap: workflowNotificationsMap,
workflowEventsMap: workflowEventsMap,
notificationLoopDone: make(chan struct{}),
logger: logger.With("service", "system_database"),
schema: databaseSchema,
}, nil
}

Expand Down Expand Up @@ -374,7 +377,8 @@ func (s *sysDB) shutdown(ctx context.Context, timeout time.Duration) {
}
}

s.notificationsMap.Clear()
s.workflowNotificationsMap.Clear()
s.workflowEventsMap.Clear()

s.launched = false
}
Expand Down Expand Up @@ -1672,9 +1676,13 @@ func (s *sysDB) notificationListenerLoop(ctx context.Context) {
retryAttempt--
}

// Handle notifications
if n.Channel == _DBOS_NOTIFICATIONS_CHANNEL || n.Channel == _DBOS_WORKFLOW_EVENTS_CHANNEL {
if cond, ok := s.notificationsMap.Load(n.Payload); ok {
switch n.Channel {
case _DBOS_NOTIFICATIONS_CHANNEL:
if cond, ok := s.workflowNotificationsMap.Load(n.Payload); ok {
cond.(*sync.Cond).Broadcast()
}
case _DBOS_WORKFLOW_EVENTS_CHANNEL:
if cond, ok := s.workflowEventsMap.Load(n.Payload); ok {
cond.(*sync.Cond).Broadcast()
}
}
Expand Down Expand Up @@ -1820,15 +1828,15 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {
// First check if there's already a receiver for this workflow/topic to avoid unnecessary database load
payload := fmt.Sprintf("%s::%s", destinationID, topic)
cond := sync.NewCond(&sync.Mutex{})
_, loaded := s.notificationsMap.LoadOrStore(payload, cond)
_, loaded := s.workflowNotificationsMap.LoadOrStore(payload, cond)
if loaded {
s.logger.Error("Receive already called for workflow", "destination_id", destinationID)
return nil, newWorkflowConflictIDError(destinationID)
}
defer func() {
// Clean up the condition variable after we're done and broadcast to wake up any waiting goroutines
cond.Broadcast()
s.notificationsMap.Delete(payload)
s.workflowNotificationsMap.Delete(payload)
}()

// Now check if there is already a message available in the database.
Expand Down Expand Up @@ -2048,7 +2056,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
// Create notification payload and condition variable
payload := fmt.Sprintf("%s::%s", input.TargetWorkflowID, input.Key)
cond := sync.NewCond(&sync.Mutex{})
existingCond, loaded := s.notificationsMap.LoadOrStore(payload, cond)
existingCond, loaded := s.workflowEventsMap.LoadOrStore(payload, cond)
if loaded {
// Reuse the existing condition variable
cond = existingCond.(*sync.Cond)
Expand All @@ -2059,7 +2067,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
cond.Broadcast()
// Clean up the condition variable after we're done, if we created it
if !loaded {
s.notificationsMap.Delete(payload)
s.workflowEventsMap.Delete(payload)
}
}()

Expand Down
Loading