Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ func (c *dbosContext) Value(key any) any {
return c.ctx.Value(key)
}


// WithValue returns a copy of the DBOS context with the given key-value pair.
// This is similar to context.WithValue but maintains DBOS context capabilities.
// No-op if the provided context is not a concrete dbos.dbosContext.
Expand Down Expand Up @@ -300,7 +299,7 @@ func (c *dbosContext) ListRegisteredWorkflows(_ DBOSContext, opts ...ListRegiste

// Get all registered workflows and apply filters
var filteredWorkflows []WorkflowRegistryEntry
c.workflowRegistry.Range(func(key, value interface{}) bool {
c.workflowRegistry.Range(func(key, value any) bool {
workflow := value.(WorkflowRegistryEntry)

// Filter by scheduled only
Expand Down
12 changes: 9 additions & 3 deletions dbos/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import (
"reflect"
)

const (
// nilMarker is a special marker string used to represent nil values in the database.
nilMarker = "__DBOS_NIL"
)

type serializer[T any] interface {
Encode(data T) (*string, error)
Decode(data *string) (T, error)
Expand All @@ -21,8 +26,9 @@ func newJSONSerializer[T any]() serializer[T] {
func (j *jsonSerializer[T]) Encode(data T) (*string, error) {
// Check if the value is nil (for pointer types, slice, map, etc.)
if isNilValue(data) {
// For nil values, return nil pointer
return nil, nil
// For nil values, return the special marker so it can be stored in the database
marker := string(nilMarker)
return &marker, nil
}

// Check if the value is a zero value (but not nil)
Expand All @@ -36,7 +42,7 @@ func (j *jsonSerializer[T]) Encode(data T) (*string, error) {

func (j *jsonSerializer[T]) Decode(data *string) (T, error) {
// If data is a nil pointer, return nil (for pointer types) or zero value (for non-pointer types)
if data == nil {
if data == nil || *data == nilMarker {
return getNilOrZeroValue[T](), nil
}

Expand Down
39 changes: 37 additions & 2 deletions dbos/serialization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -141,8 +142,8 @@ func testAllSerializationPaths[T any](
wf := wfs[0]
if isNilExpected {
// Should be an empty string
assert.Nil(t, wf.Input, "Workflow input should be nil")
assert.Nil(t, wf.Output, "Workflow output should be nil")
require.Nil(t, wf.Input, "Workflow input should be nil")
require.Nil(t, wf.Output, "Workflow output should be nil")
} else {
require.NotNil(t, wf.Input)
require.NotNil(t, wf.Output)
Expand Down Expand Up @@ -175,6 +176,40 @@ func testAllSerializationPaths[T any](
}
}
})

// If nil is expected, verify the nil marker is stored in the database
if isNilExpected {
t.Run("DatabaseNilMarker", func(t *testing.T) {
// Get the database pool to query directly
dbosCtx, ok := executor.(*dbosContext)
require.True(t, ok, "expected dbosContext")
sysDB, ok := dbosCtx.systemDB.(*sysDB)
require.True(t, ok, "expected sysDB")

// Query the database directly to check for the marker
ctx := context.Background()
query := fmt.Sprintf(`SELECT inputs, output FROM %s.workflow_status WHERE workflow_uuid = $1`, pgx.Identifier{sysDB.schema}.Sanitize())

var inputString, outputString *string
err := sysDB.pool.QueryRow(ctx, query, workflowID).Scan(&inputString, &outputString)
require.NoError(t, err, "failed to query workflow status")

// Both input and output should be the nil marker
require.NotNil(t, inputString, "input should not be NULL in database")
assert.Equal(t, nilMarker, *inputString, "input should be the nil marker")

require.NotNil(t, outputString, "output should not be NULL in database")
assert.Equal(t, nilMarker, *outputString, "output should be the nil marker")

// Also check the step output in operation_outputs
stepQuery := fmt.Sprintf(`SELECT output FROM %s.operation_outputs WHERE workflow_uuid = $1 ORDER BY function_id LIMIT 1`, pgx.Identifier{sysDB.schema}.Sanitize())
var stepOutputString *string
err = sysDB.pool.QueryRow(ctx, stepQuery, workflowID).Scan(&stepOutputString)
require.NoError(t, err, "failed to query step output")
require.NotNil(t, stepOutputString, "step output should not be NULL in database")
assert.Equal(t, nilMarker, *stepOutputString, "step output should be the nil marker")
})
}
}

// Helper function to test Send/Recv communication
Expand Down
33 changes: 18 additions & 15 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -2077,28 +2077,30 @@ func (c *dbosContext) ListWorkflows(_ DBOSContext, opts ...ListWorkflowsOption)
if !ok {
return nil, fmt.Errorf("workflow input must be encoded string, got %T", workflows[i].Input)
}
if encodedInput == nil {
continue
}
decodedBytes, err := base64.StdEncoding.DecodeString(*encodedInput)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 workflow input for %s: %w", workflows[i].ID, err)
if encodedInput == nil || *encodedInput == nilMarker {
workflows[i].Input = nil
} else {
decodedBytes, err := base64.StdEncoding.DecodeString(*encodedInput)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 workflow input for %s: %w", workflows[i].ID, err)
}
workflows[i].Input = string(decodedBytes)
}
workflows[i].Input = string(decodedBytes)
}
if params.loadOutput && workflows[i].Output != nil {
encodedOutput, ok := workflows[i].Output.(*string)
if !ok {
return nil, fmt.Errorf("workflow output must be encoded *string, got %T", workflows[i].Output)
}
if encodedOutput == nil {
continue
}
decodedBytes, err := base64.StdEncoding.DecodeString(*encodedOutput)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 workflow output for %s: %w", workflows[i].ID, err)
if encodedOutput == nil || *encodedOutput == nilMarker {
workflows[i].Output = nil
} else {
decodedBytes, err := base64.StdEncoding.DecodeString(*encodedOutput)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 workflow output for %s: %w", workflows[i].ID, err)
}
workflows[i].Output = string(decodedBytes)
}
workflows[i].Output = string(decodedBytes)
}
}
}
Expand Down Expand Up @@ -2205,7 +2207,8 @@ func (c *dbosContext) GetWorkflowSteps(_ DBOSContext, workflowID string) ([]Step
if loadOutput {
for i := range steps {
encodedOutput := steps[i].Output
if encodedOutput == nil {
if encodedOutput == nil || *encodedOutput == nilMarker {
stepInfos[i].Output = nil
continue
}
decodedBytes, err := base64.StdEncoding.DecodeString(*encodedOutput)
Expand Down
Loading