diff --git a/dbos/system_database.go b/dbos/system_database.go index 88b78cc..6cf659f 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -3,6 +3,7 @@ package dbos import ( "context" _ "embed" + "encoding/json" "errors" "fmt" "io" @@ -485,6 +486,14 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt var result insertWorkflowResult var timeoutMSResult *int64 var workflowDeadlineEpochMS *int64 + + // Marshal authenticated roles (slice of strings) to JSON for TEXT column + authenticatedRoles, err := json.Marshal(input.status.AuthenticatedRoles) + + if err != nil { + return nil, fmt.Errorf("failed to marshal the authenticated roles: %w", err) + } + err = input.tx.QueryRow(ctx, query, input.status.ID, input.status.Status, @@ -492,7 +501,7 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt input.status.QueueName, input.status.AuthenticatedUser, input.status.AssumedRole, - input.status.AuthenticatedRoles, + authenticatedRoles, input.status.ExecutorID, applicationVersion, input.status.ApplicationID, @@ -705,11 +714,12 @@ func (s *sysDB) listWorkflows(ctx context.Context, input listWorkflowsDBInput) ( var deduplicationID *string var applicationVersion *string var executorID *string + var authenticatedRoles *string // Build scan arguments dynamically based on loaded columns scanArgs := []any{ &wf.ID, &wf.Status, &wf.Name, &wf.AuthenticatedUser, &wf.AssumedRole, - &wf.AuthenticatedRoles, &executorID, &createdAtMs, + &authenticatedRoles, &executorID, &createdAtMs, &updatedAtMs, &applicationVersion, &wf.ApplicationID, &wf.Attempts, &queueName, &timeoutMs, &deadlineMs, &startedAtMs, &deduplicationID, &wf.Priority, @@ -727,6 +737,12 @@ func (s *sysDB) listWorkflows(ctx context.Context, input listWorkflowsDBInput) ( return nil, fmt.Errorf("failed to scan workflow row: %w", err) } + if authenticatedRoles != nil && *authenticatedRoles != "" { + if err := json.Unmarshal([]byte(*authenticatedRoles), &wf.AuthenticatedRoles); err != nil { + return nil, fmt.Errorf("failed to unmarshal authenticated_roles: %w", err) + } + } + if queueName != nil && len(*queueName) > 0 { wf.QueueName = *queueName } @@ -1088,13 +1104,20 @@ func (s *sysDB) forkWorkflow(ctx context.Context, input forkWorkflowDBInput) (st return "", fmt.Errorf("failed to serialize input: %w", err) } + // Marshal authenticated roles (slice of strings) to JSON for TEXT column + authenticatedRoles, err := json.Marshal(originalWorkflow.AuthenticatedRoles) + + if err != nil { + return "", fmt.Errorf("failed to marshal the authenticated roles: %w", err) + } + _, err = tx.Exec(ctx, insertQuery, forkedWorkflowID, WorkflowStatusEnqueued, originalWorkflow.Name, originalWorkflow.AuthenticatedUser, originalWorkflow.AssumedRole, - originalWorkflow.AuthenticatedRoles, + authenticatedRoles, &appVersion, originalWorkflow.ApplicationID, _DBOS_INTERNAL_QUEUE_NAME, diff --git a/dbos/workflow.go b/dbos/workflow.go index 475d66f..a7d6013 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -488,13 +488,16 @@ type Workflow[P any, R any] func(ctx DBOSContext, input P) (R, error) type WorkflowFunc func(ctx DBOSContext, input any) (any, error) type workflowOptions struct { - workflowName string - workflowID string - queueName string - applicationVersion string - maxRetries int - deduplicationID string - priority uint + workflowName string + workflowID string + queueName string + applicationVersion string + maxRetries int + deduplicationID string + priority uint + authenticated_user string + assumed_role string + authenticated_roles []string } // WorkflowOption is a functional option for configuring workflow execution parameters. @@ -544,6 +547,27 @@ func withWorkflowName(name string) WorkflowOption { } } +// Sets the authenticated user for the workflow +func WithAuthenticatedUser(user string) WorkflowOption { + return func(p *workflowOptions) { + p.authenticated_user = user + } +} + +// Sets the assumed role for the workflow +func WithAssumedRole(role string) WorkflowOption { + return func(p *workflowOptions) { + p.assumed_role = role + } +} + +// Sets the authenticated role for the workflow +func WithAuthenticatedRoles(roles []string) WorkflowOption { + return func(p *workflowOptions) { + p.authenticated_roles = roles + } +} + // RunWorkflow executes a workflow function with type safety and durability guarantees. // The workflow can be executed immediately or enqueued for later execution based on options. // Returns a typed handle that can be used to wait for completion and retrieve results. @@ -730,6 +754,9 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt QueueName: params.queueName, DeduplicationID: params.deduplicationID, Priority: int(params.priority), + AuthenticatedUser: params.authenticated_user, + AssumedRole: params.assumed_role, + AuthenticatedRoles: params.authenticated_roles, } var earlyReturnPollingHandle *workflowPollingHandle[any] diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 9d77014..dcc0d7e 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -4131,3 +4131,32 @@ func TestSpecialSteps(t *testing.T) { require.Equal(t, "success", result, "workflow should return success") }) } +func TestWorkflowIdentity(t *testing.T) { + dbosCtx := setupDBOS(t, true, true) + RegisterWorkflow(dbosCtx, simpleWorkflow) + handle, err := RunWorkflow( + dbosCtx, + simpleWorkflow, + "test", + WithWorkflowID("my-workflow-id"), + WithAuthenticatedUser("user123"), + WithAssumedRole("admin"), + WithAuthenticatedRoles([]string{"reader", "writer"})) + require.NoError(t, err, "failed to start workflow") + + // Retrieve the workflow's status. + status, err := handle.GetStatus() + require.NoError(t, err) + + t.Run("CheckAuthenticatedUser", func(t *testing.T) { + assert.Equal(t, "user123", status.AuthenticatedUser) + }) + + t.Run("CheckAssumedRole", func(t *testing.T) { + assert.Equal(t, "admin", status.AssumedRole) + }) + + t.Run("CheckAuthenticatedRoles", func(t *testing.T) { + assert.Equal(t, []string{"reader", "writer"}, status.AuthenticatedRoles) + }) +}