Skip to content

Commit

Permalink
Assist - OpenAI library port (#25948)
Browse files Browse the repository at this point in the history
* Assist - OpenAI library port

* Add tests

* Address code review comments

* Added comment

* Partial backport of #26058

* Move AI messages to a new file

* Prevent blocking on error

* Add comments.
Fix typo
  • Loading branch information
jakule committed Jun 2, 2023
1 parent 6990a3a commit ba6387b
Show file tree
Hide file tree
Showing 7 changed files with 626 additions and 0 deletions.
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/daviddengcn/go-colortext v1.0.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.9.0 // indirect
github.com/docker/distribution v2.8.2+incompatible // indirect
github.com/dvsekhvalnov/jose2go v1.5.0 // indirect
github.com/elastic/elastic-transport-go/v8 v8.2.0 // indirect
Expand Down Expand Up @@ -343,12 +344,14 @@ require (
github.com/rs/zerolog v1.28.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46 // indirect
github.com/sashabaranov/go-openai v1.9.3
github.com/shabbyrobe/gocovmerge v0.0.0-20190829150210-3e036491d500 // indirect
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 // indirect
github.com/siddontang/go-log v0.0.0-20180807004314-8d05993dda07 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/thales-e-security/pool v0.0.2 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tiktoken-go/tokenizer v0.1.0
github.com/x448/float16 v0.8.4 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.1 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
github.com/dimchansky/utfbom v1.1.0/go.mod h1:rO41eb7gLfo8SF1jd9F8HplJm1Fewwi4mQvIirEdv+8=
github.com/dimchansky/utfbom v1.1.1/go.mod h1:SxdoEBH5qIqFocHMyGOXVAybYJdr71b1Q/j0mACtrfE=
github.com/dlclark/regexp2 v1.9.0 h1:pTK/l/3qYIKaRXuHnEnIf7Y5NxfRPfpb7dis6/gdlVI=
github.com/dlclark/regexp2 v1.9.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko=
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
Expand Down Expand Up @@ -1476,6 +1478,8 @@ github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb
github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46 h1:GHRpF1pTW19a8tTFrMLUcfWwyC0pnifVo2ClaLq+hP8=
github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8=
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
github.com/sashabaranov/go-openai v1.9.3 h1:uNak3Rn5pPsKRs9bdT7RqRZEyej/zdZOEI2/8wvrFtM=
github.com/sashabaranov/go-openai v1.9.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sassoftware/go-rpmutils v0.0.0-20190420191620-a8f1baeba37b/go.mod h1:am+Fp8Bt506lA3Rk3QCmSqmYmLMnPDhdDUcosQCAx+I=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/schollz/progressbar/v3 v3.13.1 h1:o8rySDYiQ59Mwzy2FELeHY5ZARXZTVJC7iHD6PEFUiE=
Expand Down Expand Up @@ -1572,6 +1576,8 @@ github.com/tj/assert v0.0.0-20171129193455-018094318fb0/go.mod h1:mZ9/Rh9oLWpLLD
github.com/tj/go-elastic v0.0.0-20171221160941-36157cbbebc2/go.mod h1:WjeM0Oo1eNAjXGDx2yma7uG2XoyRZTq1uv3M/o7imD0=
github.com/tj/go-kinesis v0.0.0-20171128231115-08b17f58cb1b/go.mod h1:/yhzCV0xPfx6jb1bBgRFjl5lytqVqZXEaeqWP8lTEao=
github.com/tj/go-spin v1.1.0/go.mod h1:Mg1mzmePZm4dva8Qz60H2lHwmJ2loum4VIrLgVnKwh4=
github.com/tiktoken-go/tokenizer v0.1.0 h1:c1fXriHSR/NmhMDTwUDLGiNhHwTV+ElABGvqhCWLRvY=
github.com/tiktoken-go/tokenizer v0.1.0/go.mod h1:7SZW3pZUKWLJRilTvWCa86TOVIiiJhYj3FQ5V3alWcg=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tmc/grpc-websocket-proxy v0.0.0-20200427203606-3cfed13b9966/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
Expand Down
229 changes: 229 additions & 0 deletions lib/ai/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai

import (
"context"
"encoding/json"
"errors"
"io"
"strings"

"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/tiktoken-go/tokenizer"
)

const maxResponseTokens = 2000

// Chat represents a conversation between a user and an assistant with context memory.
type Chat struct {
client *Client
messages []openai.ChatCompletionMessage
tokenizer tokenizer.Codec
}

// Insert inserts a message into the conversation. This is commonly in the
// form of a user's input but may also take the form of a system messages used for instructions.
func (chat *Chat) Insert(role string, content string) Message {
chat.messages = append(chat.messages, openai.ChatCompletionMessage{
Role: role,
Content: content,
})

return Message{
Role: role,
Content: content,
Idx: len(chat.messages) - 1,
}
}

// PromptTokens uses the chat's tokenizer to calculate
// the total number of tokens in the prompt
//
// Ref: https://github.com/openai/openai-cookbook/blob/594fc6c952425810e9ea5bd1a275c8ca5f32e8f9/examples/How_to_count_tokens_with_tiktoken.ipynb
func (chat *Chat) PromptTokens() (int, error) {
// perRequest is the number of tokens used up for each completion request
const perRequest = 3
// perRole is the number of tokens used to encode a message's role
const perRole = 1
// perMessage is the token "overhead" for each message
const perMessage = 3

sum := perRequest
for _, m := range chat.messages {
tokens, _, err := chat.tokenizer.Encode(m.Content)
if err != nil {
return 0, trace.Wrap(err)
}
sum += len(tokens)
sum += perRole
sum += perMessage
}

return sum, nil
}

// Summary creates a short summary for the given input.
func (chat *Chat) Summary(ctx context.Context, message string) (string, error) {
resp, err := chat.client.svc.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4,
Messages: []openai.ChatCompletionMessage{
{Role: openai.ChatMessageRoleSystem, Content: promptSummarizeTitle},
{Role: openai.ChatMessageRoleUser, Content: message},
},
},
)

if err != nil {
return "", trace.Wrap(err)
}

return resp.Choices[0].Message.Content, nil
}

