diff --git a/args/llm.go b/args/llm.go index 1816094..2f5c8c6 100644 --- a/args/llm.go +++ b/args/llm.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "strconv" teetypes "github.com/masa-finance/tee-types/types" ) @@ -11,21 +12,22 @@ import ( var ( ErrLLMDatasetIdRequired = errors.New("dataset id is required") ErrLLMPromptRequired = errors.New("prompt is required") - ErrLLMMaxTokensNegative = errors.New("max tokens must be non-negative") ) const ( - LLMDefaultMaxTokens = 300 - LLMDefaultTemperature = "0.1" - LLMDefaultMultipleColumns = false - LLMDefaultModel = "gemini-1.5-flash-8b" + LLMDefaultMaxTokens uint = 300 + LLMDefaultTemperature float64 = 0.1 + LLMDefaultMultipleColumns bool = false + LLMDefaultModel string = "gemini-1.5-flash-8b" + LLMDefaultItems uint = 1 ) type LLMProcessorArguments struct { - DatasetId string `json:"dataset_id"` - Prompt string `json:"prompt"` - MaxTokens int `json:"max_tokens"` - Temperature string `json:"temperature"` + DatasetId string `json:"dataset_id"` + Prompt string `json:"prompt"` + MaxTokens uint `json:"max_tokens"` + Temperature float64 `json:"temperature"` + Items uint `json:"items"` } // UnmarshalJSON implements custom JSON unmarshaling with validation @@ -48,11 +50,14 @@ func (l *LLMProcessorArguments) UnmarshalJSON(data []byte) error { } func (l *LLMProcessorArguments) setDefaultValues() { + if l.Temperature == 0 { + l.Temperature = LLMDefaultTemperature + } if l.MaxTokens == 0 { l.MaxTokens = LLMDefaultMaxTokens } - if l.Temperature == "" { - l.Temperature = LLMDefaultTemperature + if l.Items == 0 { + l.Items = LLMDefaultItems } } @@ -63,9 +68,6 @@ func (l *LLMProcessorArguments) Validate() error { if l.Prompt == "" { return ErrLLMPromptRequired } - if l.MaxTokens < 0 { - return fmt.Errorf("%w: got %v", ErrLLMMaxTokensNegative, l.MaxTokens) - } return nil } @@ -74,7 +76,7 @@ func (l LLMProcessorArguments) ToLLMProcessorRequest() teetypes.LLMProcessorRequ InputDatasetId: l.DatasetId, Prompt: l.Prompt, MaxTokens: l.MaxTokens, - Temperature: l.Temperature, + Temperature: strconv.FormatFloat(l.Temperature, 'f', -1, 64), MultipleColumns: LLMDefaultMultipleColumns, // overrides default in actor API Model: LLMDefaultModel, // overrides default in actor API } diff --git a/args/llm_test.go b/args/llm_test.go index 3884ebf..aa35128 100644 --- a/args/llm_test.go +++ b/args/llm_test.go @@ -21,8 +21,9 @@ var _ = Describe("LLMProcessorArguments", func() { Expect(err).ToNot(HaveOccurred()) err = json.Unmarshal([]byte(jsonData), &llmArgs) Expect(err).ToNot(HaveOccurred()) - Expect(llmArgs.MaxTokens).To(Equal(300)) - Expect(llmArgs.Temperature).To(Equal("0.1")) + Expect(llmArgs.Temperature).To(Equal(0.1)) + Expect(llmArgs.MaxTokens).To(Equal(uint(300))) + Expect(llmArgs.Items).To(Equal(uint(1))) }) It("should override default values", func() { @@ -30,14 +31,16 @@ var _ = Describe("LLMProcessorArguments", func() { DatasetId: "ds1", Prompt: "summarize: ${markdown}", MaxTokens: 123, - Temperature: "0.7", + Temperature: 0.7, + Items: 3, } jsonData, err := json.Marshal(llmArgs) Expect(err).ToNot(HaveOccurred()) err = json.Unmarshal([]byte(jsonData), &llmArgs) Expect(err).ToNot(HaveOccurred()) - Expect(llmArgs.MaxTokens).To(Equal(123)) - Expect(llmArgs.Temperature).To(Equal("0.7")) + Expect(llmArgs.Temperature).To(Equal(0.7)) + Expect(llmArgs.MaxTokens).To(Equal(uint(123))) + Expect(llmArgs.Items).To(Equal(uint(3))) }) It("should fail unmarshal when dataset_id is missing", func() { @@ -61,7 +64,8 @@ var _ = Describe("LLMProcessorArguments", func() { DatasetId: "ds1", Prompt: "p", MaxTokens: 10, - Temperature: "0.2", + Temperature: 0.2, + Items: 1, } err := llmArgs.Validate() Expect(err).ToNot(HaveOccurred()) @@ -71,7 +75,7 @@ var _ = Describe("LLMProcessorArguments", func() { llmArgs := &args.LLMProcessorArguments{ Prompt: "p", MaxTokens: 10, - Temperature: "0.2", + Temperature: 0.2, } err := llmArgs.Validate() Expect(errors.Is(err, args.ErrLLMDatasetIdRequired)).To(BeTrue()) @@ -81,53 +85,25 @@ var _ = Describe("LLMProcessorArguments", func() { llmArgs := &args.LLMProcessorArguments{ DatasetId: "ds1", MaxTokens: 10, - Temperature: "0.2", + Temperature: 0.2, } err := llmArgs.Validate() Expect(errors.Is(err, args.ErrLLMPromptRequired)).To(BeTrue()) }) - - It("should fail when max tokens is negative", func() { - llmArgs := &args.LLMProcessorArguments{ - DatasetId: "ds1", - Prompt: "p", - MaxTokens: -1, - Temperature: "0.2", - } - err := llmArgs.Validate() - Expect(errors.Is(err, args.ErrLLMMaxTokensNegative)).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("got -1")) - }) }) Describe("ToLLMProcessorRequest", func() { - It("should map fields and defaults correctly", func() { - llmArgs := args.LLMProcessorArguments{ - DatasetId: "ds1", - Prompt: "p", - MaxTokens: 0, // default applied in To* - Temperature: "", - } - req := llmArgs.ToLLMProcessorRequest() - Expect(req.InputDatasetId).To(Equal("ds1")) - Expect(req.Prompt).To(Equal("p")) - Expect(req.MaxTokens).To(Equal(0)) - Expect(req.Temperature).To(Equal("")) - Expect(req.MultipleColumns).To(BeFalse()) - Expect(req.Model).To(Equal("gemini-1.5-flash-8b")) - }) - - It("should map fields correctly when set", func() { + It("should map request fields to actor request fields", func() { llmArgs := args.LLMProcessorArguments{ DatasetId: "ds1", Prompt: "p", MaxTokens: 42, - Temperature: "0.7", + Temperature: 0.7, } req := llmArgs.ToLLMProcessorRequest() Expect(req.InputDatasetId).To(Equal("ds1")) Expect(req.Prompt).To(Equal("p")) - Expect(req.MaxTokens).To(Equal(42)) + Expect(req.MaxTokens).To(Equal(uint(42))) Expect(req.Temperature).To(Equal("0.7")) Expect(req.MultipleColumns).To(BeFalse()) Expect(req.Model).To(Equal("gemini-1.5-flash-8b")) diff --git a/types/llm.go b/types/llm.go index fb67693..ca99075 100644 --- a/types/llm.go +++ b/types/llm.go @@ -5,9 +5,9 @@ type LLMProcessorRequest struct { LLMProviderApiKey string `json:"llmProviderApiKey"` // encrypted api key by miner Model string `json:"model"` MultipleColumns bool `json:"multipleColumns"` - Prompt string `json:"prompt"` // example: summarize the content of this webpage: ${markdown} - Temperature string `json:"temperature"` - MaxTokens int `json:"maxTokens"` + Prompt string `json:"prompt"` // example: summarize the content of this webpage: ${markdown} + Temperature string `json:"temperature"` // the actor expects a string + MaxTokens uint `json:"maxTokens"` } type LLMProcessorResult struct {