Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement llm cache #131

Merged
merged 1 commit into from Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/llm/cache/main.go
@@ -0,0 +1,41 @@
package main

import (
"bufio"
"context"
"fmt"
"os"
"strings"

openaiembedder "github.com/henomis/lingoose/embedder/openai"
simplevectorindex "github.com/henomis/lingoose/index/simpleVectorIndex"
"github.com/henomis/lingoose/llm/cache"
"github.com/henomis/lingoose/llm/openai"
)

func main() {

embedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2)
index := simplevectorindex.New("db", ".", embedder)
llm := openai.NewCompletion().WithCompletionCache(cache.New(embedder, index).WithTopK(3))

for {
text := askUserInput("What is your question?")

response, err := llm.Completion(context.Background(), text)
if err != nil {
fmt.Println(err)
continue
}

fmt.Println(response)
}
}

func askUserInput(question string) string {
fmt.Printf("%s > ", question)
reader := bufio.NewReader(os.Stdin)
name, _ := reader.ReadString('\n')
name = strings.TrimSuffix(name, "\n")
return name
}
22 changes: 11 additions & 11 deletions index/simpleVectorIndex/simpleVectorIndex.go
Expand Up @@ -6,9 +6,9 @@ import (
"fmt"
"math"
"os"
"strconv"
"strings"

"github.com/google/uuid"
"github.com/henomis/lingoose/document"
"github.com/henomis/lingoose/embedder"
"github.com/henomis/lingoose/index"
Expand Down Expand Up @@ -53,7 +53,6 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu
return fmt.Errorf("%s: %w", index.ErrInternal, err)
}

id := 0
for i := 0; i < len(documents); i += defaultBatchSize {

end := i + defaultBatchSize
Expand All @@ -72,8 +71,11 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu
}

for j, document := range documents[i:end] {
s.data = append(s.data, buildDataFromEmbeddingAndDocument(id, embeddings[j], document))
id++
id, err := uuid.NewUUID()
if err != nil {
return err
}
s.data = append(s.data, buildDataFromEmbeddingAndDocument(id.String(), embeddings[j], document))
}

}
Expand All @@ -87,14 +89,14 @@ func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Docu
}

func buildDataFromEmbeddingAndDocument(
id int,
id string,
embedding embedder.Embedding,
document document.Document,
) data {
metadata := index.DeepCopyMetadata(document.Metadata)
metadata[index.DefaultKeyContent] = document.Content
return data{
ID: fmt.Sprintf("%d", id),
ID: id,
Values: embedding,
Metadata: metadata,
}
Expand Down Expand Up @@ -148,13 +150,11 @@ func (s *Index) Add(ctx context.Context, item *index.Data) error {
}

if item.ID == "" {
lastID := s.data[len(s.data)-1].ID
lastIDAsInt, err := strconv.Atoi(lastID)
id, err := uuid.NewUUID()
if err != nil {
return fmt.Errorf("%s: %w", index.ErrInternal, err)
return err
}

item.ID = fmt.Sprintf("%d", lastIDAsInt+1)
item.ID = id.String()
}

s.data = append(
Expand Down
108 changes: 108 additions & 0 deletions llm/cache/cache.go
@@ -0,0 +1,108 @@
package cache

import (
"context"
"fmt"

"github.com/henomis/lingoose/embedder"
"github.com/henomis/lingoose/index"
indexoption "github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/types"
)

type Embedder interface {
Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error)
}

type Index interface {
Search(context.Context, []float64, ...indexoption.Option) (index.SearchResults, error)
Add(context.Context, *index.Data) error
}

var ErrCacheMiss = fmt.Errorf("cache miss")

const (
defaultTopK = 1
defaultScoreThreshold = 0.9
cacheAnswerMetadataKey = "cache-answer"
)

type Cache struct {
embedder Embedder
index Index
topK int
scoreThreshold float64
}

type CacheResult struct {
Answer []string
Embedding []float64
}

func New(embedder Embedder, index Index) *Cache {
return &Cache{
embedder: embedder,
index: index,
topK: defaultTopK,
scoreThreshold: defaultScoreThreshold,
}
}

func (c *Cache) WithTopK(topK int) *Cache {
c.topK = topK
return c
}

func (c *Cache) WithScoreThreshold(scoreThreshold float64) *Cache {
c.scoreThreshold = scoreThreshold
return c
}

func (c *Cache) Get(ctx context.Context, query string) (*CacheResult, error) {

embedding, err := c.embedder.Embed(ctx, []string{query})
if err != nil {
return nil, err
}

results, err := c.index.Search(ctx, embedding[0], indexoption.WithTopK(c.topK))
if err != nil {
return nil, err
}

answers, cacheHit := c.extractResults(results)
if cacheHit {
return &CacheResult{
Answer: answers,
Embedding: embedding[0],
}, nil
}

return nil, ErrCacheMiss
}

func (c *Cache) Set(ctx context.Context, embedding []float64, answer string) error {
return c.index.Add(ctx, &index.Data{
Values: embedding,
Metadata: types.Meta{
cacheAnswerMetadataKey: answer,
},
})
}

func (c *Cache) extractResults(results index.SearchResults) ([]string, bool) {
var output []string

for _, result := range results {
if result.Score > c.scoreThreshold {
answer, ok := result.Metadata[cacheAnswerMetadataKey]
if !ok {
continue
}

output = append(output, answer.(string))
}
}

return output, len(output) > 0
}
27 changes: 27 additions & 0 deletions llm/openai/openai.go
Expand Up @@ -10,6 +10,7 @@ import (
"strings"

"github.com/henomis/lingoose/chat"
"github.com/henomis/lingoose/llm/cache"
"github.com/henomis/lingoose/types"
"github.com/mitchellh/mapstructure"
"github.com/sashabaranov/go-openai"
Expand Down Expand Up @@ -71,6 +72,7 @@ type OpenAI struct {
functionsMaxIterations uint
calledFunctionName *string
finishReason string
cache *cache.Cache
}

func New(model Model, temperature float32, maxTokens int, verbose bool) *OpenAI {
Expand Down Expand Up @@ -130,6 +132,12 @@ func (o *OpenAI) WithVerbose(verbose bool) *OpenAI {
return o
}

// WithCache sets the cache to use for the OpenAI instance.
func (o *OpenAI) WithCompletionCache(cache *cache.Cache) *OpenAI {
o.cache = cache
return o
}

// CalledFunctionName returns the name of the function that was called.
func (o *OpenAI) CalledFunctionName() *string {
return o.calledFunctionName
Expand Down Expand Up @@ -160,11 +168,30 @@ func NewChat() *OpenAI {

// Completion returns a single completion for the given prompt.
func (o *OpenAI) Completion(ctx context.Context, prompt string) (string, error) {
var cacheResult *cache.CacheResult
var err error

if o.cache != nil {
cacheResult, err = o.cache.Get(ctx, prompt)
if err == nil {
return strings.Join(cacheResult.Answer, "\n"), nil
} else if err != cache.ErrCacheMiss {
return "", fmt.Errorf("%s: %w", ErrOpenAICompletion, err)
}
}

outputs, err := o.BatchCompletion(ctx, []string{prompt})
if err != nil {
return "", err
}

if o.cache != nil {
err = o.cache.Set(ctx, cacheResult.Embedding, outputs[0])
if err != nil {
return "", fmt.Errorf("%s: %w", ErrOpenAICompletion, err)
}
}

return outputs[0], nil
}

Expand Down