-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.go
98 lines (74 loc) · 2.37 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
)
const defaultRegion = "us-east-1"
const (
claudePromptFormat = "\n\nHuman:%s\n\nAssistant:"
claudeV2ModelID = "anthropic.claude-v2" //https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
)
const prompt = `<directory>
Phone directory:
John Latrabe, 800-232-1995, john909709@geemail.com
Josie Lana, 800-759-2905, josie@josielananier.com
Keven Stevens, 800-980-7000, drkevin22@geemail.com
Phone directory will be kept up to date by the HR manager."
<directory>
Please output the email addresses within the directory, one per line, in the order in which they appear within the text. If there are no email addresses in the text, output "N/A".`
func main() {
region := os.Getenv("AWS_REGION")
if region == "" {
region = defaultRegion
}
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region))
if err != nil {
log.Fatal(err)
}
brc := bedrockruntime.NewFromConfig(cfg)
payload := Request{
Prompt: fmt.Sprintf(claudePromptFormat, prompt),
MaxTokensToSample: 2048,
Temperature: 0.5,
TopK: 250,
TopP: 1,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
log.Fatal(err)
}
//log.Println("raw request", string(payloadBytes))
output, err := brc.InvokeModel(context.Background(), &bedrockruntime.InvokeModelInput{
Body: payloadBytes,
ModelId: aws.String(claudeV2ModelID),
ContentType: aws.String("application/json"),
})
if err != nil {
log.Fatal("failed to invoke model: ", err)
}
//log.Println("raw response ", string(output.Body))
var resp Response
err = json.Unmarshal(output.Body, &resp)
if err != nil {
log.Fatal("failed to unmarshal", err)
}
fmt.Println("response from LLM\n", resp.Completion)
}
//request/response model
type Request struct {
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
}
type Response struct {
Completion string `json:"completion"`
}