-
Notifications
You must be signed in to change notification settings - Fork 1
/
gemini.go
176 lines (153 loc) · 4.61 KB
/
gemini.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
package model
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"mime"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"cloud.google.com/go/vertexai/genai"
)
// UseGeminiModel calls Gemini's generate content method
func UseGeminiModel(ctx context.Context, modelName string, cfg Config, args []string) error {
log.Printf("Gemini [%s]", modelName)
var promptParts []genai.Part
for _, arg := range args {
if argLooksLikeGCSURL(arg) {
part := genai.FileData{
MIMEType: mime.TypeByExtension(filepath.Ext(arg)),
FileURI: arg,
}
promptParts = append(promptParts, part)
} else if argLooksLikeURL(arg) {
part, err := getPartFromURL(arg)
if err != nil {
return err
}
promptParts = append(promptParts, part)
} else if argLooksLikeFilename(arg) {
part, err := getPartFromFile(arg)
if err != nil {
return err
}
promptParts = append(promptParts, part)
} else {
promptParts = append(promptParts, genai.Text(arg))
}
}
var buf bytes.Buffer
if err := GenerateContentGemini(ctx, modelName, cfg, &buf, promptParts); err != nil {
log.Printf("error generating content: %v", err)
os.Exit(1)
}
fmt.Printf("%s\n", buf.String())
return nil
}
// GenerateContentGemini calls Gemini's generate content method
func GenerateContentGemini(ctx context.Context, modelName string, cfg Config, w io.Writer, parts []genai.Part) error {
// TODO - There are differences between this function and the matching function in palm.go
// due to when the config file contents are read.
// TODO - Unlike matching functions in palm.go and anthropic.go, this one is public. Should the
// others be made public or should this one be made private.
client, err := genai.NewClient(ctx, cfg.ProjectID, cfg.RegionID)
if err != nil {
return fmt.Errorf("error creating a client: %v", err)
}
defer client.Close()
gemini := client.GenerativeModel(modelName)
if cfg.ConfigFile != "" {
modelConfig, err := os.ReadFile(cfg.ConfigFile)
if err != nil {
return fmt.Errorf("error reading model config file: %w", err)
}
var config genai.GenerationConfig
err = json.Unmarshal(modelConfig, &config)
if err != nil {
return fmt.Errorf("error unmarshalling GenerationConfig from file: %w", err)
}
gemini.GenerationConfig = config
if cfg.LogType != "none" {
log.Printf("config: %v", config)
}
}
resp, err := gemini.GenerateContent(ctx, parts...)
if err != nil {
// needs more sensible parsing of error message
if strings.Contains(err.Error(), "lookup -aiplatform.googleapis.com:") {
log.Print("missing REGION")
}
if strings.Contains(err.Error(), "RESOURCE_PROJECT_INVALID") {
log.Print("missing PROJECT_ID")
}
return fmt.Errorf("error generating content: %w", err)
}
if cfg.OutputType == "json" {
rb, _ := json.MarshalIndent(resp, "", " ")
fmt.Fprintln(w, string(rb))
} else {
if len(resp.Candidates) > 0 {
var all []string
for _, v := range resp.Candidates[0].Content.Parts {
all = append(all, fmt.Sprintf("%s", v))
}
fmt.Fprintf(w, "%s", strings.Join(all, " "))
} else {
log.Printf("Candidate length 0")
}
}
return nil
}
// thanks to eilben's https://github.com/eliben/gemini-cli/blob/main/internal/commands/prompt.go
// argLooksLikeFilename says if command-line argument looks like a filename,
// which we consider to have an alphabetical extension following a dot separator,
// but not look like a URL.
func argLooksLikeFilename(arg string) bool {
re := regexp.MustCompile(`\.[a-zA-Z]+$`)
return re.MatchString(arg) && strings.Index(arg, "://") < 0
}
func argLooksLikeGCSURL(arg string) bool {
return strings.HasPrefix(arg, "gs://")
}
func argLooksLikeURL(arg string) bool {
_, err := url.ParseRequestURI(arg)
return err == nil
}
func getPartFromFile(path string) (genai.Part, error) {
b, err := os.ReadFile(path)
if err != nil {
return nil, err
}
ext := filepath.Ext(path)
switch strings.TrimSpace(ext) {
case ".jpg", ".jpeg":
return genai.ImageData("jpeg", b), nil
case ".png":
return genai.ImageData("png", b), nil
default:
return nil, fmt.Errorf("invalid image file extension: %s", ext)
}
}
func getPartFromURL(url string) (genai.Part, error) {
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to fetch image from url: %w", err)
}
defer resp.Body.Close()
urlData, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read image bytes: %w", err)
}
mimeType := resp.Header.Get("Content-Type")
parts := strings.Split(mimeType, "/")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid mime type %v", mimeType)
}
return genai.ImageData(parts[1], urlData), nil
}