Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for openai completion and chat stream #73

Merged
merged 1 commit into from May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/llm/openai/stream/chat/main.go
@@ -0,0 +1,36 @@
package main

import (
"context"
"fmt"

"github.com/henomis/lingoose/chat"
"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/prompt"
)

func main() {

chat := chat.New(
chat.PromptMessage{
Type: chat.MessageTypeSystem,
Prompt: prompt.New("You are a professional joke writer"),
},
chat.PromptMessage{
Type: chat.MessageTypeUser,
Prompt: prompt.New("Write a joke about geese"),
},
)

llm := openai.NewChat()

err := llm.ChatStream(context.Background(), func(output string) {
fmt.Printf("%s", output)
}, chat)
if err != nil {
panic(err)
}

fmt.Println()

}
23 changes: 23 additions & 0 deletions examples/llm/openai/stream/completion/main.go
@@ -0,0 +1,23 @@
package main

import (
"context"
"fmt"

"github.com/henomis/lingoose/llm/openai"
)

func main() {

llm := openai.NewCompletion()

err := llm.CompletionStream(context.Background(), func(output string) {
fmt.Printf("%s", output)
}, "Tell me a joke about geese")
if err != nil {
panic(err)
}

fmt.Println()

}
137 changes: 121 additions & 16 deletions llm/openai/openai.go
Expand Up @@ -3,7 +3,9 @@ package openai

import (
"context"
"errors"
"fmt"
"io"
"os"
"strings"

Expand Down Expand Up @@ -42,16 +44,17 @@ const (
GPT3Babbage Model = openai.GPT3Babbage
)

type OpenAICallback func(types.Meta)
type OpenAIUsageCallback func(types.Meta)
type OpenAIStreamCallback func(string)

type openAI struct {
openAIClient *openai.Client
model Model
temperature float32
maxTokens int
stop []string
verbose bool
callback OpenAICallback
openAIClient *openai.Client
model Model
temperature float32
maxTokens int
stop []string
verbose bool
usageCallback OpenAIUsageCallback
}

func New(model Model, temperature float32, maxTokens int, verbose bool) *openAI {
Expand Down Expand Up @@ -82,8 +85,8 @@ func (o *openAI) WithMaxTokens(maxTokens int) *openAI {
return o
}

func (o *openAI) WithCallback(callback OpenAICallback) *openAI {
o.callback = callback
func (o *openAI) WithCallback(callback OpenAIUsageCallback) *openAI {
o.usageCallback = callback
return o
}

Expand Down Expand Up @@ -139,8 +142,8 @@ func (o *openAI) Completion(ctx context.Context, prompt string) (string, error)
return "", fmt.Errorf("%s: %w", ErrOpenAICompletion, err)
}

if o.callback != nil {
o.setMetadata(response.Usage)
if o.usageCallback != nil {
o.setUsageMetadata(response.Usage)
}

if len(response.Choices) == 0 {
Expand All @@ -155,6 +158,57 @@ func (o *openAI) Completion(ctx context.Context, prompt string) (string, error)
return output, nil
}

func (o *openAI) CompletionStream(ctx context.Context, callbackFn OpenAIStreamCallback, prompt string) error {

stream, err := o.openAIClient.CreateCompletionStream(
ctx,
openai.CompletionRequest{
Model: string(o.model),
Prompt: prompt,
MaxTokens: o.maxTokens,
Temperature: o.temperature,
N: DefaultOpenAINumResults,
TopP: DefaultOpenAITopP,
Stop: o.stop,
},
)
if err != nil {
return fmt.Errorf("%s: %w", ErrOpenAICompletion, err)
}

defer stream.Close()

for {

response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}

if err != nil {
return fmt.Errorf("%s: %w", ErrOpenAICompletion, err)
}

if o.usageCallback != nil {
o.setUsageMetadata(response.Usage)
}

if len(response.Choices) == 0 {
return fmt.Errorf("%s: no choices returned", ErrOpenAICompletion)
}

output := response.Choices[0].Text
if o.verbose {
debugCompletion(prompt, output)
}

callbackFn(output)

}

return nil
}

func (o *openAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) {

messages, err := buildMessages(prompt)
Expand All @@ -179,8 +233,8 @@ func (o *openAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) {
return "", fmt.Errorf("%s: %w", ErrOpenAIChat, err)
}

if o.callback != nil {
o.setMetadata(response.Usage)
if o.usageCallback != nil {
o.setUsageMetadata(response.Usage)
}

if len(response.Choices) == 0 {
Expand All @@ -196,11 +250,62 @@ func (o *openAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) {
return content, nil
}

func (o *openAI) ChatStream(ctx context.Context, callbackFn OpenAIStreamCallback, prompt *chat.Chat) error {

messages, err := buildMessages(prompt)
if err != nil {
return fmt.Errorf("%s: %w", ErrOpenAIChat, err)
}

stream, err := o.openAIClient.CreateChatCompletionStream(
ctx,
openai.ChatCompletionRequest{
Model: string(o.model),
Messages: messages,
MaxTokens: o.maxTokens,
Temperature: o.temperature,
N: DefaultOpenAINumResults,
TopP: DefaultOpenAITopP,
Stop: o.stop,
},
)
if err != nil {
return fmt.Errorf("%s: %w", ErrOpenAIChat, err)
}

for {

response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}

// oops no usage here?
// if o.usageCallback != nil {
// o.setUsageMetadata(response.Usage)
// }

if len(response.Choices) == 0 {
return fmt.Errorf("%s: no choices returned", ErrOpenAIChat)
}

content := response.Choices[0].Delta.Content

if o.verbose {
debugChat(prompt, content)
}

callbackFn(content)
}

return nil
}

func (o *openAI) SetStop(stop []string) {
o.stop = stop
}

func (o *openAI) setMetadata(usage openai.Usage) {
func (o *openAI) setUsageMetadata(usage openai.Usage) {

callbackMetadata := make(types.Meta)

Expand All @@ -209,7 +314,7 @@ func (o *openAI) setMetadata(usage openai.Usage) {
return
}

o.callback(callbackMetadata)
o.usageCallback(callbackMetadata)
}

func buildMessages(prompt *chat.Chat) ([]openai.ChatCompletionMessage, error) {
Expand Down