-
Notifications
You must be signed in to change notification settings - Fork 0
/
google_gemini.go
114 lines (93 loc) · 3.08 KB
/
google_gemini.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package gemini
import (
"context"
"fmt"
"os"
"github.com/darmenliu/nuwa-terminal-chat/pkg/llms"
"github.com/google/generative-ai-go/genai"
lcllms "github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/googleai"
"google.golang.org/api/option"
)
// Gemini is a wrapper around the Gemini API.
type Gemini struct {
Client *genai.Client
Model *genai.GenerativeModel
google *googleai.GoogleAI
chatHistory []lcllms.MessageContent
SystemPrompt string
}
// NewGemini returns a new Gemini client.
func NewGemini(ctx context.Context, modelName string, systemPrompt string) (llms.Model, error) {
genaiKey := os.Getenv("GEMINI_API_KEY")
if genaiKey == "" {
return nil, fmt.Errorf("GEMINI_API_KEY is not set")
}
llm, err := googleai.New(ctx, googleai.WithAPIKey(genaiKey), googleai.WithDefaultModel(modelName))
if err != nil {
return nil, fmt.Errorf("Failed to create GoogleAI client: %w", err)
}
// Access your API key as an environment variable (see "Set up your API key" above)
client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY")))
if err != nil {
return nil, fmt.Errorf("Failed to create Gemini client: %w", err)
}
model := client.GenerativeModel(modelName)
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text(systemPrompt)},
}
content := []lcllms.MessageContent{
lcllms.TextParts(lcllms.ChatMessageTypeSystem, systemPrompt),
}
return &Gemini{
Client: client,
Model: model,
google: llm,
chatHistory: content,
SystemPrompt: systemPrompt,
}, nil
}
// Cotent to string
func (g *Gemini) ContentToString(content *genai.Content) string {
var str string
for _, part := range content.Parts {
// Get interface part type, and check if it is Text
if _, ok := part.(genai.Text); ok {
str += string(part.(genai.Text))
}
}
return str
}
// GenerateContent generates content from a prompt.
func (g *Gemini) GenerateContent(ctx context.Context, prompt string) (string, error) {
resp, err := g.Model.GenerateContent(ctx, genai.Text(prompt))
if err != nil {
return "", fmt.Errorf("Failed to generate content: %w", err)
}
// convert resp to string
return g.ContentToString(resp.Candidates[0].Content), nil
}
// Chat with the model.
func (g *Gemini) Chat(ctx context.Context, message string) (string, error) {
// Add the message to the chat history
//prompt := g.SystemPrompt + message
g.chatHistory = append(g.chatHistory, lcllms.TextParts(lcllms.ChatMessageTypeHuman, message))
resp, err := g.google.GenerateContent(ctx, g.chatHistory)
if err != nil {
return "", fmt.Errorf("Failed to generate content: %w", err)
}
// Add the assistant's response to the chat history
respchoice := resp.Choices[0]
assistantResponse := lcllms.TextParts(lcllms.ChatMessageTypeAI, respchoice.Content)
g.chatHistory = append(g.chatHistory, assistantResponse)
return respchoice.Content, nil
}
// Set system prompt
func (g *Gemini) SetSystemPrompt(ctx context.Context, prompt string) error {
g.SystemPrompt = prompt
return nil
}
// Close the client.
func (g *Gemini) CloseBackend() error {
return g.Client.Close()
}