From e69518090ea9f5be6cbdd329bca846565c65134a Mon Sep 17 00:00:00 2001 From: grantdfoster Date: Mon, 29 Sep 2025 21:20:28 +0200 Subject: [PATCH 1/7] fix: default model --- args/llm.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/args/llm.go b/args/llm.go index 2f5c8c6..c38cae5 100644 --- a/args/llm.go +++ b/args/llm.go @@ -18,8 +18,9 @@ const ( LLMDefaultMaxTokens uint = 300 LLMDefaultTemperature float64 = 0.1 LLMDefaultMultipleColumns bool = false - LLMDefaultModel string = "gemini-1.5-flash-8b" - LLMDefaultItems uint = 1 + // LLMDefaultModel string = "gemini-1.5-flash-8b" + LLMDefaultModel string = "gemini-1.5-flash" + LLMDefaultItems uint = 1 ) type LLMProcessorArguments struct { From 3be328a57c3ed79da34c4d69c66e03c691942f80 Mon Sep 17 00:00:00 2001 From: grantdfoster Date: Mon, 29 Sep 2025 22:36:33 +0200 Subject: [PATCH 2/7] feat: support many models --- args/llm.go | 18 ++++++++++++------ args/llm_test.go | 3 ++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/args/llm.go b/args/llm.go index c38cae5..5386b30 100644 --- a/args/llm.go +++ b/args/llm.go @@ -18,11 +18,13 @@ const ( LLMDefaultMaxTokens uint = 300 LLMDefaultTemperature float64 = 0.1 LLMDefaultMultipleColumns bool = false - // LLMDefaultModel string = "gemini-1.5-flash-8b" - LLMDefaultModel string = "gemini-1.5-flash" - LLMDefaultItems uint = 1 + LLMDefaultGeminiModel string = "gemini-1.5-flash-8b" + LLMDefaultClaudeModel string = "claude-3-5-haiku-latest" + LLMDefaultItems uint = 1 ) +var SupportedModels = map[string]bool{LLMDefaultGeminiModel: true, LLMDefaultClaudeModel: true} + type LLMProcessorArguments struct { DatasetId string `json:"dataset_id"` Prompt string `json:"prompt"` @@ -72,13 +74,17 @@ func (l *LLMProcessorArguments) Validate() error { return nil } -func (l LLMProcessorArguments) ToLLMProcessorRequest() teetypes.LLMProcessorRequest { +func (l LLMProcessorArguments) ToLLMProcessorRequest(model string) (teetypes.LLMProcessorRequest, error) { + if !SupportedModels[model] { + return teetypes.LLMProcessorRequest{}, fmt.Errorf("model %s is not supported", model) + } + return teetypes.LLMProcessorRequest{ InputDatasetId: l.DatasetId, Prompt: l.Prompt, MaxTokens: l.MaxTokens, Temperature: strconv.FormatFloat(l.Temperature, 'f', -1, 64), MultipleColumns: LLMDefaultMultipleColumns, // overrides default in actor API - Model: LLMDefaultModel, // overrides default in actor API - } + Model: model, // overrides default in actor API + }, nil } diff --git a/args/llm_test.go b/args/llm_test.go index aa35128..05ede73 100644 --- a/args/llm_test.go +++ b/args/llm_test.go @@ -100,7 +100,8 @@ var _ = Describe("LLMProcessorArguments", func() { MaxTokens: 42, Temperature: 0.7, } - req := llmArgs.ToLLMProcessorRequest() + req, err := llmArgs.ToLLMProcessorRequest("gemini-1.5-flash-8b") + Expect(err).ToNot(HaveOccurred()) Expect(req.InputDatasetId).To(Equal("ds1")) Expect(req.Prompt).To(Equal("p")) Expect(req.MaxTokens).To(Equal(uint(42))) From 2487f86599117a196ee28d8d063384c277d27e24 Mon Sep 17 00:00:00 2001 From: grantdfoster Date: Mon, 29 Sep 2025 22:38:32 +0200 Subject: [PATCH 3/7] fix: be more idiomatic --- args/llm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/args/llm.go b/args/llm.go index 5386b30..e9b960b 100644 --- a/args/llm.go +++ b/args/llm.go @@ -75,7 +75,7 @@ func (l *LLMProcessorArguments) Validate() error { } func (l LLMProcessorArguments) ToLLMProcessorRequest(model string) (teetypes.LLMProcessorRequest, error) { - if !SupportedModels[model] { + if _, ok := SupportedModels[model]; !ok { return teetypes.LLMProcessorRequest{}, fmt.Errorf("model %s is not supported", model) } From 8bfdd0e58faae7e9d8d175002646486af26bae02 Mon Sep 17 00:00:00 2001 From: grantdfoster Date: Mon, 29 Sep 2025 22:55:34 +0200 Subject: [PATCH 4/7] chore: add dynamic key --- args/llm.go | 18 +++++++++++------- args/llm_test.go | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/args/llm.go b/args/llm.go index e9b960b..3baa0d1 100644 --- a/args/llm.go +++ b/args/llm.go @@ -74,17 +74,21 @@ func (l *LLMProcessorArguments) Validate() error { return nil } -func (l LLMProcessorArguments) ToLLMProcessorRequest(model string) (teetypes.LLMProcessorRequest, error) { +func (l LLMProcessorArguments) ToLLMProcessorRequest(model string, key string) (teetypes.LLMProcessorRequest, error) { if _, ok := SupportedModels[model]; !ok { return teetypes.LLMProcessorRequest{}, fmt.Errorf("model %s is not supported", model) } + if key == "" { + return teetypes.LLMProcessorRequest{}, fmt.Errorf("key is required") + } return teetypes.LLMProcessorRequest{ - InputDatasetId: l.DatasetId, - Prompt: l.Prompt, - MaxTokens: l.MaxTokens, - Temperature: strconv.FormatFloat(l.Temperature, 'f', -1, 64), - MultipleColumns: LLMDefaultMultipleColumns, // overrides default in actor API - Model: model, // overrides default in actor API + InputDatasetId: l.DatasetId, + LLMProviderApiKey: key, + Prompt: l.Prompt, + MaxTokens: l.MaxTokens, + Temperature: strconv.FormatFloat(l.Temperature, 'f', -1, 64), + MultipleColumns: LLMDefaultMultipleColumns, // overrides default in actor API + Model: model, // overrides default in actor API }, nil } diff --git a/args/llm_test.go b/args/llm_test.go index 05ede73..e22179a 100644 --- a/args/llm_test.go +++ b/args/llm_test.go @@ -100,7 +100,7 @@ var _ = Describe("LLMProcessorArguments", func() { MaxTokens: 42, Temperature: 0.7, } - req, err := llmArgs.ToLLMProcessorRequest("gemini-1.5-flash-8b") + req, err := llmArgs.ToLLMProcessorRequest("gemini-1.5-flash-8b", "api-key") Expect(err).ToNot(HaveOccurred()) Expect(req.InputDatasetId).To(Equal("ds1")) Expect(req.Prompt).To(Equal("p")) From f74666805e1bbcf6ed409a8bec69f05119a2fe4f Mon Sep 17 00:00:00 2001 From: grantdfoster Date: Tue, 30 Sep 2025 17:08:40 +0200 Subject: [PATCH 5/7] chore: use util set --- args/llm.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/args/llm.go b/args/llm.go index 3baa0d1..d3e4ac8 100644 --- a/args/llm.go +++ b/args/llm.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" + "github.com/masa-finance/tee-types/pkg/util" teetypes "github.com/masa-finance/tee-types/types" ) @@ -23,7 +24,7 @@ const ( LLMDefaultItems uint = 1 ) -var SupportedModels = map[string]bool{LLMDefaultGeminiModel: true, LLMDefaultClaudeModel: true} +var SupportedModels = util.NewSet(LLMDefaultGeminiModel, LLMDefaultClaudeModel) type LLMProcessorArguments struct { DatasetId string `json:"dataset_id"` @@ -75,7 +76,7 @@ func (l *LLMProcessorArguments) Validate() error { } func (l LLMProcessorArguments) ToLLMProcessorRequest(model string, key string) (teetypes.LLMProcessorRequest, error) { - if _, ok := SupportedModels[model]; !ok { + if !SupportedModels.Contains(model) { return teetypes.LLMProcessorRequest{}, fmt.Errorf("model %s is not supported", model) } if key == "" { From 16c53f48d56de90577bbf76843146c9a285b220b Mon Sep 17 00:00:00 2001 From: grantdfoster Date: Tue, 30 Sep 2025 17:09:43 +0200 Subject: [PATCH 6/7] chore: fix llm test --- args/llm_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/args/llm_test.go b/args/llm_test.go index e22179a..1ace964 100644 --- a/args/llm_test.go +++ b/args/llm_test.go @@ -100,14 +100,14 @@ var _ = Describe("LLMProcessorArguments", func() { MaxTokens: 42, Temperature: 0.7, } - req, err := llmArgs.ToLLMProcessorRequest("gemini-1.5-flash-8b", "api-key") + req, err := llmArgs.ToLLMProcessorRequest(args.LLMDefaultGeminiModel, "api-key") Expect(err).ToNot(HaveOccurred()) Expect(req.InputDatasetId).To(Equal("ds1")) Expect(req.Prompt).To(Equal("p")) Expect(req.MaxTokens).To(Equal(uint(42))) Expect(req.Temperature).To(Equal("0.7")) Expect(req.MultipleColumns).To(BeFalse()) - Expect(req.Model).To(Equal("gemini-1.5-flash-8b")) + Expect(req.Model).To(Equal(args.LLMDefaultGeminiModel)) }) }) }) From 19ca0ab12ecf427b7124aba8caae415719d4c871 Mon Sep 17 00:00:00 2001 From: grantdfoster Date: Tue, 30 Sep 2025 17:15:45 +0200 Subject: [PATCH 7/7] chore: add key test --- args/llm_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/args/llm_test.go b/args/llm_test.go index 1ace964..a9b02c2 100644 --- a/args/llm_test.go +++ b/args/llm_test.go @@ -108,6 +108,7 @@ var _ = Describe("LLMProcessorArguments", func() { Expect(req.Temperature).To(Equal("0.7")) Expect(req.MultipleColumns).To(BeFalse()) Expect(req.Model).To(Equal(args.LLMDefaultGeminiModel)) + Expect(req.LLMProviderApiKey).To(Equal("api-key")) }) }) })