diff --git a/args/llm.go b/args/llm.go index 2f5c8c6..d3e4ac8 100644 --- a/args/llm.go +++ b/args/llm.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" + "github.com/masa-finance/tee-types/pkg/util" teetypes "github.com/masa-finance/tee-types/types" ) @@ -18,10 +19,13 @@ const ( LLMDefaultMaxTokens uint = 300 LLMDefaultTemperature float64 = 0.1 LLMDefaultMultipleColumns bool = false - LLMDefaultModel string = "gemini-1.5-flash-8b" + LLMDefaultGeminiModel string = "gemini-1.5-flash-8b" + LLMDefaultClaudeModel string = "claude-3-5-haiku-latest" LLMDefaultItems uint = 1 ) +var SupportedModels = util.NewSet(LLMDefaultGeminiModel, LLMDefaultClaudeModel) + type LLMProcessorArguments struct { DatasetId string `json:"dataset_id"` Prompt string `json:"prompt"` @@ -71,13 +75,21 @@ func (l *LLMProcessorArguments) Validate() error { return nil } -func (l LLMProcessorArguments) ToLLMProcessorRequest() teetypes.LLMProcessorRequest { - return teetypes.LLMProcessorRequest{ - InputDatasetId: l.DatasetId, - Prompt: l.Prompt, - MaxTokens: l.MaxTokens, - Temperature: strconv.FormatFloat(l.Temperature, 'f', -1, 64), - MultipleColumns: LLMDefaultMultipleColumns, // overrides default in actor API - Model: LLMDefaultModel, // overrides default in actor API +func (l LLMProcessorArguments) ToLLMProcessorRequest(model string, key string) (teetypes.LLMProcessorRequest, error) { + if !SupportedModels.Contains(model) { + return teetypes.LLMProcessorRequest{}, fmt.Errorf("model %s is not supported", model) + } + if key == "" { + return teetypes.LLMProcessorRequest{}, fmt.Errorf("key is required") } + + return teetypes.LLMProcessorRequest{ + InputDatasetId: l.DatasetId, + LLMProviderApiKey: key, + Prompt: l.Prompt, + MaxTokens: l.MaxTokens, + Temperature: strconv.FormatFloat(l.Temperature, 'f', -1, 64), + MultipleColumns: LLMDefaultMultipleColumns, // overrides default in actor API + Model: model, // overrides default in actor API + }, nil } diff --git a/args/llm_test.go b/args/llm_test.go index aa35128..a9b02c2 100644 --- a/args/llm_test.go +++ b/args/llm_test.go @@ -100,13 +100,15 @@ var _ = Describe("LLMProcessorArguments", func() { MaxTokens: 42, Temperature: 0.7, } - req := llmArgs.ToLLMProcessorRequest() + req, err := llmArgs.ToLLMProcessorRequest(args.LLMDefaultGeminiModel, "api-key") + Expect(err).ToNot(HaveOccurred()) Expect(req.InputDatasetId).To(Equal("ds1")) Expect(req.Prompt).To(Equal("p")) Expect(req.MaxTokens).To(Equal(uint(42))) Expect(req.Temperature).To(Equal("0.7")) Expect(req.MultipleColumns).To(BeFalse()) - Expect(req.Model).To(Equal("gemini-1.5-flash-8b")) + Expect(req.Model).To(Equal(args.LLMDefaultGeminiModel)) + Expect(req.LLMProviderApiKey).To(Equal("api-key")) }) }) })