Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions args/llm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package args

import (
"encoding/json"
"errors"
"fmt"

teetypes "github.com/masa-finance/tee-types/types"
)

var (
ErrLLMDatasetIdRequired = errors.New("dataset id is required")
ErrLLMPromptRequired = errors.New("prompt is required")
ErrLLMMaxTokensNegative = errors.New("max tokens must be non-negative")
)

const (
LLMDefaultMaxTokens = 300
LLMDefaultTemperature = "0.1"
LLMDefaultMultipleColumns = false
LLMDefaultModel = "gemini-1.5-flash-8b"
)

type LLMProcessorArguments struct {
DatasetId string `json:"dataset_id"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
Temperature string `json:"temperature"`
}

// UnmarshalJSON implements custom JSON unmarshaling with validation
func (l *LLMProcessorArguments) UnmarshalJSON(data []byte) error {
// Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...)
type Alias LLMProcessorArguments
aux := &struct {
*Alias
}{
Alias: (*Alias)(l),
}

if err := json.Unmarshal(data, aux); err != nil {
return fmt.Errorf("failed to unmarshal llm arguments: %w", err)
}

l.setDefaultValues()

return l.Validate()
}

func (l *LLMProcessorArguments) setDefaultValues() {
if l.MaxTokens == 0 {
l.MaxTokens = LLMDefaultMaxTokens
}
if l.Temperature == "" {
l.Temperature = LLMDefaultTemperature
}
}

func (l *LLMProcessorArguments) Validate() error {
if l.DatasetId == "" {
return ErrLLMDatasetIdRequired
}
if l.Prompt == "" {
return ErrLLMPromptRequired
}
if l.MaxTokens < 0 {
return fmt.Errorf("%w: got %v", ErrLLMMaxTokensNegative, l.MaxTokens)
}
return nil
}

func (l LLMProcessorArguments) ToLLMProcessorRequest() teetypes.LLMProcessorRequest {
return teetypes.LLMProcessorRequest{
InputDatasetId: l.DatasetId,
Prompt: l.Prompt,
MaxTokens: l.MaxTokens,
Temperature: l.Temperature,
MultipleColumns: LLMDefaultMultipleColumns, // overrides default in actor API
Model: LLMDefaultModel, // overrides default in actor API
}
}
136 changes: 136 additions & 0 deletions args/llm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package args_test

import (
"encoding/json"
"errors"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

"github.com/masa-finance/tee-types/args"
)

var _ = Describe("LLMProcessorArguments", func() {
Describe("Marshalling and unmarshalling", func() {
It("should set default values", func() {
llmArgs := args.LLMProcessorArguments{
DatasetId: "ds1",
Prompt: "summarize: ${markdown}",
}
jsonData, err := json.Marshal(llmArgs)
Expect(err).ToNot(HaveOccurred())
err = json.Unmarshal([]byte(jsonData), &llmArgs)
Expect(err).ToNot(HaveOccurred())
Expect(llmArgs.MaxTokens).To(Equal(300))
Expect(llmArgs.Temperature).To(Equal("0.1"))
})

It("should override default values", func() {
llmArgs := args.LLMProcessorArguments{
DatasetId: "ds1",
Prompt: "summarize: ${markdown}",
MaxTokens: 123,
Temperature: "0.7",
}
jsonData, err := json.Marshal(llmArgs)
Expect(err).ToNot(HaveOccurred())
err = json.Unmarshal([]byte(jsonData), &llmArgs)
Expect(err).ToNot(HaveOccurred())
Expect(llmArgs.MaxTokens).To(Equal(123))
Expect(llmArgs.Temperature).To(Equal("0.7"))
})

It("should fail unmarshal when dataset_id is missing", func() {
var llmArgs args.LLMProcessorArguments
jsonData := []byte(`{"type":"datasetprocessor","prompt":"p"}`)
err := json.Unmarshal(jsonData, &llmArgs)
Expect(errors.Is(err, args.ErrLLMDatasetIdRequired)).To(BeTrue())
})

It("should fail unmarshal when prompt is missing", func() {
var llmArgs args.LLMProcessorArguments
jsonData := []byte(`{"type":"datasetprocessor","dataset_id":"ds1"}`)
err := json.Unmarshal(jsonData, &llmArgs)
Expect(errors.Is(err, args.ErrLLMPromptRequired)).To(BeTrue())
})
})

Describe("Validation", func() {
It("should succeed with valid arguments", func() {
llmArgs := &args.LLMProcessorArguments{
DatasetId: "ds1",
Prompt: "p",
MaxTokens: 10,
Temperature: "0.2",
}
err := llmArgs.Validate()
Expect(err).ToNot(HaveOccurred())
})

It("should fail when dataset_id is missing", func() {
llmArgs := &args.LLMProcessorArguments{
Prompt: "p",
MaxTokens: 10,
Temperature: "0.2",
}
err := llmArgs.Validate()
Expect(errors.Is(err, args.ErrLLMDatasetIdRequired)).To(BeTrue())
})

It("should fail when prompt is missing", func() {
llmArgs := &args.LLMProcessorArguments{
DatasetId: "ds1",
MaxTokens: 10,
Temperature: "0.2",
}
err := llmArgs.Validate()
Expect(errors.Is(err, args.ErrLLMPromptRequired)).To(BeTrue())
})

It("should fail when max tokens is negative", func() {
llmArgs := &args.LLMProcessorArguments{
DatasetId: "ds1",
Prompt: "p",
MaxTokens: -1,
Temperature: "0.2",
}
err := llmArgs.Validate()
Expect(errors.Is(err, args.ErrLLMMaxTokensNegative)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("got -1"))
})
})

Describe("ToLLMProcessorRequest", func() {
It("should map fields and defaults correctly", func() {
llmArgs := args.LLMProcessorArguments{
DatasetId: "ds1",
Prompt: "p",
MaxTokens: 0, // default applied in To*
Temperature: "",
}
req := llmArgs.ToLLMProcessorRequest()
Expect(req.InputDatasetId).To(Equal("ds1"))
Expect(req.Prompt).To(Equal("p"))
Expect(req.MaxTokens).To(Equal(0))
Expect(req.Temperature).To(Equal(""))
Expect(req.MultipleColumns).To(BeFalse())
Expect(req.Model).To(Equal("gemini-1.5-flash-8b"))
})

It("should map fields correctly when set", func() {
llmArgs := args.LLMProcessorArguments{
DatasetId: "ds1",
Prompt: "p",
MaxTokens: 42,
Temperature: "0.7",
}
req := llmArgs.ToLLMProcessorRequest()
Expect(req.InputDatasetId).To(Equal("ds1"))
Expect(req.Prompt).To(Equal("p"))
Expect(req.MaxTokens).To(Equal(42))
Expect(req.Temperature).To(Equal("0.7"))
Expect(req.MultipleColumns).To(BeFalse())
Expect(req.Model).To(Equal("gemini-1.5-flash-8b"))
})
})
})
47 changes: 2 additions & 45 deletions args/unmarshaller.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,52 +10,9 @@ import (

// JobArguments defines the interface that all job arguments must implement
type JobArguments interface {
Validate() error
GetCapability() types.Capability
}

// TwitterJobArguments extends JobArguments for Twitter-specific methods
type TwitterJobArguments interface {
JobArguments
ValidateForJobType(jobType types.JobType) error
IsSingleTweetOperation() bool
IsMultipleTweetOperation() bool
IsSingleProfileOperation() bool
IsMultipleProfileOperation() bool
IsSingleSpaceOperation() bool
IsTrendsOperation() bool
}

// WebJobArguments extends JobArguments for Web-specific methods
type WebJobArguments interface {
JobArguments
ValidateForJobType(jobType types.JobType) error
IsDeepScrape() bool
HasSelector() bool
GetEffectiveMaxDepth() int
}

// TikTokJobArguments extends JobArguments for TikTok-specific methods
type TikTokJobArguments interface {
JobArguments
ValidateForJobType(jobType types.JobType) error
HasLanguagePreference() bool
GetVideoURL() string
GetLanguageCode() string
}

// LinkedInJobArguments extends JobArguments for LinkedIn-specific methods
type LinkedInJobArguments interface {
JobArguments
ValidateForJobType(jobType types.JobType) error
}

// RedditJobArguments extends JobArguments for Reddit-specific methods
type RedditJobArguments interface {
JobArguments
ValidateForJobType(jobType types.JobType) error
}

// UnmarshalJobArguments unmarshals job arguments from a generic map into the appropriate typed struct
// This works with both tee-indexer and tee-worker JobArguments types
func UnmarshalJobArguments(jobType types.JobType, args map[string]any) (JobArguments, error) {
Expand Down Expand Up @@ -84,8 +41,8 @@ func UnmarshalJobArguments(jobType types.JobType, args map[string]any) (JobArgum
}

// Helper functions for unmarshaling specific argument types
func unmarshalWebArguments(args map[string]any) (*WebSearchArguments, error) {
webArgs := &WebSearchArguments{}
func unmarshalWebArguments(args map[string]any) (*WebArguments, error) {
webArgs := &WebArguments{}
if err := unmarshalToStruct(args, webArgs); err != nil {
return nil, fmt.Errorf("failed to unmarshal web job arguments: %w", err)
}
Expand Down
4 changes: 1 addition & 3 deletions args/unmarshaller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@ var _ = Describe("Unmarshaller", func() {
It("should unmarshal the arguments correctly", func() {
argsMap := map[string]any{
"url": "https://example.com",
"selector": "h1",
"max_depth": 2,
}
jobArgs, err := args.UnmarshalJobArguments(types.WebJob, argsMap)
Expect(err).ToNot(HaveOccurred())
webArgs, ok := jobArgs.(*args.WebSearchArguments)
webArgs, ok := jobArgs.(*args.WebArguments)
Expect(ok).To(BeTrue())
Expect(webArgs.URL).To(Equal("https://example.com"))
Expect(webArgs.Selector).To(Equal("h1"))
Expect(webArgs.MaxDepth).To(Equal(2))
})
})
Expand Down
Loading