-
Notifications
You must be signed in to change notification settings - Fork 804
/
options.go
60 lines (51 loc) · 1.36 KB
/
options.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
package llm
// ParamOption is a function that configures a CallOptions.
type ParamOption func(*ParamOptions)
// ParamOptions is a set of options.
type ParamOptions struct {
// Model is the model to use.
Model string `json:"model"`
// MaxTokens is the maximum number of tokens to generate.
MaxTokens int `json:"max_tokens"`
// Temperature is the temperature for sampling, between 0 and 1.
Temperature float32 `json:"temperature"`
// StopWords is a list of words to stop on.
StopWords []string `json:"stop_words"`
LogitBias map[string]int `json:"logit_bias"`
}
func WithModel(model string) ParamOption {
return func(o *ParamOptions) {
o.Model = model
}
}
func WithMaxTokens(maxTokens int) ParamOption {
return func(o *ParamOptions) {
o.MaxTokens = maxTokens
}
}
func WithTemperature(temperature float32) ParamOption {
return func(o *ParamOptions) {
o.Temperature = temperature
}
}
func WithStopWords(stopWords []string) ParamOption {
return func(o *ParamOptions) {
o.StopWords = stopWords
}
}
func WithLogitBias(logitBias map[string]int) ParamOption {
return func(o *ParamOptions) {
o.LogitBias = logitBias
}
}
func WithOptions(options ParamOptions) ParamOption {
return func(o *ParamOptions) {
(*o) = options
}
}
func ValidOptions(options ParamOptions) ParamOptions {
if len(options.StopWords) == 0 {
options.StopWords = nil
}
return options
}