diff --git a/CHANGELOG.md b/CHANGELOG.md index b14921f56..6ffbe2c83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - chore: update json-as and remove hack [#857](https://github.com/hypermodeinc/modus/pull/857) - chore: rename agent lifecycle methods and APIs [#858](https://github.com/hypermodeinc/modus/pull/858) - feat: enforce WASI reactor mode [#859](https://github.com/hypermodeinc/modus/pull/859) +- feat: return user and chat errors in API response [#863](https://github.com/hypermodeinc/modus/pull/863) ## 2025-05-22 - Go SDK 0.18.0-alpha.3 diff --git a/runtime/graphql/datasource/source.go b/runtime/graphql/datasource/source.go index 024cc8278..9a8242f89 100644 --- a/runtime/graphql/datasource/source.go +++ b/runtime/graphql/datasource/source.go @@ -22,6 +22,7 @@ import ( "github.com/puzpuzpuz/xsync/v4" "github.com/buger/jsonparser" + "github.com/tetratelabs/wazero/sys" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -80,8 +81,18 @@ func (ds *ModusDataSource) callFunction(ctx context.Context, callInfo *callInfo) // Call the function execInfo, err := ds.WasmHost.CallFunction(ctx, fnInfo, callInfo.Parameters) if err != nil { - // The full error message has already been logged. Return a generic error to the caller, which will be included in the response. - return nil, nil, errors.New("error calling function") + exitErr := &sys.ExitError{} + if errors.As(err, &exitErr) { + if exitErr.ExitCode() == 255 { + // Exit code 255 is returned when an AssemblyScript function calls `abort` or throws an unhandled exception. + // Return a generic error to the caller, which will be included in the response. + return nil, nil, errors.New("error calling function") + } + + // clear the exit error so we can show only the logged error in the response + err = nil + } + // Otherwise, continue so we can return the error in the response. } // Store the execution info into the function output map. diff --git a/runtime/models/models.go b/runtime/models/models.go index c7825ec1e..85bb15f80 100644 --- a/runtime/models/models.go +++ b/runtime/models/models.go @@ -107,12 +107,14 @@ func PostToModelEndpoint[TResult any](ctx context.Context, model *manifest.Model } } - return empty, err + if res == nil { + return empty, err + } } + // NOTE: This path occurs whether or not there's an error, as long as there was some response body content. db.WriteInferenceHistory(ctx, model, payload, res.Data, res.StartTime, res.EndTime) - - return res.Data, nil + return res.Data, err } func getModelEndpointUrl(model *manifest.ModelInfo, connection *manifest.HTTPConnectionInfo) (string, error) { diff --git a/runtime/utils/http.go b/runtime/utils/http.go index 91c560a97..0e8095129 100644 --- a/runtime/utils/http.go +++ b/runtime/utils/http.go @@ -46,16 +46,9 @@ func sendHttp(req *http.Request) ([]byte, error) { } if response.StatusCode != http.StatusOK { - if len(body) == 0 { - return nil, &HttpError{ - StatusCode: response.StatusCode, - Message: response.Status, - } - } else { - return nil, &HttpError{ - StatusCode: response.StatusCode, - Message: fmt.Sprintf("%s\n%s", response.Status, body), - } + return body, &HttpError{ + StatusCode: response.StatusCode, + Message: response.Status, } } @@ -111,20 +104,21 @@ func PostHttp[TResult any](ctx context.Context, url string, payload any, beforeS startTime := GetTime() content, err := sendHttp(req) endTime := GetTime() - if err != nil { - return nil, err - } + + // NOTE: Unlike most functions, the result and error are BOTH returned. + // This is because some error messages are returned in the body of the response. var result TResult - switch any(result).(type) { - case []byte: - result = any(content).(TResult) - case string: - result = any(string(content)).(TResult) - default: - err = JsonDeserialize(content, &result) - if err != nil { - return nil, fmt.Errorf("error deserializing response: %w", err) + if content != nil { + switch any(result).(type) { + case []byte: + result = any(content).(TResult) + case string: + result = any(string(content)).(TResult) + default: + if err := JsonDeserialize(content, &result); err != nil { + return nil, fmt.Errorf("error deserializing response: %w", err) + } } } @@ -132,7 +126,7 @@ func PostHttp[TResult any](ctx context.Context, url string, payload any, beforeS Data: result, StartTime: startTime, EndTime: endTime, - }, nil + }, err } func WriteJsonContentHeader(w http.ResponseWriter) { diff --git a/runtime/utils/http_test.go b/runtime/utils/http_test.go index dbd539bfa..92c4171c5 100644 --- a/runtime/utils/http_test.go +++ b/runtime/utils/http_test.go @@ -56,13 +56,16 @@ func Test_SendHttp_ErrorResponse(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } - _, err = sendHttp(req) + res, err := sendHttp(req) if err == nil { t.Error("Expected an error, but got nil") } - expected := "HTTP error: 500 Internal Server Error\nSomething went wrong!\n" - if err.Error() != expected { + if expected := "Something went wrong!\n"; string(res) != expected { + t.Errorf("Unexpected result. Got: %s, want: %s", string(res), expected) + } + + if expected := "HTTP error: 500 Internal Server Error"; err.Error() != expected { t.Errorf("Unexpected error message. Got: %s, want: %s", err.Error(), expected) } } diff --git a/runtime/wasmhost/fncall.go b/runtime/wasmhost/fncall.go index 7721a009a..b21598355 100644 --- a/runtime/wasmhost/fncall.go +++ b/runtime/wasmhost/fncall.go @@ -156,7 +156,7 @@ func (host *wasmHost) CallFunctionInModule(ctx context.Context, mod wasm.Module, Dur("duration_ms", duration). Bool("user_visible", true). Int32("exit_code", exitCode). - Msgf("Function ended prematurely with exit code %d. This may have been intentional, or caused by an exception or panic in your code.", exitCode) + Msgf("Function ended with exit code %d, indicating an error.", exitCode) } } else if errors.Is(err, context.Canceled) { // Cancellation is not an error, but we still want to log it. diff --git a/runtime/wasmhost/hostfns.go b/runtime/wasmhost/hostfns.go index fb54946fb..89bd797c0 100644 --- a/runtime/wasmhost/hostfns.go +++ b/runtime/wasmhost/hostfns.go @@ -278,13 +278,6 @@ func (host *wasmHost) newHostFunction(modName, funcName string, fn any, opts ... // invoke the function out := rvFunc.Call(inputs) - // check for an error - if hasErrorResult && len(out) > 0 { - if err, ok := out[len(out)-1].Interface().(error); ok && err != nil { - return err - } - } - // copy results to the results slice for i := range numResults { if hasErrorResult && i == numResults-1 { @@ -294,6 +287,13 @@ func (host *wasmHost) newHostFunction(modName, funcName string, fn any, opts ... } } + // check for an error + if hasErrorResult && len(out) > 0 { + if err, ok := out[len(out)-1].Interface().(error); ok && err != nil { + return err + } + } + return nil } @@ -309,11 +309,11 @@ func (host *wasmHost) newHostFunction(modName, funcName string, fn any, opts ... } // Call the host function - if ok := callHostFunction(ctx, wrappedFn, msgs); !ok { - return - } + // NOTE: This will log any errors, but there still might be results that need to be returned to the guest even if the function fails + // For example, an HTTP request with a 4xx status code might still return a response body with details about the error. + callHostFunction(ctx, wrappedFn, msgs) - // Encode the results (if there are any) + // Encode the results (if there are any) and write them to the stack if len(results) > 0 { if err := encodeResults(ctx, wa, plan, stack, results); err != nil { logger.Err(ctx, err).Str("host_function", fullName).Any("data", results).Msg("Error encoding results.") @@ -489,7 +489,7 @@ func writeIndirectResults(ctx context.Context, wa langsupport.WasmAdapter, plan return nil } -func callHostFunction(ctx context.Context, fn func() error, msgs hfMessages) bool { +func callHostFunction(ctx context.Context, fn func() error, msgs hfMessages) { if msgs.msgStarting != "" { l := logger.Info(ctx).Bool("user_visible", true) if msgs.msgDetail != "" { @@ -510,7 +510,6 @@ func callHostFunction(ctx context.Context, fn func() error, msgs hfMessages) boo } l.Msg(msgs.msgCancelled) } - return false } else if err != nil { if msgs.msgError != "" { l := logger.Err(ctx, err).Bool("user_visible", true).Dur("duration_ms", duration) @@ -519,15 +518,11 @@ func callHostFunction(ctx context.Context, fn func() error, msgs hfMessages) boo } l.Msg(msgs.msgError) } - return false - } else { - if msgs.msgCompleted != "" { - l := logger.Info(ctx).Bool("user_visible", true).Dur("duration_ms", duration) - if msgs.msgDetail != "" { - l.Str("detail", msgs.msgDetail) - } - l.Msg(msgs.msgCompleted) + } else if msgs.msgCompleted != "" { + l := logger.Info(ctx).Bool("user_visible", true).Dur("duration_ms", duration) + if msgs.msgDetail != "" { + l.Str("detail", msgs.msgDetail) } - return true + l.Msg(msgs.msgCompleted) } } diff --git a/sdk/go/pkg/models/models.go b/sdk/go/pkg/models/models.go index bedbd0884..0ed6b530a 100644 --- a/sdk/go/pkg/models/models.go +++ b/sdk/go/pkg/models/models.go @@ -41,8 +41,9 @@ type modelPtr[TModel any] interface { // Provides a base implementation for all models. type ModelBase[TIn, TOut any] struct { - info *ModelInfo - Debug bool + info *ModelInfo + Debug bool + Validator func(response []byte) error } // Gets the model information. @@ -98,8 +99,15 @@ func (m ModelBase[TIn, TOut]) Invoke(input *TIn) (*TOut, error) { console.Debugf("Received output for model %s: %s", modelName, *sOutputJson) } + output := []byte(*sOutputJson) + if m.Validator != nil { + if err := m.Validator(output); err != nil { + return nil, err + } + } + var result TOut - err = utils.JsonDeserialize([]byte(*sOutputJson), &result) + err = utils.JsonDeserialize(output, &result) if err != nil { return nil, fmt.Errorf("failed to deserialize model output for %s: %w", modelName, err) } diff --git a/sdk/go/pkg/models/openai/chat.go b/sdk/go/pkg/models/openai/chat.go index 6289fa618..1e9a31cb4 100644 --- a/sdk/go/pkg/models/openai/chat.go +++ b/sdk/go/pkg/models/openai/chat.go @@ -13,6 +13,7 @@ import ( "bytes" "encoding/base64" "encoding/json" + "errors" "fmt" "strings" "time" @@ -316,6 +317,47 @@ func (o *ChatModelOutput) UnmarshalJSON(data []byte) error { return nil } +// Validates the response from the chat model output. +func validateChatModelResponse(data []byte) error { + if len(data) == 0 { + return errors.New("no response received from model invocation") + } + if !json.Valid(data) { + return fmt.Errorf("invalid response received from model invocation: %s", string(data)) + } + + result := gjson.GetBytes(data, "error") + if result.Exists() { + var ce ChatModelError + if err := json.Unmarshal([]byte(result.Raw), &ce); err != nil { + return fmt.Errorf("error parsing chat model error response: %w", err) + } + return fmt.Errorf("the chat model returned an error: %w", &ce) + } + + // no error + return nil +} + +// Represents an error returned from the OpenAI Chat API. +type ChatModelError struct { + // The error type. + Type string `json:"type"` + + // A human-readable description of the error. + Message string `json:"message"` + + // The parameter related to the error, if any. + Param string `json:"param,omitempty"` + + // The error code, if any. + Code string `json:"code,omitempty"` +} + +func (e *ChatModelError) Error() string { + return e.Message +} + // An interface to any request message. type RequestMessage interface { json.Marshaler @@ -1172,6 +1214,7 @@ type FunctionDefinition struct { // Creates an input object for the OpenAI Chat API. func (m *ChatModel) CreateInput(messages ...RequestMessage) (*ChatModelInput, error) { + m.Validator = validateChatModelResponse return &ChatModelInput{ Model: strings.ToLower(m.Info().FullName), Messages: messages, diff --git a/sdk/go/tools/modus-go-build/codegen/preprocess.go b/sdk/go/tools/modus-go-build/codegen/preprocess.go index 51d26de24..37b1034e3 100644 --- a/sdk/go/tools/modus-go-build/codegen/preprocess.go +++ b/sdk/go/tools/modus-go-build/codegen/preprocess.go @@ -251,6 +251,7 @@ func writeFuncWrappers(b *bytes.Buffer, pkg *packages.Package, imports map[strin } if hasErrorReturn { + imports["os"] = "os" imports["github.com/hypermodeinc/modus/sdk/go/pkg/console"] = "console" // remove the error return value from the function signature @@ -320,6 +321,7 @@ func writeFuncWrappers(b *bytes.Buffer, pkg *packages.Package, imports map[strin b.WriteString("\tif err != nil {\n") b.WriteString("\t\tconsole.Error(err.Error())\n") + b.WriteString("\t\tos.Exit(1)\n") b.WriteString("\t}\n") if numResults > 0 {