Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dbos
import (
"context"
_ "embed"
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -485,14 +486,22 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt
var result insertWorkflowResult
var timeoutMSResult *int64
var workflowDeadlineEpochMS *int64

// Marshal authenticated roles (slice of strings) to JSON for TEXT column
authenticatedRoles, err := json.Marshal(input.status.AuthenticatedRoles)

if err != nil {
return nil, fmt.Errorf("failed to marshal the authenticated roles: %w", err)
}

err = input.tx.QueryRow(ctx, query,
input.status.ID,
input.status.Status,
input.status.Name,
input.status.QueueName,
input.status.AuthenticatedUser,
input.status.AssumedRole,
input.status.AuthenticatedRoles,
authenticatedRoles,
input.status.ExecutorID,
applicationVersion,
input.status.ApplicationID,
Expand Down Expand Up @@ -705,11 +714,12 @@ func (s *sysDB) listWorkflows(ctx context.Context, input listWorkflowsDBInput) (
var deduplicationID *string
var applicationVersion *string
var executorID *string
var authenticatedRoles *string

// Build scan arguments dynamically based on loaded columns
scanArgs := []any{
&wf.ID, &wf.Status, &wf.Name, &wf.AuthenticatedUser, &wf.AssumedRole,
&wf.AuthenticatedRoles, &executorID, &createdAtMs,
&authenticatedRoles, &executorID, &createdAtMs,
&updatedAtMs, &applicationVersion, &wf.ApplicationID,
&wf.Attempts, &queueName, &timeoutMs,
&deadlineMs, &startedAtMs, &deduplicationID, &wf.Priority,
Expand All @@ -727,6 +737,12 @@ func (s *sysDB) listWorkflows(ctx context.Context, input listWorkflowsDBInput) (
return nil, fmt.Errorf("failed to scan workflow row: %w", err)
}

if authenticatedRoles != nil && *authenticatedRoles != "" {
if err := json.Unmarshal([]byte(*authenticatedRoles), &wf.AuthenticatedRoles); err != nil {
return nil, fmt.Errorf("failed to unmarshal authenticated_roles: %w", err)
}
}

if queueName != nil && len(*queueName) > 0 {
wf.QueueName = *queueName
}
Expand Down Expand Up @@ -1088,13 +1104,20 @@ func (s *sysDB) forkWorkflow(ctx context.Context, input forkWorkflowDBInput) (st
return "", fmt.Errorf("failed to serialize input: %w", err)
}

// Marshal authenticated roles (slice of strings) to JSON for TEXT column
authenticatedRoles, err := json.Marshal(originalWorkflow.AuthenticatedRoles)

if err != nil {
return "", fmt.Errorf("failed to marshal the authenticated roles: %w", err)
}

_, err = tx.Exec(ctx, insertQuery,
forkedWorkflowID,
WorkflowStatusEnqueued,
originalWorkflow.Name,
originalWorkflow.AuthenticatedUser,
originalWorkflow.AssumedRole,
originalWorkflow.AuthenticatedRoles,
authenticatedRoles,
&appVersion,
originalWorkflow.ApplicationID,
_DBOS_INTERNAL_QUEUE_NAME,
Expand Down
41 changes: 34 additions & 7 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,13 +488,16 @@ type Workflow[P any, R any] func(ctx DBOSContext, input P) (R, error)
type WorkflowFunc func(ctx DBOSContext, input any) (any, error)

type workflowOptions struct {
workflowName string
workflowID string
queueName string
applicationVersion string
maxRetries int
deduplicationID string
priority uint
workflowName string
workflowID string
queueName string
applicationVersion string
maxRetries int
deduplicationID string
priority uint
authenticated_user string
assumed_role string
authenticated_roles []string
}

// WorkflowOption is a functional option for configuring workflow execution parameters.
Expand Down Expand Up @@ -544,6 +547,27 @@ func withWorkflowName(name string) WorkflowOption {
}
}

// Sets the authenticated user for the workflow
func WithAuthenticatedUser(user string) WorkflowOption {
return func(p *workflowOptions) {
p.authenticated_user = user
}
}

// Sets the assumed role for the workflow
func WithAssumedRole(role string) WorkflowOption {
return func(p *workflowOptions) {
p.assumed_role = role
}
}

// Sets the authenticated role for the workflow
func WithAuthenticatedRoles(roles []string) WorkflowOption {
return func(p *workflowOptions) {
p.authenticated_roles = roles
}
}

// RunWorkflow executes a workflow function with type safety and durability guarantees.
// The workflow can be executed immediately or enqueued for later execution based on options.
// Returns a typed handle that can be used to wait for completion and retrieve results.
Expand Down Expand Up @@ -730,6 +754,9 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt
QueueName: params.queueName,
DeduplicationID: params.deduplicationID,
Priority: int(params.priority),
AuthenticatedUser: params.authenticated_user,
AssumedRole: params.assumed_role,
AuthenticatedRoles: params.authenticated_roles,
}

var earlyReturnPollingHandle *workflowPollingHandle[any]
Expand Down
29 changes: 29 additions & 0 deletions dbos/workflows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4131,3 +4131,32 @@ func TestSpecialSteps(t *testing.T) {
require.Equal(t, "success", result, "workflow should return success")
})
}
func TestWorkflowIdentity(t *testing.T) {
dbosCtx := setupDBOS(t, true, true)
RegisterWorkflow(dbosCtx, simpleWorkflow)
handle, err := RunWorkflow(
dbosCtx,
simpleWorkflow,
"test",
WithWorkflowID("my-workflow-id"),
WithAuthenticatedUser("user123"),
WithAssumedRole("admin"),
WithAuthenticatedRoles([]string{"reader", "writer"}))
require.NoError(t, err, "failed to start workflow")

// Retrieve the workflow's status.
status, err := handle.GetStatus()
require.NoError(t, err)

t.Run("CheckAuthenticatedUser", func(t *testing.T) {
assert.Equal(t, "user123", status.AuthenticatedUser)
})

t.Run("CheckAssumedRole", func(t *testing.T) {
assert.Equal(t, "admin", status.AssumedRole)
})

t.Run("CheckAuthenticatedRoles", func(t *testing.T) {
assert.Equal(t, []string{"reader", "writer"}, status.AuthenticatedRoles)
})
}
Loading