Skip to content

Commit

Permalink
chore: make constructors composable (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed May 10, 2023
1 parent 4f19a46 commit 082fd2b
Show file tree
Hide file tree
Showing 15 changed files with 118 additions and 187 deletions.
12 changes: 7 additions & 5 deletions embedder/openai/openai.go
Expand Up @@ -62,16 +62,18 @@ type openAIEmbedder struct {
model Model
}

func New(model Model) (*openAIEmbedder, error) {
func New(model Model) *openAIEmbedder {
openAIKey := os.Getenv("OPENAI_API_KEY")
if openAIKey == "" {
return nil, fmt.Errorf("OPENAI_API_KEY not set")
}

return &openAIEmbedder{
openAIClient: openai.NewClient(openAIKey),
model: model,
}, nil
}
}

func (o *openAIEmbedder) WithAPIKey(apiKey string) *openAIEmbedder {
o.openAIClient = openai.NewClient(apiKey)
return o
}

func (o *openAIEmbedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
Expand Down
5 changes: 1 addition & 4 deletions examples/chat/main.go
Expand Up @@ -22,10 +22,7 @@ func main() {
},
)

llmOpenAI, err := openai.New(openai.GPT3Dot5Turbo, openai.DefaultOpenAITemperature, openai.DefaultOpenAIMaxTokens, true)
if err != nil {
panic(err)
}
llmOpenAI := openai.New(openai.GPT3Dot5Turbo, openai.DefaultOpenAITemperature, openai.DefaultOpenAIMaxTokens, true)

response, err := llmOpenAI.Chat(context.Background(), chat)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions examples/embeddings/knowledge_base/db.json

Large diffs are not rendered by default.

28 changes: 6 additions & 22 deletions examples/embeddings/knowledge_base/main.go
Expand Up @@ -22,29 +22,19 @@ const (

func main() {

openaiEmbedder, err := openaiembedder.New(openaiembedder.AdaEmbeddingV2)
if err != nil {
panic(err)
}

docsVectorIndex, err := index.NewSimpleVectorIndex("db", ".", openaiEmbedder)
if err != nil {
panic(err)
}
openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2)

docsVectorIndex := index.NewSimpleVectorIndex("db", ".", openaiEmbedder)
indexIsEmpty, _ := docsVectorIndex.IsEmpty()

if indexIsEmpty {
err = ingestData(openaiEmbedder)
err := ingestData(openaiEmbedder)
if err != nil {
panic(err)
}
}

llmOpenAI, err := openai.NewChat()
if err != nil {
panic(err)
}
llmOpenAI := openai.NewChat()

fmt.Println("Enter a query to search the knowledge base. Type 'quit' to exit.")
query := ""
Expand Down Expand Up @@ -115,15 +105,9 @@ func ingestData(openaiEmbedder index.Embedder) error {

fmt.Printf("Learning Knowledge Base...")

docsVectorIndex, err := index.NewSimpleVectorIndex("db", ".", openaiEmbedder)
if err != nil {
return err
}
docsVectorIndex := index.NewSimpleVectorIndex("db", ".", openaiEmbedder)

loader, err := loader.NewPDFToTextLoader("/usr/bin/pdftotext", "./kb")
if err != nil {
return err
}
loader := loader.NewPDFToTextLoader("/usr/bin/pdftotext", "./kb")

documents, err := loader.Load()
if err != nil {
Expand Down
27 changes: 5 additions & 22 deletions examples/embeddings/pinecone/main.go
Expand Up @@ -20,10 +20,7 @@ var pineconeClient *pineconego.PineconeGo

func main() {

openaiEmbedder, err := openaiembedder.New(openaiembedder.AdaEmbeddingV2)
if err != nil {
panic(err)
}
openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2)

pineconeApiKey := os.Getenv("PINECONE_API_KEY")
if pineconeApiKey == "" {
Expand All @@ -42,7 +39,7 @@ func main() {
panic(err)
}

pineconeIndex, err := index.NewPinecone(
pineconeIndex := index.NewPinecone(
index.PineconeOptions{
IndexName: "test",
ProjectID: projectID,
Expand All @@ -51,9 +48,6 @@ func main() {
},
openaiEmbedder,
)
if err != nil {
panic(err)
}

indexIsEmpty, err := pineconeIndex.IsEmpty(context.Background())
if err != nil {
Expand Down Expand Up @@ -86,10 +80,7 @@ func main() {
fmt.Println("----------")
}

llmOpenAI, err := openai.NewCompletion()
if err != nil {
panic(err)
}
llmOpenAI := openai.NewCompletion()

prompt1, err := prompt.NewPromptTemplate(
"Based on the following context answer to the question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}",
Expand Down Expand Up @@ -130,7 +121,7 @@ func getProjectID(pineconeEnvironment, pineconeApiKey string) (string, error) {

func ingestData(projectID string, openaiEmbedder index.Embedder) error {

pineconeIndex, err := index.NewPinecone(
pineconeIndex := index.NewPinecone(
index.PineconeOptions{
IndexName: "test",
ProjectID: projectID,
Expand All @@ -139,16 +130,8 @@ func ingestData(projectID string, openaiEmbedder index.Embedder) error {
},
openaiEmbedder,
)
if err != nil {
return err
}

loader, err := loader.NewDirectoryLoader(".", ".txt")
if err != nil {
return err
}

documents, err := loader.Load()
documents, err := loader.NewDirectoryLoader(".", ".txt").Load()
if err != nil {
return err
}
Expand Down
38 changes: 9 additions & 29 deletions examples/embeddings/simpleVector/main.go
Expand Up @@ -14,20 +14,13 @@ import (

func main() {

openaiEmbedder, err := openaiembedder.New(openaiembedder.AdaEmbeddingV2)
if err != nil {
panic(err)
}

docsVectorIndex, err := index.NewSimpleVectorIndex("docs", ".", openaiEmbedder)
if err != nil {
panic(err)
}
openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2)

docsVectorIndex := index.NewSimpleVectorIndex("docs", ".", openaiEmbedder)
indexIsEmpty, _ := docsVectorIndex.IsEmpty()

if indexIsEmpty {
err = ingestData(openaiEmbedder)
err := ingestData(openaiEmbedder)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -56,11 +49,7 @@ func main() {
documentContext += similarity.Document.Content + "\n\n"
}

llmOpenAI, err := openai.NewCompletion()
if err != nil {
panic(err)
}

llmOpenAI := openai.NewCompletion()
prompt1, err := prompt.NewPromptTemplate(
"Based on the following context answer to the question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}",
map[string]string{
Expand All @@ -77,28 +66,19 @@ func main() {
panic(err)
}

_, err = llmOpenAI.Completion(context.Background(), prompt1.String())

output, err := llmOpenAI.Completion(context.Background(), prompt1.String())
if err != nil {
panic(err)
}

fmt.Println(output)
}

func ingestData(openaiEmbedder index.Embedder) error {

fmt.Printf("Ingesting data...")

docsVectorIndex, err := index.NewSimpleVectorIndex("docs", ".", openaiEmbedder)
if err != nil {
return err
}

loader, err := loader.NewDirectoryLoader(".", ".txt")
if err != nil {
return err
}

documents, err := loader.Load()
documents, err := loader.NewDirectoryLoader(".", ".txt").Load()
if err != nil {
return err
}
Expand All @@ -107,7 +87,7 @@ func ingestData(openaiEmbedder index.Embedder) error {

documentChunks := textSplitter.SplitDocuments(documents)

err = docsVectorIndex.LoadFromDocuments(context.Background(), documentChunks)
err = index.NewSimpleVectorIndex("docs", ".", openaiEmbedder).LoadFromDocuments(context.Background(), documentChunks)
if err != nil {
return err
}
Expand Down
11 changes: 2 additions & 9 deletions examples/pipeline/chat/main.go
Expand Up @@ -18,15 +18,8 @@ func main() {

cache := ram.New()

llmChatOpenAI, err := openai.NewChat()
if err != nil {
panic(err)
}

llmOpenAI, err := openai.NewCompletion()
if err != nil {
panic(err)
}
llmChatOpenAI := openai.NewChat()
llmOpenAI := openai.NewCompletion()

prompt1, _ := prompt.NewPromptTemplate(
"You are a {{.mode}} {{.role}}",
Expand Down
5 changes: 1 addition & 4 deletions examples/pipeline/openai/main.go
Expand Up @@ -16,10 +16,7 @@ func main() {

cache := ram.New()

llmOpenAI, err := openai.NewCompletion()
if err != nil {
panic(err)
}
llmOpenAI := openai.NewCompletion()

llmOpenAI.SetCallback(func(response types.Meta) {
fmt.Printf("USAGE: %#v\n", response)
Expand Down
5 changes: 1 addition & 4 deletions examples/pipeline/splitter/main.go
Expand Up @@ -13,10 +13,7 @@ import (

func main() {

llmOpenAI, err := openai.NewCompletion()
if err != nil {
panic(err)
}
llmOpenAI := openai.NewCompletion()

llm := pipeline.Llm{
LlmEngine: llmOpenAI,
Expand Down
16 changes: 7 additions & 9 deletions index/pinecone.go
Expand Up @@ -37,17 +37,10 @@ type PineconeOptions struct {
BatchUpsertSize *int
}

func NewPinecone(options PineconeOptions, embedder Embedder) (*pinecone, error) {
func NewPinecone(options PineconeOptions, embedder Embedder) *pinecone {

apiKey := os.Getenv("PINECONE_API_KEY")
if apiKey == "" {
return nil, fmt.Errorf("PINECONE_API_KEY is not set")
}

environment := os.Getenv("PINECONE_ENVIRONMENT")
if environment == "" {
return nil, fmt.Errorf("PINECONE_ENVIRONMENT is not set")
}

pineconeClient := pineconego.New(environment, apiKey)

Expand All @@ -64,7 +57,12 @@ func NewPinecone(options PineconeOptions, embedder Embedder) (*pinecone, error)
namespace: options.Namespace,
includeContent: options.IncludeContent,
batchUpsertSize: batchUpsertSize,
}, nil
}
}

func (p *pinecone) WithAPIKeyAndEnvironment(apiKey, environment string) *pinecone {
p.pineconeClient = pineconego.New(environment, apiKey)
return p
}

func (s *pinecone) LoadFromDocuments(ctx context.Context, documents []document.Document) error {
Expand Down
18 changes: 8 additions & 10 deletions index/simpleVectorIndex.go
Expand Up @@ -28,23 +28,15 @@ type simpleVectorIndex struct {
embedder Embedder
}

func NewSimpleVectorIndex(name string, outputPath string, embedder Embedder) (*simpleVectorIndex, error) {
func NewSimpleVectorIndex(name string, outputPath string, embedder Embedder) *simpleVectorIndex {
simpleVectorIndex := &simpleVectorIndex{
data: []simpleVectorIndexData{},
outputPath: outputPath,
name: name,
embedder: embedder,
}

_, err := os.Stat(simpleVectorIndex.database())
if err == nil {
err = simpleVectorIndex.load()
if err != nil {
return nil, fmt.Errorf("%s: %w", ErrInternal, err)
}
}

return simpleVectorIndex, nil
return simpleVectorIndex
}

func (s *simpleVectorIndex) LoadFromDocuments(ctx context.Context, documents []document.Document) error {
Expand Down Expand Up @@ -114,6 +106,12 @@ func (s *simpleVectorIndex) database() string {
}

func (s *simpleVectorIndex) IsEmpty() (bool, error) {

err := s.load()
if err != nil {
return true, fmt.Errorf("%s: %w", ErrInternal, err)
}

return len(s.data) == 0, nil
}

Expand Down
16 changes: 9 additions & 7 deletions llm/openai/openai.go
Expand Up @@ -54,28 +54,30 @@ type openAI struct {
callback OpenAICallback
}

func New(model Model, temperature float32, maxTokens int, verbose bool) (*openAI, error) {
func New(model Model, temperature float32, maxTokens int, verbose bool) *openAI {

openAIKey := os.Getenv("OPENAI_API_KEY")
if openAIKey == "" {
return nil, fmt.Errorf("OPENAI_API_KEY not set")
}

return &openAI{
openAIClient: openai.NewClient(openAIKey),
model: model,
temperature: temperature,
maxTokens: maxTokens,
verbose: verbose,
}, nil
}
}

func (o *openAI) WithStop(stop []string) *openAI {
o.stop = stop
return o
}

func NewCompletion() (*openAI, error) {
func (o *openAI) WithAPIKey(apiKey string) *openAI {
o.openAIClient = openai.NewClient(apiKey)
return o
}

func NewCompletion() *openAI {
return New(
GPT3TextDavinci003,
DefaultOpenAITemperature,
Expand All @@ -84,7 +86,7 @@ func NewCompletion() (*openAI, error) {
)
}

func NewChat() (*openAI, error) {
func NewChat() *openAI {
return New(
GPT3Dot5Turbo,
DefaultOpenAITemperature,
Expand Down

0 comments on commit 082fd2b

Please sign in to comment.