Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions args/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"strconv"

"github.com/masa-finance/tee-types/pkg/util"
teetypes "github.com/masa-finance/tee-types/types"
)

Expand All @@ -18,10 +19,13 @@ const (
LLMDefaultMaxTokens uint = 300
LLMDefaultTemperature float64 = 0.1
LLMDefaultMultipleColumns bool = false
LLMDefaultModel string = "gemini-1.5-flash-8b"
LLMDefaultGeminiModel string = "gemini-1.5-flash-8b"
LLMDefaultClaudeModel string = "claude-3-5-haiku-latest"
LLMDefaultItems uint = 1
)

var SupportedModels = util.NewSet(LLMDefaultGeminiModel, LLMDefaultClaudeModel)

type LLMProcessorArguments struct {
DatasetId string `json:"dataset_id"`
Prompt string `json:"prompt"`
Expand Down Expand Up @@ -71,13 +75,21 @@ func (l *LLMProcessorArguments) Validate() error {
return nil
}

func (l LLMProcessorArguments) ToLLMProcessorRequest() teetypes.LLMProcessorRequest {
return teetypes.LLMProcessorRequest{
InputDatasetId: l.DatasetId,
Prompt: l.Prompt,
MaxTokens: l.MaxTokens,
Temperature: strconv.FormatFloat(l.Temperature, 'f', -1, 64),
MultipleColumns: LLMDefaultMultipleColumns, // overrides default in actor API
Model: LLMDefaultModel, // overrides default in actor API
func (l LLMProcessorArguments) ToLLMProcessorRequest(model string, key string) (teetypes.LLMProcessorRequest, error) {
if !SupportedModels.Contains(model) {
return teetypes.LLMProcessorRequest{}, fmt.Errorf("model %s is not supported", model)
}
if key == "" {
return teetypes.LLMProcessorRequest{}, fmt.Errorf("key is required")
}

return teetypes.LLMProcessorRequest{
InputDatasetId: l.DatasetId,
LLMProviderApiKey: key,
Prompt: l.Prompt,
MaxTokens: l.MaxTokens,
Temperature: strconv.FormatFloat(l.Temperature, 'f', -1, 64),
MultipleColumns: LLMDefaultMultipleColumns, // overrides default in actor API
Model: model, // overrides default in actor API
}, nil
}
6 changes: 4 additions & 2 deletions args/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ var _ = Describe("LLMProcessorArguments", func() {
MaxTokens: 42,
Temperature: 0.7,
}
req := llmArgs.ToLLMProcessorRequest()
req, err := llmArgs.ToLLMProcessorRequest(args.LLMDefaultGeminiModel, "api-key")
Expect(err).ToNot(HaveOccurred())
Expect(req.InputDatasetId).To(Equal("ds1"))
Expect(req.Prompt).To(Equal("p"))
Expect(req.MaxTokens).To(Equal(uint(42)))
Expect(req.Temperature).To(Equal("0.7"))
Expect(req.MultipleColumns).To(BeFalse())
Expect(req.Model).To(Equal("gemini-1.5-flash-8b"))
Expect(req.Model).To(Equal(args.LLMDefaultGeminiModel))
Expect(req.LLMProviderApiKey).To(Equal("api-key"))
})
})
})