From 8e165dc9aadc9f7045b91dd1b02d6404940dc023 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Mon, 9 Oct 2023 17:41:54 +0200 Subject: [PATCH] Feat Add headers to openai responses (#506) * feat: add headers to http response * chore: add test * fix: rename to httpHeader --- audio.go | 19 ++++++++++++++++++- chat.go | 2 ++ chat_test.go | 30 ++++++++++++++++++++++++++++++ client.go | 20 +++++++++++++++++++- completion.go | 2 ++ edits.go | 2 ++ embeddings.go | 4 ++++ engines.go | 4 ++++ files.go | 4 ++++ fine_tunes.go | 8 ++++++++ fine_tuning_job.go | 4 ++++ image.go | 2 ++ models.go | 6 ++++++ moderation.go | 2 ++ 14 files changed, 107 insertions(+), 2 deletions(-) diff --git a/audio.go b/audio.go index 9f469159..4cbe4fe6 100644 --- a/audio.go +++ b/audio.go @@ -63,6 +63,21 @@ type AudioResponse struct { Transient bool `json:"transient"` } `json:"segments"` Text string `json:"text"` + + httpHeader +} + +type audioTextResponse struct { + Text string `json:"text"` + + httpHeader +} + +func (r *audioTextResponse) ToAudioResponse() AudioResponse { + return AudioResponse{ + Text: r.Text, + httpHeader: r.httpHeader, + } } // CreateTranscription — API call to create a transcription. Returns transcribed text. @@ -104,7 +119,9 @@ func (c *Client) callAudioAPI( if request.HasJSONResponse() { err = c.sendRequest(req, &response) } else { - err = c.sendRequest(req, &response.Text) + var textResponse audioTextResponse + err = c.sendRequest(req, &textResponse) + response = textResponse.ToAudioResponse() } if err != nil { return AudioResponse{}, err diff --git a/chat.go b/chat.go index 8d29b323..df0e5f97 100644 --- a/chat.go +++ b/chat.go @@ -142,6 +142,8 @@ type ChatCompletionResponse struct { Model string `json:"model"` Choices []ChatCompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateChatCompletion — API call to Create a completion for the chat message. diff --git a/chat_test.go b/chat_test.go index 38d66fa6..52cd0bde 100644 --- a/chat_test.go +++ b/chat_test.go @@ -16,6 +16,11 @@ import ( "github.com/sashabaranov/go-openai/jsonschema" ) +const ( + xCustomHeader = "X-CUSTOM-HEADER" + xCustomHeaderValue = "test" +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -68,6 +73,30 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + a := resp.Header().Get(xCustomHeader) + _ = a + if resp.Header().Get(xCustomHeader) != xCustomHeaderValue { + t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -281,6 +310,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { TotalTokens: inputTokens + completionTokens, } resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) fmt.Fprintln(w, string(resBytes)) } diff --git a/client.go b/client.go index 5779a8e1..19902285 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,20 @@ type Client struct { createFormBuilder func(io.Writer) utils.FormBuilder } +type Response interface { + SetHeader(http.Header) +} + +type httpHeader http.Header + +func (h *httpHeader) SetHeader(header http.Header) { + *h = httpHeader(header) +} + +func (h httpHeader) Header() http.Header { + return http.Header(h) +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) @@ -82,7 +96,7 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ... return req, nil } -func (c *Client) sendRequest(req *http.Request, v any) error { +func (c *Client) sendRequest(req *http.Request, v Response) error { req.Header.Set("Accept", "application/json; charset=utf-8") // Check whether Content-Type is already set, Upload Files API requires @@ -103,6 +117,10 @@ func (c *Client) sendRequest(req *http.Request, v any) error { return c.handleErrorResp(res) } + if v != nil { + v.SetHeader(res.Header) + } + return decodeResponse(res.Body, v) } diff --git a/completion.go b/completion.go index 7b9ae89e..c7ff94af 100644 --- a/completion.go +++ b/completion.go @@ -154,6 +154,8 @@ type CompletionResponse struct { Model string `json:"model"` Choices []CompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well diff --git a/edits.go b/edits.go index 831aade2..97d02602 100644 --- a/edits.go +++ b/edits.go @@ -28,6 +28,8 @@ type EditsResponse struct { Created int64 `json:"created"` Usage Usage `json:"usage"` Choices []EditsChoice `json:"choices"` + + httpHeader } // Edits Perform an API call to the Edits endpoint. diff --git a/embeddings.go b/embeddings.go index 660bc24c..7e2aa7eb 100644 --- a/embeddings.go +++ b/embeddings.go @@ -150,6 +150,8 @@ type EmbeddingResponse struct { Data []Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader } type base64String string @@ -182,6 +184,8 @@ type EmbeddingResponseBase64 struct { Data []Base64Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader } // ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse. diff --git a/engines.go b/engines.go index adf6025c..5a0dba85 100644 --- a/engines.go +++ b/engines.go @@ -12,11 +12,15 @@ type Engine struct { Object string `json:"object"` Owner string `json:"owner"` Ready bool `json:"ready"` + + httpHeader } // EnginesList is a list of engines. type EnginesList struct { Engines []Engine `json:"data"` + + httpHeader } // ListEngines Lists the currently available engines, and provides basic diff --git a/files.go b/files.go index 8b933c36..9e521fbb 100644 --- a/files.go +++ b/files.go @@ -25,11 +25,15 @@ type File struct { Status string `json:"status"` Purpose string `json:"purpose"` StatusDetails string `json:"status_details"` + + httpHeader } // FilesList is a list of files that belong to the user or organization. type FilesList struct { Files []File `json:"data"` + + httpHeader } // CreateFile uploads a jsonl file to GPT3 diff --git a/fine_tunes.go b/fine_tunes.go index 7d3b59db..ca840781 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -41,6 +41,8 @@ type FineTune struct { ValidationFiles []File `json:"validation_files"` TrainingFiles []File `json:"training_files"` UpdatedAt int64 `json:"updated_at"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -69,6 +71,8 @@ type FineTuneHyperParams struct { type FineTuneList struct { Object string `json:"object"` Data []FineTune `json:"data"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -77,6 +81,8 @@ type FineTuneList struct { type FineTuneEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -86,6 +92,8 @@ type FineTuneDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 07b0c337..9dcb49de 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -21,6 +21,8 @@ type FineTuningJob struct { ValidationFile string `json:"validation_file,omitempty"` ResultFiles []string `json:"result_files"` TrainedTokens int `json:"trained_tokens"` + + httpHeader } type Hyperparameters struct { @@ -39,6 +41,8 @@ type FineTuningJobEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` HasMore bool `json:"has_more"` + + httpHeader } type FineTuningJobEvent struct { diff --git a/image.go b/image.go index cb96f4f5..4addcdb1 100644 --- a/image.go +++ b/image.go @@ -33,6 +33,8 @@ type ImageRequest struct { type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + + httpHeader } // ImageResponseDataInner represents a response data structure for image API. diff --git a/models.go b/models.go index c207f0a8..d94f9883 100644 --- a/models.go +++ b/models.go @@ -15,6 +15,8 @@ type Model struct { Permission []Permission `json:"permission"` Root string `json:"root"` Parent string `json:"parent"` + + httpHeader } // Permission struct represents an OpenAPI permission. @@ -38,11 +40,15 @@ type FineTuneModelDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } // ModelsList is a list of models, including those that belong to the user or organization. type ModelsList struct { Models []Model `json:"data"` + + httpHeader } // ListModels Lists the currently available models, diff --git a/moderation.go b/moderation.go index a32f123f..f8d20ee5 100644 --- a/moderation.go +++ b/moderation.go @@ -69,6 +69,8 @@ type ModerationResponse struct { ID string `json:"id"` Model string `json:"model"` Results []Result `json:"results"` + + httpHeader } // Moderations — perform a moderation api call over a string.