From 15ae97ca77fc42378d6df653a64ebb124177d022 Mon Sep 17 00:00:00 2001 From: Sean Goedecke Date: Mon, 7 Oct 2024 11:22:08 +1100 Subject: [PATCH 1/5] Support non-streamed responses, such as those from o1 models --- README.md | 13 ++++++++++- cmd/run/run.go | 23 ++++++++++---------- internal/azure_models/client.go | 25 +++++++++++++++++++--- internal/azure_models/types.go | 2 +- internal/sse/eventreader.go | 5 +++++ internal/sse/mockeventreader.go | 38 +++++++++++++++++++++++++++++++++ 6 files changed, 90 insertions(+), 16 deletions(-) create mode 100644 internal/sse/mockeventreader.go diff --git a/README.md b/README.md index 6e4fbbea..b8dc4e5d 100644 --- a/README.md +++ b/README.md @@ -62,4 +62,15 @@ cat README.md | gh models run gpt-4o-mini "summarize this text" ### Building -Run `script/build`. +Run `script/build`. Now you can run the binary locally, e.g. `./gh-models list` + +### Releasing + +`gh extension upgrade github/gh-models` or `gh extension install github/gh-models` will pull the latest release, not the latest commit, so all changes require cutting a new release: + +```shell +git tag v0.0.x main +git push origin tag v0.0.x +``` + +This will trigger the `release` action that runs the actual production build. \ No newline at end of file diff --git a/cmd/run/run.go b/cmd/run/run.go index 11b174bd..47e67d04 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -391,18 +391,19 @@ func NewRunCommand() *cobra.Command { sp.Stop() for _, choice := range completion.Choices { - if choice.Delta != nil { - if choice.Delta.Content == nil { - continue - } - - messageBuilder.WriteString(*choice.Delta.Content) - io.WriteString(out, *choice.Delta.Content) + if choice.Delta != nil && choice.Delta.Content != nil { + content := choice.Delta.Content + messageBuilder.WriteString(*content) + io.WriteString(out, *content) + } else if choice.Message != nil && choice.Message.Content != nil { + content := choice.Message.Content + messageBuilder.WriteString(*content) + io.WriteString(out, *content) + } - // Introduce a small delay in between response tokens to better simulate a conversation - if terminal.IsTerminalOutput() { - time.Sleep(10 * time.Millisecond) - } + // Introduce a small delay in between response tokens to better simulate a conversation + if terminal.IsTerminalOutput() { + time.Sleep(10 * time.Millisecond) } } } diff --git a/internal/azure_models/client.go b/internal/azure_models/client.go index db80447e..ce02546c 100644 --- a/internal/azure_models/client.go +++ b/internal/azure_models/client.go @@ -31,7 +31,12 @@ func NewClient(authToken string) *Client { } func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatCompletionResponse, error) { - req.Stream = true + // Check if the model name is `o1-mini` or `o1-preview` + if req.Model == "o1-mini" || req.Model == "o1-preview" { + req.Stream = false + } else { + req.Stream = true + } bodyBytes, err := json.Marshal(req) if err != nil { @@ -54,13 +59,27 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple } if resp.StatusCode != http.StatusOK { - // If we aren't going to return an SSE stream, then ensure the response body is closed. defer resp.Body.Close() return nil, c.handleHTTPError(resp) } var chatCompletionResponse ChatCompletionResponse - chatCompletionResponse.Reader = sse.NewEventReader[ChatCompletion](resp.Body) + + if req.Stream { + // Handle streamed response + chatCompletionResponse.Reader = sse.NewEventReader[ChatCompletion](resp.Body) + } else { + // Handle non-streamed response + defer resp.Body.Close() + var completion ChatCompletion + if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { + return nil, err + } + + // Create a mock reader that returns the decoded completion + mockReader := sse.NewMockEventReader([]ChatCompletion{completion}) + chatCompletionResponse.Reader = mockReader + } return &chatCompletionResponse, nil } diff --git a/internal/azure_models/types.go b/internal/azure_models/types.go index 9a9cd0f4..bb3ecf6e 100644 --- a/internal/azure_models/types.go +++ b/internal/azure_models/types.go @@ -50,7 +50,7 @@ type ChatCompletion struct { } type ChatCompletionResponse struct { - Reader *sse.EventReader[ChatCompletion] + Reader sse.EventReaderInterface[ChatCompletion] } type modelCatalogSearchResponse struct { diff --git a/internal/sse/eventreader.go b/internal/sse/eventreader.go index 5eddcc87..c5252b32 100644 --- a/internal/sse/eventreader.go +++ b/internal/sse/eventreader.go @@ -10,6 +10,11 @@ import ( "strings" ) +type EventReaderInterface[T any] interface { + Read() (T, error) + Close() error +} + // Reader is an interface for reading events from an SSE stream. type Reader[T any] interface { // Read reads the next event from the stream. diff --git a/internal/sse/mockeventreader.go b/internal/sse/mockeventreader.go new file mode 100644 index 00000000..f5108cb8 --- /dev/null +++ b/internal/sse/mockeventreader.go @@ -0,0 +1,38 @@ +package sse + +import ( + "bufio" + "bytes" + "io" +) + +// MockEventReader is a mock implementation of the sse.EventReader. This lets us use EventReader as a common interface +// for models that support streaming (like gpt-4o) and models that do not (like the o1 class of models) +type MockEventReader[T any] struct { + reader io.ReadCloser + scanner *bufio.Scanner + events []T + index int +} + +func NewMockEventReader[T any](events []T) *MockEventReader[T] { + data := []byte{} + reader := io.NopCloser(bytes.NewReader(data)) + scanner := bufio.NewScanner(reader) + return &MockEventReader[T]{reader: reader, scanner: scanner, events: events, index: 0} +} + +func (mer *MockEventReader[T]) Read() (T, error) { + if mer.index >= len(mer.events) { + var zero T + return zero, io.EOF + } + event := mer.events[mer.index] + mer.index++ + return event, nil +} + +func (mer *MockEventReader[T]) Close() error { + return mer.reader.Close() +} + From 3b15075a0ebfee0853fb77ca9b13b900f6d89299 Mon Sep 17 00:00:00 2001 From: Sean Goedecke Date: Mon, 7 Oct 2024 11:23:07 +1100 Subject: [PATCH 2/5] go fmt --- internal/sse/eventreader.go | 4 ++-- internal/sse/mockeventreader.go | 33 ++++++++++++++++----------------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/internal/sse/eventreader.go b/internal/sse/eventreader.go index c5252b32..84b16f73 100644 --- a/internal/sse/eventreader.go +++ b/internal/sse/eventreader.go @@ -11,8 +11,8 @@ import ( ) type EventReaderInterface[T any] interface { - Read() (T, error) - Close() error + Read() (T, error) + Close() error } // Reader is an interface for reading events from an SSE stream. diff --git a/internal/sse/mockeventreader.go b/internal/sse/mockeventreader.go index f5108cb8..aa015a79 100644 --- a/internal/sse/mockeventreader.go +++ b/internal/sse/mockeventreader.go @@ -2,37 +2,36 @@ package sse import ( "bufio" - "bytes" + "bytes" "io" ) // MockEventReader is a mock implementation of the sse.EventReader. This lets us use EventReader as a common interface // for models that support streaming (like gpt-4o) and models that do not (like the o1 class of models) type MockEventReader[T any] struct { - reader io.ReadCloser - scanner *bufio.Scanner - events []T - index int + reader io.ReadCloser + scanner *bufio.Scanner + events []T + index int } func NewMockEventReader[T any](events []T) *MockEventReader[T] { - data := []byte{} - reader := io.NopCloser(bytes.NewReader(data)) - scanner := bufio.NewScanner(reader) - return &MockEventReader[T]{reader: reader, scanner: scanner, events: events, index: 0} + data := []byte{} + reader := io.NopCloser(bytes.NewReader(data)) + scanner := bufio.NewScanner(reader) + return &MockEventReader[T]{reader: reader, scanner: scanner, events: events, index: 0} } func (mer *MockEventReader[T]) Read() (T, error) { if mer.index >= len(mer.events) { - var zero T - return zero, io.EOF - } - event := mer.events[mer.index] - mer.index++ - return event, nil + var zero T + return zero, io.EOF + } + event := mer.events[mer.index] + mer.index++ + return event, nil } func (mer *MockEventReader[T]) Close() error { - return mer.reader.Close() + return mer.reader.Close() } - From f90e447488449b00440b2890298dc75756fec4da Mon Sep 17 00:00:00 2001 From: Sean Goedecke Date: Mon, 7 Oct 2024 11:30:54 +1100 Subject: [PATCH 3/5] Cleanup --- internal/azure_models/client.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/azure_models/client.go b/internal/azure_models/client.go index ce02546c..2fcfbe5f 100644 --- a/internal/azure_models/client.go +++ b/internal/azure_models/client.go @@ -59,6 +59,7 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple } if resp.StatusCode != http.StatusOK { + // If we aren't going to return an SSE stream, then ensure the response body is closed. defer resp.Body.Close() return nil, c.handleHTTPError(resp) } @@ -69,8 +70,6 @@ func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatComple // Handle streamed response chatCompletionResponse.Reader = sse.NewEventReader[ChatCompletion](resp.Body) } else { - // Handle non-streamed response - defer resp.Body.Close() var completion ChatCompletion if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { return nil, err From 228c26b3e5db597cc9c93ffb08e13a83376728cc Mon Sep 17 00:00:00 2001 From: Sean Goedecke Date: Mon, 7 Oct 2024 12:13:07 +1100 Subject: [PATCH 4/5] remove duplicate interface --- internal/azure_models/types.go | 2 +- internal/sse/eventreader.go | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/internal/azure_models/types.go b/internal/azure_models/types.go index bb3ecf6e..c3f7acfb 100644 --- a/internal/azure_models/types.go +++ b/internal/azure_models/types.go @@ -50,7 +50,7 @@ type ChatCompletion struct { } type ChatCompletionResponse struct { - Reader sse.EventReaderInterface[ChatCompletion] + Reader sse.Reader[ChatCompletion] } type modelCatalogSearchResponse struct { diff --git a/internal/sse/eventreader.go b/internal/sse/eventreader.go index 84b16f73..5eddcc87 100644 --- a/internal/sse/eventreader.go +++ b/internal/sse/eventreader.go @@ -10,11 +10,6 @@ import ( "strings" ) -type EventReaderInterface[T any] interface { - Read() (T, error) - Close() error -} - // Reader is an interface for reading events from an SSE stream. type Reader[T any] interface { // Read reads the next event from the stream. From 7c146b7947b6d48742108dcce28eb974ed7b7ca3 Mon Sep 17 00:00:00 2001 From: Sean Goedecke Date: Mon, 7 Oct 2024 12:35:37 +1100 Subject: [PATCH 5/5] add comment --- cmd/run/run.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/run/run.go b/cmd/run/run.go index 47e67d04..69928884 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -391,6 +391,8 @@ func NewRunCommand() *cobra.Command { sp.Stop() for _, choice := range completion.Choices { + // Streamed responses from the OpenAI API have their data in `.Delta`, while + // non-streamed responses use `.Message`, so let's support both if choice.Delta != nil && choice.Delta.Content != nil { content := choice.Delta.Content messageBuilder.WriteString(*content)