-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprompts.go
72 lines (64 loc) · 1.86 KB
/
prompts.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
package chatmodels
import (
"context"
"fmt"
"github.com/jackmcguire1/alexa-chatgpt/internal/pkg/utils"
)
func (client *Client) AutoComplete(ctx context.Context, prompt string, model ChatModel) (string, error) {
switch model {
case CHAT_MODEL_GEMINI:
res, err := client.GeminiAPI.GeminiChat(ctx, prompt)
if err != nil {
return "", err
}
if len(res.Candidates) > 0 && len(res.Candidates[0].Content.Parts) > 0 {
return fmt.Sprint(res.Candidates[0].Content.Parts[0]), nil
}
return "", fmt.Errorf("did not get enough info back from google %s", utils.ToJSON(res))
case CHAT_MODEL_META, CHAT_MODEL_SQL, CHAT_MODEL_OPEN, CHAT_MODEL_AWQ, CHAT_MODEL_QWEN:
return client.CloudflareApiClient.GenerateText(ctx, prompt, CHAT_MODEL_TO_CF_MODEL[model])
default:
resp, err := client.GPTApi.AutoComplete(ctx, prompt)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
err = fmt.Errorf("missing choices")
return "", err
}
message := resp.Choices[0].Message.Content
return message, nil
}
}
func (client *Client) GenerateImage(ctx context.Context, prompt string, model ChatModel) ([]byte, error) {
switch model {
case CHAT_MODEL_STABLE_DIFFUSION:
fallthrough
default:
return client.CloudflareApiClient.GenerateImage(ctx, prompt, CHAT_MODEL_TO_CF_MODEL[model])
}
return nil, fmt.Errorf("unidentified image generation model")
}
func (client *Client) Translate(
ctx context.Context,
prompt string,
sourceLang string,
targetLang string,
model ChatModel,
) (string, error) {
if sourceLang == "" {
sourceLang = "en"
}
if targetLang == "" {
targetLang = "jp"
}
if model == "" {
model = CHAT_MODEL_TRANSLATIONS
}
return client.CloudflareApiClient.GenerateTranslation(ctx, &GenerateTranslationRequest{
SourceLanguage: sourceLang,
TargetLanguage: targetLang,
Prompt: prompt,
Model: CHAT_MODEL_TO_CF_MODEL[model],
})
}