// Complete completes the conversation with a message from the assistant based on the current context.
// On success, it returns the message and the number of tokens used for the completion.
// Returned types:
// - Message: the message from the assistant
// - int: the number of tokens used for the completion
// - error: an error if one occurred
// Message types:
// - CompletionCommand: a command from the assistant
// - StreamingMessage: a message that is streamed from the assistant
func (chat *Chat) Complete(ctx context.Context) (any, error) {
var numTokens int

// if the chat is empty, return the initial response we predefine instead of querying GPT-4
if len(chat.messages) == 1 {
return &Message{
Role: openai.ChatMessageRoleAssistant,
Content: initialAIResponse,
Idx: len(chat.messages) - 1,
}, nil
}

// if not, copy the current chat log to a new slice and append the suffix instruction
messages := make([]openai.ChatCompletionMessage, len(chat.messages)+1)
copy(messages, chat.messages)
messages[len(messages)-1] = openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: promptExtractInstruction,
}

// create a streaming completion request, we do this to optimistically stream the response when
// we don't believe it's a payload
stream, err := chat.client.svc.CreateChatCompletionStream(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4,
Messages: messages,
MaxTokens: maxResponseTokens,
Stream: true,
},
)
if err != nil {
return nil, trace.Wrap(err)
}

var (
response openai.ChatCompletionStreamResponse
trimmed string
)
for trimmed == "" {
// fetch the first delta to check for a possible JSON payload
response, err = stream.Recv()
if err != nil {
return nil, trace.Wrap(err)
}
numTokens++

trimmed = strings.TrimSpace(response.Choices[0].Delta.Content)
}

// if it looks like a JSON payload, let's wait for the entire response and try to parse it
if strings.HasPrefix(trimmed, "{") {
payload := strings.Builder{}
payload.WriteString(response.Choices[0].Delta.Content)

for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, trace.Wrap(err)
}
numTokens++

payload.WriteString(response.Choices[0].Delta.Content)
}

// if we can parse it, return the parsed payload, otherwise return a non-streaming message
var c CompletionCommand
err = json.Unmarshal([]byte(payload.String()), &c)
switch err {
case nil:
c.NumTokens = numTokens
return &c, nil
default:
return &Message{
Role: openai.ChatMessageRoleAssistant,
Content: payload.String(),
Idx: len(chat.messages) - 1,
NumTokens: numTokens,
}, nil
}
}

// if it doesn't look like a JSON payload, return a streaming message to the caller
chunks := make(chan string, 1)
errCh := make(chan error)
chunks <- response.Choices[0].Delta.Content
go func() {
defer close(chunks)

for {
response, err := stream.Recv()
switch {
case errors.Is(err, io.EOF):
return
case err != nil:
select {
case <-ctx.Done():
case errCh <- trace.Wrap(err):
}
return
}

select {
case chunks <- response.Choices[0].Delta.Content:
case <-ctx.Done():
return
}
}
}()

return &StreamingMessage{
Role: openai.ChatMessageRoleAssistant,
Idx: len(chat.messages) - 1,
Chunks: chunks,
Error: errCh,
}, nil
}

0 comments on commit ba6387b

Please sign in to comment.