Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions args/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,30 @@ import (
"encoding/json"
"errors"
"fmt"
"strconv"

teetypes "github.com/masa-finance/tee-types/types"
)

var (
ErrLLMDatasetIdRequired = errors.New("dataset id is required")
ErrLLMPromptRequired = errors.New("prompt is required")
ErrLLMMaxTokensNegative = errors.New("max tokens must be non-negative")
)

const (
LLMDefaultMaxTokens = 300
LLMDefaultTemperature = "0.1"
LLMDefaultMultipleColumns = false
LLMDefaultModel = "gemini-1.5-flash-8b"
LLMDefaultMaxTokens uint = 300
LLMDefaultTemperature float64 = 0.1
LLMDefaultMultipleColumns bool = false
LLMDefaultModel string = "gemini-1.5-flash-8b"
LLMDefaultItems uint = 1
)

type LLMProcessorArguments struct {
DatasetId string `json:"dataset_id"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
Temperature string `json:"temperature"`
DatasetId string `json:"dataset_id"`
Prompt string `json:"prompt"`
MaxTokens uint `json:"max_tokens"`
Temperature float64 `json:"temperature"`
Items uint `json:"items"`
}

// UnmarshalJSON implements custom JSON unmarshaling with validation
Expand All @@ -48,11 +50,14 @@ func (l *LLMProcessorArguments) UnmarshalJSON(data []byte) error {
}

func (l *LLMProcessorArguments) setDefaultValues() {
if l.Temperature == 0 {
l.Temperature = LLMDefaultTemperature
}
if l.MaxTokens == 0 {
l.MaxTokens = LLMDefaultMaxTokens
}
if l.Temperature == "" {
l.Temperature = LLMDefaultTemperature
if l.Items == 0 {
l.Items = LLMDefaultItems
}
}

Expand All @@ -63,9 +68,6 @@ func (l *LLMProcessorArguments) Validate() error {
if l.Prompt == "" {
return ErrLLMPromptRequired
}
if l.MaxTokens < 0 {
return fmt.Errorf("%w: got %v", ErrLLMMaxTokensNegative, l.MaxTokens)
}
return nil
}

Expand All @@ -74,7 +76,7 @@ func (l LLMProcessorArguments) ToLLMProcessorRequest() teetypes.LLMProcessorRequ
InputDatasetId: l.DatasetId,
Prompt: l.Prompt,
MaxTokens: l.MaxTokens,
Temperature: l.Temperature,
Temperature: strconv.FormatFloat(l.Temperature, 'f', -1, 64),
MultipleColumns: LLMDefaultMultipleColumns, // overrides default in actor API
Model: LLMDefaultModel, // overrides default in actor API
}
Expand Down
54 changes: 15 additions & 39 deletions args/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,26 @@ var _ = Describe("LLMProcessorArguments", func() {
Expect(err).ToNot(HaveOccurred())
err = json.Unmarshal([]byte(jsonData), &llmArgs)
Expect(err).ToNot(HaveOccurred())
Expect(llmArgs.MaxTokens).To(Equal(300))
Expect(llmArgs.Temperature).To(Equal("0.1"))
Expect(llmArgs.Temperature).To(Equal(0.1))
Expect(llmArgs.MaxTokens).To(Equal(uint(300)))
Expect(llmArgs.Items).To(Equal(uint(1)))
})

It("should override default values", func() {
llmArgs := args.LLMProcessorArguments{
DatasetId: "ds1",
Prompt: "summarize: ${markdown}",
MaxTokens: 123,
Temperature: "0.7",
Temperature: 0.7,
Items: 3,
}
jsonData, err := json.Marshal(llmArgs)
Expect(err).ToNot(HaveOccurred())
err = json.Unmarshal([]byte(jsonData), &llmArgs)
Expect(err).ToNot(HaveOccurred())
Expect(llmArgs.MaxTokens).To(Equal(123))
Expect(llmArgs.Temperature).To(Equal("0.7"))
Expect(llmArgs.Temperature).To(Equal(0.7))
Expect(llmArgs.MaxTokens).To(Equal(uint(123)))
Expect(llmArgs.Items).To(Equal(uint(3)))
})

It("should fail unmarshal when dataset_id is missing", func() {
Expand All @@ -61,7 +64,8 @@ var _ = Describe("LLMProcessorArguments", func() {
DatasetId: "ds1",
Prompt: "p",
MaxTokens: 10,
Temperature: "0.2",
Temperature: 0.2,
Items: 1,
}
err := llmArgs.Validate()
Expect(err).ToNot(HaveOccurred())
Expand All @@ -71,7 +75,7 @@ var _ = Describe("LLMProcessorArguments", func() {
llmArgs := &args.LLMProcessorArguments{
Prompt: "p",
MaxTokens: 10,
Temperature: "0.2",
Temperature: 0.2,
}
err := llmArgs.Validate()
Expect(errors.Is(err, args.ErrLLMDatasetIdRequired)).To(BeTrue())
Expand All @@ -81,53 +85,25 @@ var _ = Describe("LLMProcessorArguments", func() {
llmArgs := &args.LLMProcessorArguments{
DatasetId: "ds1",
MaxTokens: 10,
Temperature: "0.2",
Temperature: 0.2,
}
err := llmArgs.Validate()
Expect(errors.Is(err, args.ErrLLMPromptRequired)).To(BeTrue())
})

It("should fail when max tokens is negative", func() {
llmArgs := &args.LLMProcessorArguments{
DatasetId: "ds1",
Prompt: "p",
MaxTokens: -1,
Temperature: "0.2",
}
err := llmArgs.Validate()
Expect(errors.Is(err, args.ErrLLMMaxTokensNegative)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("got -1"))
})
})

Describe("ToLLMProcessorRequest", func() {
It("should map fields and defaults correctly", func() {
llmArgs := args.LLMProcessorArguments{
DatasetId: "ds1",
Prompt: "p",
MaxTokens: 0, // default applied in To*
Temperature: "",
}
req := llmArgs.ToLLMProcessorRequest()
Expect(req.InputDatasetId).To(Equal("ds1"))
Expect(req.Prompt).To(Equal("p"))
Expect(req.MaxTokens).To(Equal(0))
Expect(req.Temperature).To(Equal(""))
Expect(req.MultipleColumns).To(BeFalse())
Expect(req.Model).To(Equal("gemini-1.5-flash-8b"))
})

It("should map fields correctly when set", func() {
It("should map request fields to actor request fields", func() {
llmArgs := args.LLMProcessorArguments{
DatasetId: "ds1",
Prompt: "p",
MaxTokens: 42,
Temperature: "0.7",
Temperature: 0.7,
}
req := llmArgs.ToLLMProcessorRequest()
Expect(req.InputDatasetId).To(Equal("ds1"))
Expect(req.Prompt).To(Equal("p"))
Expect(req.MaxTokens).To(Equal(42))
Expect(req.MaxTokens).To(Equal(uint(42)))
Expect(req.Temperature).To(Equal("0.7"))
Expect(req.MultipleColumns).To(BeFalse())
Expect(req.Model).To(Equal("gemini-1.5-flash-8b"))
Expand Down
6 changes: 3 additions & 3 deletions types/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ type LLMProcessorRequest struct {
LLMProviderApiKey string `json:"llmProviderApiKey"` // encrypted api key by miner
Model string `json:"model"`
MultipleColumns bool `json:"multipleColumns"`
Prompt string `json:"prompt"` // example: summarize the content of this webpage: ${markdown}
Temperature string `json:"temperature"`
MaxTokens int `json:"maxTokens"`
Prompt string `json:"prompt"` // example: summarize the content of this webpage: ${markdown}
Temperature string `json:"temperature"` // the actor expects a string
MaxTokens uint `json:"maxTokens"`
}

type LLMProcessorResult struct {
Expand Down