diff --git a/dbos/dbos.go b/dbos/dbos.go index 06daa504..890eacda 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -3,14 +3,13 @@ package dbos import ( "context" "crypto/sha256" + "encoding/gob" "encoding/hex" "fmt" + "io" "log/slog" "net/url" "os" - "reflect" - "runtime" - "sort" "time" "github.com/robfig/cron/v3" @@ -23,36 +22,6 @@ var ( _DEFAULT_ADMIN_SERVER_PORT = 3001 ) -func computeApplicationVersion() string { - if len(registry) == 0 { - fmt.Println("DBOS: No registered workflows found, cannot compute application version") - return "" - } - - // Collect all function names and sort them for consistent hashing - var functionNames []string - for fqn := range registry { - functionNames = append(functionNames, fqn) - } - sort.Strings(functionNames) - - hasher := sha256.New() - - for _, fqn := range functionNames { - workflowEntry := registry[fqn] - - // Try to get function source location and other identifying info - if pc := runtime.FuncForPC(reflect.ValueOf(workflowEntry.wrappedFunction).Pointer()); pc != nil { - // Get the function's entry point - this reflects the actual compiled code - entry := pc.Entry() - fmt.Fprintf(hasher, "%x", entry) - } - } - - return hex.EncodeToString(hasher.Sum(nil)) - -} - var workflowScheduler *cron.Cron // Global because accessed during workflow registration before the dbos singleton is initialized var logger *slog.Logger // Global because accessed everywhere inside the library @@ -141,6 +110,10 @@ func Initialize(inputConfig Config) error { // Set global logger logger = config.Logger + // Register types we serialize with gob + var t time.Time + gob.Register(t) + // Initialize global variables with environment variables, providing defaults if not set _APP_VERSION = os.Getenv("DBOS__APPVERSION") if _APP_VERSION == "" { @@ -273,3 +246,32 @@ func Shutdown() { } dbos = nil } + +func GetBinaryHash() (string, error) { + execPath, err := os.Executable() + if err != nil { + return "", err + } + + file, err := os.Open(execPath) + if err != nil { + return "", err + } + defer file.Close() + + hasher := sha256.New() + if _, err := io.Copy(hasher, file); err != nil { + return "", err + } + + return hex.EncodeToString(hasher.Sum(nil)), nil +} + +func computeApplicationVersion() string { + hash, err := GetBinaryHash() + if err != nil { + fmt.Printf("DBOS: Failed to compute binary hash: %v\n", err) + return "" + } + return hash +} diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go index 9a639415..192b4c8c 100644 --- a/dbos/dbos_test.go +++ b/dbos/dbos_test.go @@ -1,9 +1,6 @@ package dbos import ( - "context" - "encoding/hex" - "maps" "testing" ) @@ -60,28 +57,3 @@ func TestConfigValidationErrorTypes(t *testing.T) { } }) } -func TestAppVersion(t *testing.T) { - if _, err := hex.DecodeString(_APP_VERSION); err != nil { - t.Fatalf("APP_VERSION is not a valid hex string: %v", err) - } - - // Save the original registry content - originalRegistry := make(map[string]workflowRegistryEntry) - maps.Copy(originalRegistry, registry) - - // Restore the registry after the test - defer func() { - registry = originalRegistry - }() - - // Replace the registry and verify the hash is different - registry = make(map[string]workflowRegistryEntry) - - WithWorkflow(func(ctx context.Context, input string) (string, error) { - return "new-registry-workflow-" + input, nil - }) - hash2 := computeApplicationVersion() - if _APP_VERSION == hash2 { - t.Fatalf("APP_VERSION hash did not change after replacing registry") - } -} diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index 5040f48e..7e6acc39 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "testing" + "time" ) /** Test serialization and deserialization @@ -13,6 +14,7 @@ import ( [x] Workflow inputs/outputs [x] Step inputs/outputs [x] Direct handlers, polling handler, list workflows results, get step infos +[x] Set/get event with user defined types */ var ( @@ -289,3 +291,180 @@ func TestWorkflowEncoding(t *testing.T) { } }) } + +type UserDefinedEventData struct { + ID int `json:"id"` + Name string `json:"name"` + Details struct { + Description string `json:"description"` + Tags []string `json:"tags"` + } `json:"details"` +} + +func setEventUserDefinedTypeWorkflow(ctx context.Context, input string) (string, error) { + eventData := UserDefinedEventData{ + ID: 42, + Name: "test-event", + Details: struct { + Description string `json:"description"` + Tags []string `json:"tags"` + }{ + Description: "This is a test event with user-defined data", + Tags: []string{"test", "user-defined", "serialization"}, + }, + } + + err := SetEvent(ctx, WorkflowSetEventInput[UserDefinedEventData]{Key: input, Message: eventData}) + if err != nil { + return "", err + } + return "user-defined-event-set", nil +} + +var setEventUserDefinedTypeWf = WithWorkflow(setEventUserDefinedTypeWorkflow) + +func TestSetEventSerialize(t *testing.T) { + setupDBOS(t) + + t.Run("SetEventUserDefinedType", func(t *testing.T) { + // Start a workflow that sets an event with a user-defined type + setHandle, err := setEventUserDefinedTypeWf(context.Background(), "user-defined-key") + if err != nil { + t.Fatalf("failed to start workflow with user-defined event type: %v", err) + } + + // Wait for the workflow to complete + result, err := setHandle.GetResult(context.Background()) + if err != nil { + t.Fatalf("failed to get result from user-defined event workflow: %v", err) + } + if result != "user-defined-event-set" { + t.Fatalf("expected result to be 'user-defined-event-set', got '%s'", result) + } + + // Retrieve the event to verify it was properly serialized and can be deserialized + retrievedEvent, err := GetEvent[UserDefinedEventData](context.Background(), WorkflowGetEventInput{ + TargetWorkflowID: setHandle.GetWorkflowID(), + Key: "user-defined-key", + Timeout: 3 * time.Second, + }) + if err != nil { + t.Fatalf("failed to get user-defined event: %v", err) + } + + // Verify the retrieved data matches what we set + if retrievedEvent.ID != 42 { + t.Fatalf("expected ID to be 42, got %d", retrievedEvent.ID) + } + if retrievedEvent.Name != "test-event" { + t.Fatalf("expected Name to be 'test-event', got '%s'", retrievedEvent.Name) + } + if retrievedEvent.Details.Description != "This is a test event with user-defined data" { + t.Fatalf("expected Description to be 'This is a test event with user-defined data', got '%s'", retrievedEvent.Details.Description) + } + if len(retrievedEvent.Details.Tags) != 3 { + t.Fatalf("expected 3 tags, got %d", len(retrievedEvent.Details.Tags)) + } + expectedTags := []string{"test", "user-defined", "serialization"} + for i, tag := range retrievedEvent.Details.Tags { + if tag != expectedTags[i] { + t.Fatalf("expected tag %d to be '%s', got '%s'", i, expectedTags[i], tag) + } + } + }) +} + + +func sendUserDefinedTypeWorkflow(ctx context.Context, destinationID string) (string, error) { + // Create an instance of our user-defined type inside the workflow + sendData := UserDefinedEventData{ + ID: 42, + Name: "test-send-message", + Details: struct { + Description string `json:"description"` + Tags []string `json:"tags"` + }{ + Description: "This is a test send message with user-defined data", + Tags: []string{"test", "user-defined", "serialization", "send"}, + }, + } + + // Send should automatically register this type with gob + // Note the explicit type parameter since compiler cannot infer UserDefinedEventData from string input + err := Send(ctx, WorkflowSendInput[UserDefinedEventData]{ + DestinationID: destinationID, + Topic: "user-defined-topic", + Message: sendData, + }) + if err != nil { + return "", err + } + return "user-defined-message-sent", nil +} + +func recvUserDefinedTypeWorkflow(ctx context.Context, input string) (UserDefinedEventData, error) { + // Receive the user-defined type message + result, err := Recv[UserDefinedEventData](ctx, WorkflowRecvInput{ + Topic: "user-defined-topic", + Timeout: 3 * time.Second, + }) + return result, err +} + +var sendUserDefinedTypeWf = WithWorkflow(sendUserDefinedTypeWorkflow) +var recvUserDefinedTypeWf = WithWorkflow(recvUserDefinedTypeWorkflow) + +func TestSendSerialize(t *testing.T) { + setupDBOS(t) + + t.Run("SendUserDefinedType", func(t *testing.T) { + // Start a receiver workflow first + recvHandle, err := recvUserDefinedTypeWf(context.Background(), "recv-input") + if err != nil { + t.Fatalf("failed to start receive workflow: %v", err) + } + + // Start a sender workflow that sends a message with a user-defined type + sendHandle, err := sendUserDefinedTypeWf(context.Background(), recvHandle.GetWorkflowID()) + if err != nil { + t.Fatalf("failed to start workflow with user-defined send type: %v", err) + } + + // Wait for the sender workflow to complete + sendResult, err := sendHandle.GetResult(context.Background()) + if err != nil { + t.Fatalf("failed to get result from user-defined send workflow: %v", err) + } + if sendResult != "user-defined-message-sent" { + t.Fatalf("expected result to be 'user-defined-message-sent', got '%s'", sendResult) + } + + // Wait for the receiver workflow to complete and get the message + receivedData, err := recvHandle.GetResult(context.Background()) + if err != nil { + t.Fatalf("failed to get result from receive workflow: %v", err) + } + + // Verify the received data matches what we sent + if receivedData.ID != 42 { + t.Fatalf("expected ID to be 42, got %d", receivedData.ID) + } + if receivedData.Name != "test-send-message" { + t.Fatalf("expected Name to be 'test-send-message', got '%s'", receivedData.Name) + } + if receivedData.Details.Description != "This is a test send message with user-defined data" { + t.Fatalf("expected Description to be 'This is a test send message with user-defined data', got '%s'", receivedData.Details.Description) + } + + // Verify tags + expectedTags := []string{"test", "user-defined", "serialization", "send"} + if len(receivedData.Details.Tags) != len(expectedTags) { + t.Fatalf("expected %d tags, got %d", len(expectedTags), len(receivedData.Details.Tags)) + } + for i, tag := range receivedData.Details.Tags { + if tag != expectedTags[i] { + t.Fatalf("expected tag %d to be '%s', got '%s'", i, expectedTags[i], tag) + } + } + }) +} diff --git a/dbos/system_database.go b/dbos/system_database.go index 3016e044..3bc6bc0a 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -38,9 +38,9 @@ type SystemDatabase interface { CheckOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error) RecordChildGetResult(ctx context.Context, input recordChildGetResultDBInput) error GetWorkflowSteps(ctx context.Context, workflowID string) ([]StepInfo, error) - Send(ctx context.Context, input WorkflowSendInput) error + Send(ctx context.Context, input workflowSendInputInternal) error Recv(ctx context.Context, input WorkflowRecvInput) (any, error) - SetEvent(ctx context.Context, input WorkflowSetEventInput) error + SetEvent(ctx context.Context, input workflowSetEventInputInternal) error GetEvent(ctx context.Context, input WorkflowGetEventInput) (any, error) Sleep(ctx context.Context, duration time.Duration) (time.Duration, error) } @@ -1102,82 +1102,92 @@ func (s *systemDatabase) notificationListenerLoop(ctx context.Context) { const _DBOS_NULL_TOPIC = "__null__topic__" +type workflowSendInputInternal struct { + destinationID string + message any + topic string +} + // Send is a special type of step that sends a message to another workflow. -// Three differences with a normal steps: durability and the function run in the same transaction, and we forbid nested step execution -func (s *systemDatabase) Send(ctx context.Context, input WorkflowSendInput) error { +// Can be called both within a workflow (as a step) or outside a workflow (directly). +// When called within a workflow: durability and the function run in the same transaction, and we forbid nested step execution +func (s *systemDatabase) Send(ctx context.Context, input workflowSendInputInternal) error { functionName := "DBOS.send" - // Get workflow state from context + // Get workflow state from context (optional for Send as we can send from outside a workflow) wfState, ok := ctx.Value(workflowStateKey).(*workflowState) - if !ok || wfState == nil { - return newStepExecutionError("", functionName, "workflow state not found in context: are you running this step within a workflow?") - } + var stepID int + var isInWorkflow bool - if wfState.isWithinStep { - return newStepExecutionError(wfState.workflowID, functionName, "cannot call Send within a step") + if ok && wfState != nil { + isInWorkflow = true + if wfState.isWithinStep { + return newStepExecutionError(wfState.workflowID, functionName, "cannot call Send within a step") + } + stepID = wfState.NextStepID() } - stepID := wfState.NextStepID() - tx, err := s.pool.Begin(ctx) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer tx.Rollback(ctx) - // Check if operation was already executed and do nothing if so - checkInput := checkOperationExecutionDBInput{ - workflowID: wfState.workflowID, - stepID: stepID, - stepName: functionName, - tx: tx, - } - recordedResult, err := s.CheckOperationExecution(ctx, checkInput) - if err != nil { - return err - } - if recordedResult != nil { - // when hitting this case, recordedResult will be &{ } - return nil + // Check if operation was already executed and do nothing if so (only if in workflow) + if isInWorkflow { + checkInput := checkOperationExecutionDBInput{ + workflowID: wfState.workflowID, + stepID: stepID, + stepName: functionName, + tx: tx, + } + recordedResult, err := s.CheckOperationExecution(ctx, checkInput) + if err != nil { + return err + } + if recordedResult != nil { + // when hitting this case, recordedResult will be &{ } + return nil + } } // Set default topic if not provided topic := _DBOS_NULL_TOPIC - if len(input.Topic) > 0 { - topic = input.Topic + if len(input.topic) > 0 { + topic = input.topic } // Serialize the message. It must have been registered with encoding/gob by the user if not a basic type. - messageString, err := serialize(input.Message) + messageString, err := serialize(input.message) if err != nil { return fmt.Errorf("failed to serialize message: %w", err) } - insertQuery := `INSERT INTO dbos.notifications (destination_uuid, topic, message) - VALUES ($1, $2, $3)` - - _, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, messageString) + insertQuery := `INSERT INTO dbos.notifications (destination_uuid, topic, message) VALUES ($1, $2, $3)` + _, err = tx.Exec(ctx, insertQuery, input.destinationID, topic, messageString) if err != nil { // Check for foreign key violation (destination workflow doesn't exist) if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23503" { - return newNonExistentWorkflowError(input.DestinationID) + return newNonExistentWorkflowError(input.destinationID) } return fmt.Errorf("failed to insert notification: %w", err) } - // Record the operation result - recordInput := recordOperationResultDBInput{ - workflowID: wfState.workflowID, - stepID: stepID, - stepName: functionName, - output: nil, - err: nil, - tx: tx, - } + // Record the operation result if this is called within a workflow + if isInWorkflow { + recordInput := recordOperationResultDBInput{ + workflowID: wfState.workflowID, + stepID: stepID, + stepName: functionName, + output: nil, + err: nil, + tx: tx, + } - err = s.RecordOperationResult(ctx, recordInput) - if err != nil { - return fmt.Errorf("failed to record operation result: %w", err) + err = s.RecordOperationResult(ctx, recordInput) + if err != nil { + return fmt.Errorf("failed to record operation result: %w", err) + } } // Commit transaction @@ -1334,7 +1344,12 @@ func (s *systemDatabase) Recv(ctx context.Context, input WorkflowRecvInput) (any return message, nil } -func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInput) error { +type workflowSetEventInputInternal struct { + key string + message any +} + +func (s *systemDatabase) SetEvent(ctx context.Context, input workflowSetEventInputInternal) error { functionName := "DBOS.setEvent" // Get workflow state from context @@ -1372,7 +1387,7 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInp } // Serialize the message. It must have been registered with encoding/gob by the user if not a basic type. - messageString, err := serialize(input.Message) + messageString, err := serialize(input.message) if err != nil { return fmt.Errorf("failed to serialize message: %w", err) } @@ -1383,7 +1398,7 @@ func (s *systemDatabase) SetEvent(ctx context.Context, input WorkflowSetEventInp ON CONFLICT (workflow_uuid, key) DO UPDATE SET value = EXCLUDED.value` - _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.Key, messageString) + _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.key, messageString) if err != nil { return fmt.Errorf("failed to insert/update workflow event: %w", err) } diff --git a/dbos/workflow.go b/dbos/workflow.go index 308a5ddf..226fa71d 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -712,14 +712,22 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op /******* WORKFLOW COMMUNICATIONS ********/ /****************************************/ -type WorkflowSendInput struct { +type WorkflowSendInput[R any] struct { DestinationID string - Message any + Message R Topic string } -func Send(ctx context.Context, input WorkflowSendInput) error { - return dbos.systemDB.Send(ctx, input) +// Send sends a message to another workflow. +// Send automatically registers the type of R for gob encoding +func Send[R any](ctx context.Context, input WorkflowSendInput[R]) error { + var typedMessage R + gob.Register(typedMessage) + return dbos.systemDB.Send(ctx, workflowSendInputInternal{ + destinationID: input.DestinationID, + message: input.Message, + topic: input.Topic, + }) } type WorkflowRecvInput struct { @@ -744,13 +752,21 @@ func Recv[R any](ctx context.Context, input WorkflowRecvInput) (R, error) { return typedMessage, nil } -type WorkflowSetEventInput struct { +type WorkflowSetEventInput[R any] struct { Key string - Message any + Message R } -func SetEvent(ctx context.Context, input WorkflowSetEventInput) error { - return dbos.systemDB.SetEvent(ctx, input) +// Sets an event from a workflow. +// The event is a key value pair +// SetEvent automatically registers the type of R for gob encoding +func SetEvent[R any](ctx context.Context, input WorkflowSetEventInput[R]) error { + var typedMessage R + gob.Register(typedMessage) + return dbos.systemDB.SetEvent(ctx, workflowSetEventInputInternal{ + key: input.Key, + message: input.Message, + }) } type WorkflowGetEventInput struct { diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 54ff6fab..e09b77ad 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -973,6 +973,7 @@ var ( recvIdempotencyWf = WithWorkflow(receiveIdempotencyWorkflow) receiveIdempotencyStartEvent = NewEvent() receiveIdempotencyStopEvent = NewEvent() + sendWithinStepWf = WithWorkflow(workflowThatCallsSendInStep) numConcurrentRecvWfs = 5 concurrentRecvReadyEvents = make([]*Event, numConcurrentRecvWfs) concurrentRecvStartEvent = NewEvent() @@ -984,15 +985,15 @@ type sendWorkflowInput struct { } func sendWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { - err := Send(ctx, WorkflowSendInput{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message1"}) + err := Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message1"}) if err != nil { return "", err } - err = Send(ctx, WorkflowSendInput{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message2"}) + err = Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message2"}) if err != nil { return "", err } - err = Send(ctx, WorkflowSendInput{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message3"}) + err = Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "message3"}) if err != nil { return "", err } @@ -1036,7 +1037,7 @@ func receiveWorkflowCoordinated(ctx context.Context, input struct { func sendStructWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { testStruct := sendRecvType{Value: "test-struct-value"} - err := Send(ctx, WorkflowSendInput{DestinationID: input.DestinationID, Topic: input.Topic, Message: testStruct}) + err := Send(ctx, WorkflowSendInput[sendRecvType]{DestinationID: input.DestinationID, Topic: input.Topic, Message: testStruct}) return "", err } @@ -1045,7 +1046,7 @@ func receiveStructWorkflow(ctx context.Context, topic string) (sendRecvType, err } func sendIdempotencyWorkflow(ctx context.Context, input sendWorkflowInput) (string, error) { - err := Send(ctx, WorkflowSendInput{DestinationID: input.DestinationID, Topic: input.Topic, Message: "m1"}) + err := Send(ctx, WorkflowSendInput[string]{DestinationID: input.DestinationID, Topic: input.Topic, Message: "m1"}) if err != nil { return "", err } @@ -1063,6 +1064,22 @@ func receiveIdempotencyWorkflow(ctx context.Context, topic string) (string, erro return msg, nil } +func stepThatCallsSend(ctx context.Context, input sendWorkflowInput) (string, error) { + err := Send(ctx, WorkflowSendInput[string]{ + DestinationID: input.DestinationID, + Topic: input.Topic, + Message: "message-from-step", + }) + if err != nil { + return "", err + } + return "send-completed", nil +} + +func workflowThatCallsSendInStep(ctx context.Context, input sendWorkflowInput) (string, error) { + return RunAsStep(ctx, stepThatCallsSend, input) +} + type sendRecvType struct { Value string } @@ -1185,13 +1202,13 @@ func TestSendRecv(t *testing.T) { } }) - t.Run("SendRecvMustRunInsideWorkflows", func(t *testing.T) { + t.Run("RecvMustRunInsideWorkflows", func(t *testing.T) { ctx := context.Background() - // Attempt to run Send outside of a workflow context - err := Send(ctx, WorkflowSendInput{DestinationID: "test-id", Topic: "test-topic", Message: "test-message"}) + // Attempt to run Recv outside of a workflow context + _, err := Recv[string](ctx, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second}) if err == nil { - t.Fatal("expected error when running Send outside of workflow context, but got none") + t.Fatal("expected error when running Recv outside of workflow context, but got none") } // Check the error type @@ -1209,26 +1226,35 @@ func TestSendRecv(t *testing.T) { if !strings.Contains(err.Error(), expectedMessagePart) { t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error()) } + }) - // Attempt to run Recv outside of a workflow context - _, err = Recv[string](ctx, WorkflowRecvInput{Topic: "test-topic", Timeout: 1 * time.Second}) - if err == nil { - t.Fatal("expected error when running Recv outside of workflow context, but got none") + t.Run("SendOutsideWorkflow", func(t *testing.T) { + // Start a receive workflow to have a valid destination + receiveHandle, err := receiveWf(context.Background(), "outside-workflow-topic") + if err != nil { + t.Fatalf("failed to start receive workflow: %v", err) } - // Check the error type - dbosErr, ok = err.(*DBOSError) - if !ok { - t.Fatalf("expected error to be of type *DBOSError, got %T", err) + // Send messages from outside a workflow context (should work now) + ctx := context.Background() + for i := range 3 { + err = Send(ctx, WorkflowSendInput[string]{ + DestinationID: receiveHandle.GetWorkflowID(), + Topic: "outside-workflow-topic", + Message: fmt.Sprintf("message%d", i+1), + }) + if err != nil { + t.Fatalf("failed to send message%d from outside workflow: %v", i+1, err) + } } - if dbosErr.Code != StepExecutionError { - t.Fatalf("expected error code to be StepExecutionError, got %v", dbosErr.Code) + // Verify the receive workflow gets all messages + result, err := receiveHandle.GetResult(context.Background()) + if err != nil { + t.Fatalf("failed to get result from receive workflow: %v", err) } - - // Test the specific message from the error - if !strings.Contains(err.Error(), expectedMessagePart) { - t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error()) + if result != "message1-message2-message3" { + t.Fatalf("expected result to be 'message1-message2-message3', got '%s'", result) } }) t.Run("SendRecvIdempotency", func(t *testing.T) { @@ -1292,6 +1318,54 @@ func TestSendRecv(t *testing.T) { } }) + t.Run("SendCannotBeCalledWithinStep", func(t *testing.T) { + // Start a receive workflow to have a valid destination + receiveHandle, err := receiveWf(context.Background(), "send-within-step-topic") + if err != nil { + t.Fatalf("failed to start receive workflow: %v", err) + } + + // Execute the workflow that tries to call Send within a step + handle, err := sendWithinStepWf(context.Background(), sendWorkflowInput{ + DestinationID: receiveHandle.GetWorkflowID(), + Topic: "send-within-step-topic", + }) + if err != nil { + t.Fatalf("failed to start workflow: %v", err) + } + + // Expect the workflow to fail with the specific error + _, err = handle.GetResult(context.Background()) + if err == nil { + t.Fatal("expected error when calling Send within a step, but got none") + } + + // Check the error type + dbosErr, ok := err.(*DBOSError) + if !ok { + t.Fatalf("expected error to be of type *DBOSError, got %T", err) + } + + if dbosErr.Code != StepExecutionError { + t.Fatalf("expected error code to be StepExecutionError, got %v", dbosErr.Code) + } + + // Test the specific message from the error + expectedMessagePart := "cannot call Send within a step" + if !strings.Contains(err.Error(), expectedMessagePart) { + t.Fatalf("expected error message to contain %q, but got %q", expectedMessagePart, err.Error()) + } + + // Wait for the receive workflow to time out + result, err := receiveHandle.GetResult(context.Background()) + if err != nil { + t.Fatalf("failed to get result from receive workflow: %v", err) + } + if result != "--" { + t.Fatalf("expected receive workflow result to be '--' (timeout), got '%s'", result) + } + }) + t.Run("ConcurrentRecv", func(t *testing.T) { // Test concurrent receivers - only 1 should timeout, others should get errors receiveTopic := "concurrent-recv-topic" @@ -1396,7 +1470,7 @@ type setEventWorkflowInput struct { } func setEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - err := SetEvent(ctx, WorkflowSetEventInput{Key: input.Key, Message: input.Message}) + err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: input.Key, Message: input.Message}) if err != nil { return "", err } @@ -1417,7 +1491,7 @@ func getEventWorkflow(ctx context.Context, input setEventWorkflowInput) (string, func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { // Set the first event - err := SetEvent(ctx, WorkflowSetEventInput{Key: "event1", Message: "first-event-message"}) + err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: "event1", Message: "first-event-message"}) if err != nil { return "", err } @@ -1426,7 +1500,7 @@ func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (str setSecondEventSignal.Wait() // Set the second event - err = SetEvent(ctx, WorkflowSetEventInput{Key: "event2", Message: "second-event-message"}) + err = SetEvent(ctx, WorkflowSetEventInput[string]{Key: "event2", Message: "second-event-message"}) if err != nil { return "", err } @@ -1435,7 +1509,7 @@ func setTwoEventsWorkflow(ctx context.Context, input setEventWorkflowInput) (str } func setEventIdempotencyWorkflow(ctx context.Context, input setEventWorkflowInput) (string, error) { - err := SetEvent(ctx, WorkflowSetEventInput{Key: input.Key, Message: input.Message}) + err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: input.Key, Message: input.Message}) if err != nil { return "", err } @@ -1596,7 +1670,7 @@ func TestSetGetEvent(t *testing.T) { ctx := context.Background() // Attempt to run SetEvent outside of a workflow context - err := SetEvent(ctx, WorkflowSetEventInput{Key: "test-key", Message: "test-message"}) + err := SetEvent(ctx, WorkflowSetEventInput[string]{Key: "test-key", Message: "test-message"}) if err == nil { t.Fatal("expected error when running SetEvent outside of workflow context, but got none") }