diff --git a/.cursor/rules/collaboration.mdc b/.cursor/rules/collaboration.mdc new file mode 100644 index 00000000..fbc46e15 --- /dev/null +++ b/.cursor/rules/collaboration.mdc @@ -0,0 +1,30 @@ +--- +description: +globs: +alwaysApply: true +--- +# Collaboration Rules + +## Planning and Confirmation Rule + +**Before implementing any code changes, features, or modifications:** + +1. **Create a Plan**: Always develop a clear, detailed plan that outlines: + - What changes will be made + - Which files will be modified or created + - The approach and methodology + - Expected outcomes and impacts + +2. **Confirm with User**: Present the plan to the user and wait for explicit confirmation before: + - Making any file modifications + - Creating new files + - Running commands that modify the codebase + - Implementing any suggested changes + +3. **Get Approval**: Only proceed with implementation after receiving clear approval from the user. + +4. **No Assumptions**: Never assume the user wants changes implemented immediately, even if they seem obvious or beneficial. + +**Exception**: Read-only operations (viewing files, searching, analyzing) do not require prior confirmation. + +This rule ensures we maintain collaborative control over the codebase and prevents unwanted changes. diff --git a/.cursor/rules/testing.mdc b/.cursor/rules/testing.mdc new file mode 100644 index 00000000..4bacbbf0 --- /dev/null +++ b/.cursor/rules/testing.mdc @@ -0,0 +1,15 @@ +--- +alwaysApply: true +--- + +## Testing Rule + +**Before writing any test code:** + +1. **Create a Test Plan**: Always develop a clear, detailed plan that outlines: + - What tests will be written + - Which files will be modified or created + - The approach and methodology + - Expected outcomes and impacts + +2. **Prefer Ginkgo & Gomega**: When writing tests, use the Ginkgo and Gomega frameworks (BDD style) where possible to structure and assert tests in Go. Only use the built-in `testing` package for compatibility or legacy reasons, or if instructed. \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..cd885540 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "gomod" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/.gitignore b/.gitignore index 72e5a502..cdd02a24 100644 --- a/.gitignore +++ b/.gitignore @@ -81,6 +81,9 @@ bp-todo.md # TEE tee/private.pem -# LLML +# LLM-related files .aider* GEMINI.md + +# Examples from tee-types +/examples/*json diff --git a/Makefile b/Makefile index 99f57d4b..978ef8a1 100644 --- a/Makefile +++ b/Makefile @@ -86,7 +86,9 @@ test-tiktok: docker-build-test @docker run --user root $(ENV_FILE_ARG) -v $(PWD)/.masa:/home/masa -v $(PWD)/coverage:/app/coverage --rm --workdir /app -e DATA_DIR=/home/masa $(TEST_IMAGE) go test -v ./internal/jobs/tiktok_test.go ./internal/jobs/jobs_suite_test.go test-reddit: docker-build-test - @docker run --user root $(ENV_FILE_ARG) -v $(PWD)/.masa:/home/masa -v $(PWD)/coverage:/app/coverage --rm --workdir /app -e DATA_DIR=/home/masa $(TEST_IMAGE) go test -v ./internal/jobs/reddit_test.go ./internal/jobs/redditapify/client_test.go ./api/types/reddit/reddit_suite_test.go + @docker run --user root $(ENV_FILE_ARG) -v $(PWD)/.masa:/home/masa -v $(PWD)/coverage:/app/coverage --rm --workdir /app -e DATA_DIR=/home/masa $(TEST_IMAGE) sh -c "cd /app && go test -v ./internal/jobs/reddit_test.go ./internal/jobs/jobs_suite_test.go" + @docker run --user root $(ENV_FILE_ARG) -v $(PWD)/.masa:/home/masa -v $(PWD)/coverage:/app/coverage --rm --workdir /app -e DATA_DIR=/home/masa $(TEST_IMAGE) sh -c "cd /app && go test -v ./internal/jobs/redditapify" + @docker run --user root $(ENV_FILE_ARG) -v $(PWD)/.masa:/home/masa -v $(PWD)/coverage:/app/coverage --rm --workdir /app -e DATA_DIR=/home/masa $(TEST_IMAGE) sh -c "cd /app && go test -v ./api/types/reddit_test.go" test-web: docker-build-test @docker run --user root $(ENV_FILE_ARG) -v $(PWD)/.masa:/home/masa -v $(PWD)/coverage:/app/coverage --rm --workdir /app -e DATA_DIR=/home/masa $(TEST_IMAGE) sh -c "cd /app && go test -v ./internal/jobs/web_test.go ./internal/jobs/jobs_suite_test.go" diff --git a/api/args/base/base.go b/api/args/base/base.go new file mode 100644 index 00000000..01e3a220 --- /dev/null +++ b/api/args/base/base.go @@ -0,0 +1,49 @@ +package base + +import ( + "encoding/json" + "fmt" + + "github.com/masa-finance/tee-worker/api/types" +) + +// JobArgument defines the interface that all job arguments must implement +type JobArgument interface { + UnmarshalJSON([]byte) error + GetCapability() types.Capability + ValidateCapability(jobType types.JobType) error + SetDefaultValues() + Validate() error +} + +// Verify interface implementation +var _ JobArgument = (*Arguments)(nil) + +type Arguments struct { + Type types.Capability `json:"type"` +} + +func (t *Arguments) UnmarshalJSON(data []byte) error { + type Alias Arguments + aux := &struct{ *Alias }{Alias: (*Alias)(t)} + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("%v: %w", "failed to unmarshal arguments", err) + } + t.SetDefaultValues() + return t.Validate() +} + +func (a *Arguments) GetCapability() types.Capability { + return a.Type +} + +func (a *Arguments) ValidateCapability(jobType types.JobType) error { + return jobType.ValidateCapability(&a.Type) +} + +func (a *Arguments) SetDefaultValues() { +} + +func (a *Arguments) Validate() error { + return nil +} diff --git a/api/args/linkedin/linkedin.go b/api/args/linkedin/linkedin.go new file mode 100644 index 00000000..3a7205ff --- /dev/null +++ b/api/args/linkedin/linkedin.go @@ -0,0 +1,7 @@ +package linkedin + +import ( + "github.com/masa-finance/tee-worker/api/args/linkedin/profile" +) + +type Profile = profile.Arguments diff --git a/api/args/linkedin/profile/profile.go b/api/args/linkedin/profile/profile.go new file mode 100644 index 00000000..9d22df60 --- /dev/null +++ b/api/args/linkedin/profile/profile.go @@ -0,0 +1,141 @@ +package profile + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/masa-finance/tee-worker/api/args/base" + "github.com/masa-finance/tee-worker/api/types" + "github.com/masa-finance/tee-worker/api/types/linkedin/experiences" + "github.com/masa-finance/tee-worker/api/types/linkedin/functions" + "github.com/masa-finance/tee-worker/api/types/linkedin/industries" + "github.com/masa-finance/tee-worker/api/types/linkedin/profile" + "github.com/masa-finance/tee-worker/api/types/linkedin/seniorities" +) + +var ( + ErrScraperModeNotSupported = errors.New("scraper mode not supported") + ErrMaxItemsTooLarge = errors.New("max items must be less than or equal to 100") + ErrExperienceNotSupported = errors.New("years of experience not supported") + ErrSeniorityNotSupported = errors.New("seniority level not supported") + ErrFunctionNotSupported = errors.New("function not supported") + ErrIndustryNotSupported = errors.New("industry not supported") + ErrUnmarshalling = errors.New("failed to unmarshal LinkedIn profile arguments") +) + +const ( + DefaultMaxItems = 10 + DefaultScraperMode = profile.ScraperModeShort + MaxItems = 1000 // 2500 on the actor, but we will run over 1MB memory limit on responses +) + +// Verify interface implementation +var _ base.JobArgument = (*Arguments)(nil) + +// Arguments defines args for LinkedIn profile operations +type Arguments struct { + Type types.Capability `json:"type"` + ScraperMode profile.ScraperMode `json:"profileScraperMode"` + Query string `json:"searchQuery"` + MaxItems uint `json:"maxItems"` + Locations []string `json:"locations,omitempty"` + CurrentCompanies []string `json:"currentCompanies,omitempty"` + PastCompanies []string `json:"pastCompanies,omitempty"` + CurrentJobTitles []string `json:"currentJobTitles,omitempty"` + PastJobTitles []string `json:"pastJobTitles,omitempty"` + Schools []string `json:"schools,omitempty"` + YearsOfExperience []experiences.Id `json:"yearsOfExperienceIds,omitempty"` + YearsAtCurrentCompany []experiences.Id `json:"yearsAtCurrentCompanyIds,omitempty"` + SeniorityLevels []seniorities.Id `json:"seniorityLevelIds,omitempty"` + Functions []functions.Id `json:"functionIds,omitempty"` + Industries []industries.Id `json:"industryIds,omitempty"` + FirstNames []string `json:"firstNames,omitempty"` + LastNames []string `json:"lastNames,omitempty"` + RecentlyChangedJobs bool `json:"recentlyChangedJobs,omitempty"` + StartPage uint `json:"startPage,omitempty"` +} + +func (t *Arguments) UnmarshalJSON(data []byte) error { + type Alias Arguments + aux := &struct{ *Alias }{Alias: (*Alias)(t)} + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("%w: %w", ErrUnmarshalling, err) + } + t.SetDefaultValues() + return t.Validate() +} + +func (a *Arguments) SetDefaultValues() { + if a.MaxItems == 0 { + a.MaxItems = DefaultMaxItems + } + if a.ScraperMode == "" { + a.ScraperMode = DefaultScraperMode + } +} + +func (a *Arguments) Validate() error { + var errs []error + + if a.MaxItems > MaxItems { + errs = append(errs, ErrMaxItemsTooLarge) + } + + err := a.ValidateCapability(types.LinkedInJob) + if err != nil { + errs = append(errs, err) + } + + if !profile.AllScraperModes.Contains(a.ScraperMode) { + errs = append(errs, ErrScraperModeNotSupported) + } + + for _, yoe := range a.YearsOfExperience { + if !experiences.All.Contains(yoe) { + errs = append(errs, fmt.Errorf("%w: %v", ErrExperienceNotSupported, yoe)) + } + } + for _, yac := range a.YearsAtCurrentCompany { + if !experiences.All.Contains(yac) { + errs = append(errs, fmt.Errorf("%w: %v", ErrExperienceNotSupported, yac)) + } + } + for _, sl := range a.SeniorityLevels { + if !seniorities.All.Contains(sl) { + errs = append(errs, fmt.Errorf("%w: %v", ErrSeniorityNotSupported, sl)) + } + } + for _, f := range a.Functions { + if !functions.All.Contains(f) { + errs = append(errs, fmt.Errorf("%w: %v", ErrFunctionNotSupported, f)) + } + } + for _, i := range a.Industries { + if !industries.All.Contains(i) { + errs = append(errs, fmt.Errorf("%w: %v", ErrIndustryNotSupported, i)) + } + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + + return nil +} + +func (a *Arguments) GetCapability() types.Capability { + return a.Type +} + +func (a *Arguments) ValidateCapability(jobType types.JobType) error { + return jobType.ValidateCapability(&a.Type) +} + +// NewArguments creates a new Arguments instance and applies default values immediately +func NewArguments() Arguments { + args := Arguments{} + args.SetDefaultValues() + args.Validate() // This will set the default capability via ValidateCapability + return args +} diff --git a/api/types/reddit/reddit_suite_test.go b/api/args/linkedin/profile/profile_suite_test.go similarity index 57% rename from api/types/reddit/reddit_suite_test.go rename to api/args/linkedin/profile/profile_suite_test.go index 22d7fd62..713e96d2 100644 --- a/api/types/reddit/reddit_suite_test.go +++ b/api/args/linkedin/profile/profile_suite_test.go @@ -1,4 +1,4 @@ -package reddit_test +package profile_test import ( "testing" @@ -7,7 +7,7 @@ import ( . "github.com/onsi/gomega" ) -func TestReddit(t *testing.T) { +func TestArgs(t *testing.T) { RegisterFailHandler(Fail) - RunSpecs(t, "Reddit Suite") -} \ No newline at end of file + RunSpecs(t, "Args Suite") +} diff --git a/api/args/linkedin/profile/profile_test.go b/api/args/linkedin/profile/profile_test.go new file mode 100644 index 00000000..98718800 --- /dev/null +++ b/api/args/linkedin/profile/profile_test.go @@ -0,0 +1,192 @@ +package profile_test + +import ( + "encoding/json" + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/args/linkedin/profile" + "github.com/masa-finance/tee-worker/api/types" + "github.com/masa-finance/tee-worker/api/types/linkedin/experiences" + "github.com/masa-finance/tee-worker/api/types/linkedin/functions" + "github.com/masa-finance/tee-worker/api/types/linkedin/industries" + ptypes "github.com/masa-finance/tee-worker/api/types/linkedin/profile" + "github.com/masa-finance/tee-worker/api/types/linkedin/seniorities" +) + +var _ = Describe("LinkedIn Profile Arguments", func() { + Describe("Marshalling and unmarshalling", func() { + It("should set default values", func() { + args := profile.NewArguments() + args.Query = "software engineer" + jsonData, err := json.Marshal(args) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal([]byte(jsonData), &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.MaxItems).To(Equal(uint(10))) + Expect(args.ScraperMode).To(Equal(ptypes.ScraperModeShort)) + }) + + It("should override default values", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.MaxItems = 50 + args.ScraperMode = ptypes.ScraperModeFull + jsonData, err := json.Marshal(args) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal([]byte(jsonData), &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.MaxItems).To(Equal(uint(50))) + Expect(args.ScraperMode).To(Equal(ptypes.ScraperModeFull)) + }) + }) + + Describe("Validation", func() { + It("should succeed with valid arguments", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = ptypes.ScraperModeShort + args.MaxItems = 10 + args.YearsOfExperience = []experiences.Id{experiences.ThreeToFiveYears} + args.SeniorityLevels = []seniorities.Id{seniorities.Senior} + args.Functions = []functions.Id{functions.Engineering} + args.Industries = []industries.Id{industries.SoftwareDevelopment} + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail with max items too large", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = ptypes.ScraperModeShort + args.MaxItems = 1500 + err := args.Validate() + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, profile.ErrMaxItemsTooLarge)).To(BeTrue()) + }) + + It("should fail with invalid scraper mode", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = "InvalidMode" + args.MaxItems = 10 + err := args.Validate() + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, profile.ErrScraperModeNotSupported)).To(BeTrue()) + }) + + It("should fail with invalid years of experience", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = ptypes.ScraperModeShort + args.MaxItems = 10 + args.YearsOfExperience = []experiences.Id{"invalid"} + err := args.Validate() + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, profile.ErrExperienceNotSupported)).To(BeTrue()) + + }) + + It("should fail with invalid years at current company", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = ptypes.ScraperModeShort + args.MaxItems = 10 + args.YearsAtCurrentCompany = []experiences.Id{"invalid"} + err := args.Validate() + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, profile.ErrExperienceNotSupported)).To(BeTrue()) + + }) + + It("should fail with invalid seniority level", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = ptypes.ScraperModeShort + args.MaxItems = 10 + args.SeniorityLevels = []seniorities.Id{"invalid"} + err := args.Validate() + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, profile.ErrSeniorityNotSupported)).To(BeTrue()) + }) + + It("should fail with invalid function", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = ptypes.ScraperModeShort + args.MaxItems = 10 + args.Functions = []functions.Id{"invalid"} + err := args.Validate() + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, profile.ErrFunctionNotSupported)).To(BeTrue()) + + }) + + It("should fail with invalid industry", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = ptypes.ScraperModeShort + args.MaxItems = 10 + args.Industries = []industries.Id{"invalid"} + err := args.Validate() + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, profile.ErrIndustryNotSupported)).To(BeTrue()) + + }) + + It("should handle multiple validation errors", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = "InvalidMode" + args.MaxItems = 1500 + args.YearsOfExperience = []experiences.Id{"invalid"} + args.SeniorityLevels = []seniorities.Id{"invalid"} + err := args.Validate() + Expect(err).To(HaveOccurred()) + // Should contain multiple error messages + Expect(errors.Is(err, profile.ErrMaxItemsTooLarge)).To(BeTrue()) + Expect(errors.Is(err, profile.ErrScraperModeNotSupported)).To(BeTrue()) + Expect(errors.Is(err, profile.ErrExperienceNotSupported)).To(BeTrue()) + Expect(errors.Is(err, profile.ErrSeniorityNotSupported)).To(BeTrue()) + }) + }) + + Describe("GetCapability", func() { + It("should return the query type", func() { + args := profile.NewArguments() + Expect(args.GetCapability()).To(Equal(types.CapSearchByProfile)) + }) + }) + + Describe("ValidateCapability", func() { + It("should succeed with valid job type and capability", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = ptypes.ScraperModeShort + args.MaxItems = 10 + err := args.ValidateCapability(types.LinkedInJob) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail with invalid job type", func() { + args := profile.NewArguments() + args.Type = types.CapSearchByQuery // Override the default + args.Query = "software engineer" + args.ScraperMode = ptypes.ScraperModeShort + args.MaxItems = 10 + err := args.ValidateCapability(types.LinkedInJob) + Expect(err).To(HaveOccurred()) + }) + + It("should fail if profile validation fails", func() { + args := profile.NewArguments() + args.Query = "software engineer" + args.ScraperMode = "InvalidMode" + args.MaxItems = 10 + err := args.Validate() + Expect(err).To(HaveOccurred()) + }) + }) +}) diff --git a/api/args/llm/llm.go b/api/args/llm/llm.go new file mode 100644 index 00000000..6dae9294 --- /dev/null +++ b/api/args/llm/llm.go @@ -0,0 +1,7 @@ +package llm + +import ( + "github.com/masa-finance/tee-worker/api/args/llm/process" +) + +type Process = process.Arguments diff --git a/api/args/llm/process/process.go b/api/args/llm/process/process.go new file mode 100644 index 00000000..7d08d0de --- /dev/null +++ b/api/args/llm/process/process.go @@ -0,0 +1,108 @@ +package process + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + + "github.com/masa-finance/tee-worker/api/args/base" + "github.com/masa-finance/tee-worker/api/types" + "github.com/masa-finance/tee-worker/pkg/util" +) + +var ( + ErrDatasetIdRequired = errors.New("dataset id is required") + ErrPromptRequired = errors.New("prompt is required") + ErrUnmarshalling = errors.New("failed to unmarshal arguments") +) + +const ( + DefaultMaxTokens uint = 300 + DefaultTemperature float64 = 0.1 + DefaultMultipleColumns bool = false + DefaultGeminiModel string = "gemini-1.5-flash-8b" + DefaultClaudeModel string = "claude-3-5-haiku-latest" + DefaultItems uint = 1 +) + +var SupportedModels = util.NewSet(DefaultGeminiModel, DefaultClaudeModel) + +// Verify interface implementation +var _ base.JobArgument = (*Arguments)(nil) + +type Arguments struct { + Type types.Capability `json:"type"` + DatasetId string `json:"dataset_id"` + Prompt string `json:"prompt"` + MaxTokens uint `json:"max_tokens"` + Temperature float64 `json:"temperature"` + Items uint `json:"items"` +} + +func (t *Arguments) UnmarshalJSON(data []byte) error { + type Alias Arguments + aux := &struct{ *Alias }{Alias: (*Alias)(t)} + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("%w: %w", ErrUnmarshalling, err) + } + t.SetDefaultValues() + return t.Validate() +} + +func (l *Arguments) SetDefaultValues() { + if l.Temperature == 0 { + l.Temperature = DefaultTemperature + } + if l.MaxTokens == 0 { + l.MaxTokens = DefaultMaxTokens + } + if l.Items == 0 { + l.Items = DefaultItems + } +} + +func (l *Arguments) Validate() error { + if l.DatasetId == "" { + return ErrDatasetIdRequired + } + if l.Prompt == "" { + return ErrPromptRequired + } + return nil +} + +func (l *Arguments) GetCapability() types.Capability { + return l.Type +} + +func (l *Arguments) ValidateCapability(jobType types.JobType) error { + return nil // is not yet a standalone job type +} + +// NewArguments creates a new Arguments instance and applies default values immediately +func NewArguments() Arguments { + args := Arguments{} + args.SetDefaultValues() + args.Validate() // This will set the default capability via ValidateCapability + return args +} + +func (l Arguments) ToProcessorRequest(model string, key string) (types.LLMProcessorRequest, error) { + if !SupportedModels.Contains(model) { + return types.LLMProcessorRequest{}, fmt.Errorf("model %s is not supported", model) + } + if key == "" { + return types.LLMProcessorRequest{}, fmt.Errorf("key is required") + } + + return types.LLMProcessorRequest{ + InputDatasetId: l.DatasetId, + LLMProviderApiKey: key, + Prompt: l.Prompt, + MaxTokens: l.MaxTokens, + Temperature: strconv.FormatFloat(l.Temperature, 'f', -1, 64), + MultipleColumns: DefaultMultipleColumns, // overrides default in actor API + Model: model, // overrides default in actor API + }, nil +} diff --git a/api/args/llm/process/process_suite_test.go b/api/args/llm/process/process_suite_test.go new file mode 100644 index 00000000..a485a7f6 --- /dev/null +++ b/api/args/llm/process/process_suite_test.go @@ -0,0 +1,13 @@ +package process_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestArgs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Args Suite") +} diff --git a/api/args/llm/process/process_test.go b/api/args/llm/process/process_test.go new file mode 100644 index 00000000..c63033cb --- /dev/null +++ b/api/args/llm/process/process_test.go @@ -0,0 +1,108 @@ +package process_test + +import ( + "encoding/json" + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/args/llm/process" +) + +var _ = Describe("LLMProcessorArguments", func() { + Describe("Marshalling and unmarshalling", func() { + It("should set default values", func() { + llmArgs := process.NewArguments() + llmArgs.DatasetId = "ds1" + llmArgs.Prompt = "summarize: ${markdown}" + jsonData, err := json.Marshal(llmArgs) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal([]byte(jsonData), &llmArgs) + Expect(err).ToNot(HaveOccurred()) + Expect(llmArgs.Temperature).To(Equal(0.1)) + Expect(llmArgs.MaxTokens).To(Equal(uint(300))) + Expect(llmArgs.Items).To(Equal(uint(1))) + }) + + It("should override default values", func() { + llmArgs := process.NewArguments() + llmArgs.DatasetId = "ds1" + llmArgs.Prompt = "summarize: ${markdown}" + llmArgs.MaxTokens = 123 + llmArgs.Temperature = 0.7 + llmArgs.Items = 3 + jsonData, err := json.Marshal(llmArgs) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal([]byte(jsonData), &llmArgs) + Expect(err).ToNot(HaveOccurred()) + Expect(llmArgs.Temperature).To(Equal(0.7)) + Expect(llmArgs.MaxTokens).To(Equal(uint(123))) + Expect(llmArgs.Items).To(Equal(uint(3))) + }) + + It("should fail unmarshal when dataset_id is missing", func() { + var llmArgs process.Arguments + jsonData := []byte(`{"type":"datasetprocessor","prompt":"p"}`) + err := json.Unmarshal(jsonData, &llmArgs) + Expect(errors.Is(err, process.ErrDatasetIdRequired)).To(BeTrue()) + }) + + It("should fail unmarshal when prompt is missing", func() { + var llmArgs process.Arguments + jsonData := []byte(`{"type":"datasetprocessor","dataset_id":"ds1"}`) + err := json.Unmarshal(jsonData, &llmArgs) + Expect(errors.Is(err, process.ErrPromptRequired)).To(BeTrue()) + }) + }) + + Describe("Validation", func() { + It("should succeed with valid arguments", func() { + llmArgs := process.NewArguments() + llmArgs.DatasetId = "ds1" + llmArgs.Prompt = "p" + llmArgs.MaxTokens = 10 + llmArgs.Temperature = 0.2 + llmArgs.Items = 1 + err := llmArgs.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail when dataset_id is missing", func() { + llmArgs := process.NewArguments() + llmArgs.Prompt = "p" + llmArgs.MaxTokens = 10 + llmArgs.Temperature = 0.2 + err := llmArgs.Validate() + Expect(errors.Is(err, process.ErrDatasetIdRequired)).To(BeTrue()) + }) + + It("should fail when prompt is missing", func() { + llmArgs := process.NewArguments() + llmArgs.DatasetId = "ds1" + llmArgs.MaxTokens = 10 + llmArgs.Temperature = 0.2 + err := llmArgs.Validate() + Expect(errors.Is(err, process.ErrPromptRequired)).To(BeTrue()) + }) + }) + + Describe("ToLLMProcessorRequest", func() { + It("should map request fields to actor request fields", func() { + llmArgs := process.NewArguments() + llmArgs.DatasetId = "ds1" + llmArgs.Prompt = "p" + llmArgs.MaxTokens = 42 + llmArgs.Temperature = 0.7 + req, err := llmArgs.ToProcessorRequest(process.DefaultGeminiModel, "api-key") + Expect(err).ToNot(HaveOccurred()) + Expect(req.InputDatasetId).To(Equal("ds1")) + Expect(req.Prompt).To(Equal("p")) + Expect(req.MaxTokens).To(Equal(uint(42))) + Expect(req.Temperature).To(Equal("0.7")) + Expect(req.MultipleColumns).To(BeFalse()) + Expect(req.Model).To(Equal(process.DefaultGeminiModel)) + Expect(req.LLMProviderApiKey).To(Equal("api-key")) + }) + }) +}) diff --git a/api/args/reddit/reddit.go b/api/args/reddit/reddit.go new file mode 100644 index 00000000..ec34b0d0 --- /dev/null +++ b/api/args/reddit/reddit.go @@ -0,0 +1,7 @@ +package reddit + +import ( + "github.com/masa-finance/tee-worker/api/args/reddit/search" +) + +type Search = search.Arguments diff --git a/api/args/reddit/search/search.go b/api/args/reddit/search/search.go new file mode 100644 index 00000000..686eb249 --- /dev/null +++ b/api/args/reddit/search/search.go @@ -0,0 +1,192 @@ +package search + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/masa-finance/tee-worker/api/args/base" + "github.com/masa-finance/tee-worker/api/types" +) + +var ( + ErrInvalidType = errors.New("invalid type") + ErrInvalidSort = errors.New("invalid sort") + ErrTimeInTheFuture = errors.New("after field is in the future") + ErrNoQueries = errors.New("queries must be provided for all query types except scrapeurls") + ErrNoUrls = errors.New("urls must be provided for scrapeurls query type") + ErrQueriesNotAllowed = errors.New("the scrapeurls query type does not admit queries") + ErrUrlsNotAllowed = errors.New("urls can only be provided for the scrapeurls query type") + ErrUnmarshalling = errors.New("failed to unmarshal reddit search arguments") +) + +const ( + // These reflect the default values in https://apify.com/trudax/reddit-scraper/input-schema + DefaultMaxItems = 10 + DefaultMaxPosts = 10 + DefaultMaxComments = 10 + DefaultMaxCommunities = 2 + DefaultMaxUsers = 2 + DefaultSort = types.RedditSortNew +) + +const DomainSuffix = "reddit.com" + +// Verify interface implementation +var _ base.JobArgument = (*Arguments)(nil) + +// Arguments defines args for Reddit scrapes +// see https://apify.com/trudax/reddit-scraper +type Arguments struct { + Type types.Capability `json:"type"` + Queries []string `json:"queries"` + URLs []string `json:"urls"` + Sort types.RedditSortType `json:"sort"` + IncludeNSFW bool `json:"include_nsfw"` + SkipPosts bool `json:"skip_posts"` // Valid only for searchusers + After time.Time `json:"after"` // valid only for scrapeurls and searchposts + MaxItems uint `json:"max_items"` // Max number of items to scrape (total), default 10 + MaxResults uint `json:"max_results"` // Max number of results per page, default MaxItems + MaxPosts uint `json:"max_posts"` // Max number of posts per page, default 10 + MaxComments uint `json:"max_comments"` // Max number of comments per page, default 10 + MaxCommunities uint `json:"max_communities"` // Max number of communities per page, default 2 + MaxUsers uint `json:"max_users"` // Max number of users per page, default 2 + NextCursor string `json:"next_cursor"` +} + +func (t *Arguments) UnmarshalJSON(data []byte) error { + type Alias Arguments + aux := &struct{ *Alias }{Alias: (*Alias)(t)} + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("%w: %w", ErrUnmarshalling, err) + } + t.SetDefaultValues() + return t.Validate() +} + +// SetDefaultValues sets the default values for the parameters that were not provided and canonicalizes the strings for later validation +func (r *Arguments) SetDefaultValues() { + if r.MaxItems == 0 { + r.MaxItems = DefaultMaxItems + } + if r.MaxPosts == 0 { + r.MaxPosts = DefaultMaxPosts + } + if r.MaxComments == 0 { + r.MaxComments = DefaultMaxComments + } + if r.MaxCommunities == 0 { + r.MaxCommunities = DefaultMaxCommunities + } + if r.MaxUsers == 0 { + r.MaxUsers = DefaultMaxUsers + } + if r.MaxItems != 0 { + r.MaxResults = r.MaxItems + } else if r.MaxResults == 0 { + r.MaxResults = DefaultMaxItems + } + if r.Sort == "" { + r.Sort = DefaultSort + } + + r.Sort = types.RedditSortType(strings.ToLower(string(r.Sort))) +} + +func (r *Arguments) Validate() error { + var errs []error + + if !types.AllRedditQueryTypes.Contains(r.Type) { + errs = append(errs, ErrInvalidType) + } + + if !types.AllRedditSortTypes.Contains(r.Sort) { + errs = append(errs, ErrInvalidSort) + } + + if time.Now().Before(r.After) { + errs = append(errs, ErrTimeInTheFuture) + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + + if r.Type == types.CapScrapeUrls { + if len(r.URLs) == 0 { + errs = append(errs, ErrNoUrls) + } + if len(r.Queries) > 0 { + errs = append(errs, ErrQueriesNotAllowed) + } + + for _, u := range r.URLs { + u, err := url.Parse(u) + if err != nil { + errs = append(errs, fmt.Errorf("%s is not a valid URL", u)) + } else { + if !strings.HasSuffix(strings.ToLower(u.Host), DomainSuffix) { + errs = append(errs, fmt.Errorf("invalid Reddit URL %s", u)) + } + if !strings.HasPrefix(u.Path, "/r/") { + errs = append(errs, fmt.Errorf("%s is not a Reddit post or comment URL (missing /r/)", u)) + } + if !strings.Contains(u.Path, "/comments/") { + errs = append(errs, fmt.Errorf("%s is not a Reddit post or comment URL (missing /comments/)", u)) + } + } + } + } else { + if len(r.Queries) == 0 { + errs = append(errs, ErrNoQueries) + } + if len(r.URLs) > 0 { + errs = append(errs, ErrUrlsNotAllowed) + } + } + + return errors.Join(errs...) +} + +// GetCapability returns the capability of the arguments +func (r *Arguments) GetCapability() types.Capability { + return r.Type +} + +// ValidateCapability validates the capability of the arguments +func (r *Arguments) ValidateCapability(jobType types.JobType) error { + return jobType.ValidateCapability(&r.Type) +} + +// NewArguments creates a new Arguments instance with the specified capability +// and applies default values immediately +func NewArguments(capability types.Capability) Arguments { + args := Arguments{ + Type: capability, + } + args.SetDefaultValues() + return args +} + +// NewSearchPostsArguments creates a new Arguments instance for searching posts +func NewSearchPostsArguments() Arguments { + return NewArguments(types.CapSearchPosts) +} + +// NewSearchUsersArguments creates a new Arguments instance for searching users +func NewSearchUsersArguments() Arguments { + return NewArguments(types.CapSearchUsers) +} + +// NewSearchCommunitiesArguments creates a new Arguments instance for searching communities +func NewSearchCommunitiesArguments() Arguments { + return NewArguments(types.CapSearchCommunities) +} + +// NewScrapeUrlsArguments creates a new Arguments instance for scraping URLs +func NewScrapeUrlsArguments() Arguments { + return NewArguments(types.CapScrapeUrls) +} diff --git a/api/args/reddit/search/search_suite_test.go b/api/args/reddit/search/search_suite_test.go new file mode 100644 index 00000000..688d8ec8 --- /dev/null +++ b/api/args/reddit/search/search_suite_test.go @@ -0,0 +1,13 @@ +package search_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestArgs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Args Suite") +} diff --git a/api/args/reddit/search/search_test.go b/api/args/reddit/search/search_test.go new file mode 100644 index 00000000..6356649c --- /dev/null +++ b/api/args/reddit/search/search_test.go @@ -0,0 +1,162 @@ +package search_test + +import ( + "encoding/json" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/args/reddit/search" + "github.com/masa-finance/tee-worker/api/types" +) + +var _ = Describe("RedditArguments", func() { + Describe("Marshalling and unmarshalling", func() { + It("should set default values", func() { + redditArgs := search.NewSearchPostsArguments() + redditArgs.Queries = []string{"Zaphod", "Ford"} + jsonData, err := json.Marshal(redditArgs) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal([]byte(jsonData), &redditArgs) + Expect(err).ToNot(HaveOccurred()) + Expect(redditArgs.MaxItems).To(Equal(uint(10))) + Expect(redditArgs.MaxPosts).To(Equal(uint(10))) + Expect(redditArgs.MaxComments).To(Equal(uint(10))) + Expect(redditArgs.MaxCommunities).To(Equal(uint(2))) + Expect(redditArgs.MaxUsers).To(Equal(uint(2))) + Expect(redditArgs.Sort).To(Equal(types.RedditSortNew)) + Expect(redditArgs.MaxResults).To(Equal(redditArgs.MaxItems)) + }) + + It("should override default values", func() { + redditArgs := search.NewSearchPostsArguments() + redditArgs.Queries = []string{"Zaphod", "Ford"} + redditArgs.MaxItems = 20 + redditArgs.MaxPosts = 21 + redditArgs.MaxComments = 22 + redditArgs.MaxCommunities = 23 + redditArgs.MaxUsers = 24 + redditArgs.Sort = types.RedditSortTop + jsonData, err := json.Marshal(redditArgs) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal([]byte(jsonData), &redditArgs) + Expect(err).ToNot(HaveOccurred()) + Expect(redditArgs.MaxItems).To(Equal(uint(20))) + Expect(redditArgs.MaxPosts).To(Equal(uint(21))) + Expect(redditArgs.MaxComments).To(Equal(uint(22))) + Expect(redditArgs.MaxCommunities).To(Equal(uint(23))) + Expect(redditArgs.MaxUsers).To(Equal(uint(24))) + Expect(redditArgs.MaxResults).To(Equal(uint(20))) + Expect(redditArgs.Sort).To(Equal(types.RedditSortTop)) + }) + + }) + + Describe("Validation", func() { + It("should succeed with valid arguments", func() { + redditArgs := search.NewSearchPostsArguments() + redditArgs.Queries = []string{"test"} + redditArgs.Sort = types.RedditSortNew + err := redditArgs.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should succeed with valid scrapeurls arguments", func() { + redditArgs := search.NewScrapeUrlsArguments() + redditArgs.URLs = []string{"https://www.reddit.com/r/golang/comments/foo/bar"} + redditArgs.Sort = types.RedditSortNew + err := redditArgs.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail with an invalid type", func() { + redditArgs := search.NewSearchPostsArguments() + redditArgs.Type = "invalidtype" // Override the default + redditArgs.Queries = []string{"test"} + redditArgs.Sort = types.RedditSortNew + err := redditArgs.Validate() + Expect(err).To(MatchError(search.ErrInvalidType)) + }) + + It("should fail with an invalid sort", func() { + redditArgs := search.NewSearchPostsArguments() + redditArgs.Queries = []string{"test"} + redditArgs.Sort = "invalidsort" + err := redditArgs.Validate() + Expect(err).To(MatchError(search.ErrInvalidSort)) + }) + + It("should fail if the after time is in the future", func() { + redditArgs := &search.Arguments{ + Type: types.CapSearchPosts, + Queries: []string{"test"}, + Sort: types.RedditSortNew, + After: time.Now().Add(24 * time.Hour), + } + err := redditArgs.Validate() + Expect(err).To(MatchError(search.ErrTimeInTheFuture)) + }) + + It("should fail if queries are not provided for searchposts", func() { + redditArgs := search.NewSearchPostsArguments() + redditArgs.Sort = types.RedditSortNew + err := redditArgs.Validate() + Expect(err).To(MatchError(search.ErrNoQueries)) + }) + + It("should fail if urls are not provided for scrapeurls", func() { + redditArgs := search.NewScrapeUrlsArguments() + redditArgs.Sort = types.RedditSortNew + err := redditArgs.Validate() + Expect(err).To(MatchError(search.ErrNoUrls)) + }) + + It("should fail if queries are provided for scrapeurls", func() { + redditArgs := search.NewScrapeUrlsArguments() + redditArgs.Queries = []string{"test"} + redditArgs.URLs = []string{"https://www.reddit.com/r/golang/comments/foo/bar/"} + redditArgs.Sort = types.RedditSortNew + err := redditArgs.Validate() + Expect(err).To(MatchError(search.ErrQueriesNotAllowed)) + }) + + It("should fail if urls are provided for searchposts", func() { + redditArgs := search.NewSearchPostsArguments() + redditArgs.Queries = []string{"test"} + redditArgs.URLs = []string{"https://www.reddit.com/r/golang/comments/foo/bar"} + redditArgs.Sort = types.RedditSortNew + err := redditArgs.Validate() + Expect(err).To(MatchError(search.ErrUrlsNotAllowed)) + }) + + It("should fail with an invalid URL", func() { + redditArgs := &search.Arguments{ + Type: types.CapScrapeUrls, + URLs: []string{"ht tp://invalid-url.com"}, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("is not a valid URL")) + }) + + It("should fail with an invalid domain", func() { + redditArgs := search.NewScrapeUrlsArguments() + redditArgs.URLs = []string{"https://www.google.com"} + redditArgs.Sort = types.RedditSortNew + err := redditArgs.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("invalid Reddit URL")) + }) + + It("should fail if the URL is not a post or comment", func() { + redditArgs := search.NewScrapeUrlsArguments() + redditArgs.URLs = []string{"https://www.reddit.com/r/golang/"} + redditArgs.Sort = types.RedditSortNew + err := redditArgs.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not a Reddit post or comment URL")) + }) + }) +}) diff --git a/api/args/telemetry/telemetry.go b/api/args/telemetry/telemetry.go new file mode 100644 index 00000000..4cc2eb6a --- /dev/null +++ b/api/args/telemetry/telemetry.go @@ -0,0 +1,63 @@ +package telemetry + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/masa-finance/tee-worker/api/args/base" + "github.com/masa-finance/tee-worker/api/types" +) + +var ( + ErrUnmarshalling = errors.New("failed to unmarshal telemetry arguments") +) + +type Telemetry = Arguments + +// Verify interface implementation +var _ base.JobArgument = (*Arguments)(nil) + +// Arguments defines args for Telemetry jobs +type Arguments struct { + Type types.Capability `json:"type"` +} + +func (t *Arguments) UnmarshalJSON(data []byte) error { + type Alias Arguments + aux := &struct{ *Alias }{Alias: (*Alias)(t)} + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("%w: %w", ErrUnmarshalling, err) + } + t.SetDefaultValues() + return t.Validate() +} + +func (t *Arguments) SetDefaultValues() { +} + +func (t *Arguments) Validate() error { + err := t.ValidateCapability(types.TelemetryJob) + if err != nil { + return err + } + return nil +} + +// GetCapability returns the capability of the arguments +func (t *Arguments) GetCapability() types.Capability { + return t.Type +} + +// ValidateCapability validates the capability of the arguments +func (t *Arguments) ValidateCapability(jobType types.JobType) error { + return jobType.ValidateCapability(&t.Type) +} + +// NewArguments creates a new Arguments instance and applies default values immediately +func NewArguments() Arguments { + args := Arguments{} + args.SetDefaultValues() + args.Validate() // This will set the default capability via ValidateCapability + return args +} diff --git a/api/args/telemetry/telemetry_suite_test.go b/api/args/telemetry/telemetry_suite_test.go new file mode 100644 index 00000000..9daa685c --- /dev/null +++ b/api/args/telemetry/telemetry_suite_test.go @@ -0,0 +1,13 @@ +package telemetry_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestArgs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Args Suite") +} diff --git a/api/args/telemetry/telemetry_test.go b/api/args/telemetry/telemetry_test.go new file mode 100644 index 00000000..28dd580b --- /dev/null +++ b/api/args/telemetry/telemetry_test.go @@ -0,0 +1,113 @@ +package telemetry_test + +import ( + "encoding/json" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/args/telemetry" + "github.com/masa-finance/tee-worker/api/types" +) + +var _ = Describe("Telemetry Arguments", func() { + Describe("Marshalling and unmarshalling", func() { + It("should set default values", func() { + args := &telemetry.Arguments{ + Type: types.CapTelemetry, + } + jsonData, err := json.Marshal(args) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal([]byte(jsonData), &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.Type).To(Equal(types.CapTelemetry)) + }) + + It("should preserve custom values", func() { + args := &telemetry.Arguments{ + Type: types.CapTelemetry, + } + jsonData, err := json.Marshal(args) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal([]byte(jsonData), &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.Type).To(Equal(types.CapTelemetry)) + }) + + It("should handle invalid JSON", func() { + args := &telemetry.Arguments{} + invalidJSON := `{"type": "telemetry", "invalid": }` + err := json.Unmarshal([]byte(invalidJSON), args) + Expect(err).To(HaveOccurred()) + // The error should be a JSON syntax error, not our custom error + Expect(err).To(BeAssignableToTypeOf(&json.SyntaxError{})) + }) + }) + + Describe("Validation", func() { + It("should succeed with valid arguments", func() { + args := &telemetry.Arguments{ + Type: types.CapTelemetry, + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should succeed with empty arguments", func() { + args := &telemetry.Arguments{} + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Describe("GetCapability", func() { + It("should return the telemetry capability", func() { + args := &telemetry.Arguments{ + Type: types.CapTelemetry, + } + Expect(args.GetCapability()).To(Equal(types.CapTelemetry)) + }) + + It("should return empty capability for uninitialized arguments", func() { + args := &telemetry.Arguments{} + Expect(args.GetCapability()).To(Equal(types.Capability(""))) + }) + }) + + Describe("ValidateCapability", func() { + It("should succeed with valid job type and capability", func() { + args := &telemetry.Arguments{ + Type: types.CapTelemetry, + } + err := args.ValidateCapability(types.TelemetryJob) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail with invalid job type", func() { + args := &telemetry.Arguments{ + Type: types.CapTelemetry, + } + err := args.ValidateCapability(types.LinkedInJob) + Expect(err).To(HaveOccurred()) + }) + + It("should fail with invalid capability", func() { + args := &telemetry.Arguments{ + Type: types.CapSearchPosts, // Wrong capability + } + err := args.ValidateCapability(types.TelemetryJob) + Expect(err).To(HaveOccurred()) + }) + }) + + Describe("SetDefaultValues", func() { + It("should not modify arguments", func() { + args := &telemetry.Arguments{ + Type: types.CapTelemetry, + } + originalType := args.Type + args.SetDefaultValues() + Expect(args.Type).To(Equal(originalType)) + }) + }) +}) diff --git a/api/args/tiktok/query/query.go b/api/args/tiktok/query/query.go new file mode 100644 index 00000000..3733997b --- /dev/null +++ b/api/args/tiktok/query/query.go @@ -0,0 +1,74 @@ +package query + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/masa-finance/tee-worker/api/args/base" + "github.com/masa-finance/tee-worker/api/types" +) + +var ( + ErrSearchOrUrlsRequired = errors.New("either 'search' or 'start_urls' are required") + ErrUnmarshalling = errors.New("failed to unmarshal TikTok searchbyquery arguments") +) + +const ( + DefaultMaxItems = 10 + DefaultType = types.CapSearchByQuery +) + +// Verify interface implementation +var _ base.JobArgument = (*Arguments)(nil) + +type Arguments struct { + Type types.Capability `json:"type"` + Search []string `json:"search,omitempty"` + StartUrls []string `json:"start_urls,omitempty"` + MaxItems uint `json:"max_items,omitempty"` + EndPage uint `json:"end_page,omitempty"` +} + +func (t *Arguments) UnmarshalJSON(data []byte) error { + type Alias Arguments + aux := &struct{ *Alias }{Alias: (*Alias)(t)} + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("%w: %w", ErrUnmarshalling, err) + } + t.SetDefaultValues() + return t.Validate() +} + +func (t *Arguments) SetDefaultValues() { + if t.MaxItems == 0 { + t.MaxItems = DefaultMaxItems + } +} + +func (t *Arguments) GetCapability() types.Capability { + return t.Type +} + +func (t *Arguments) ValidateCapability(jobType types.JobType) error { + return jobType.ValidateCapability(&t.Type) +} + +func (t *Arguments) Validate() error { + err := t.ValidateCapability(types.TiktokJob) + if err != nil { + return err + } + if len(t.Search) == 0 && len(t.StartUrls) == 0 { + return ErrSearchOrUrlsRequired + } + return nil +} + +func NewArguments() Arguments { + args := Arguments{ + Type: types.CapSearchByQuery, + } + args.SetDefaultValues() + return args +} diff --git a/api/args/tiktok/query/query_suite_test.go b/api/args/tiktok/query/query_suite_test.go new file mode 100644 index 00000000..484865cc --- /dev/null +++ b/api/args/tiktok/query/query_suite_test.go @@ -0,0 +1,13 @@ +package query_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestArgs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Args Suite") +} diff --git a/api/args/tiktok/query/query_test.go b/api/args/tiktok/query/query_test.go new file mode 100644 index 00000000..454c76f6 --- /dev/null +++ b/api/args/tiktok/query/query_test.go @@ -0,0 +1,205 @@ +package query_test + +import ( + "encoding/json" + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/args/tiktok/query" + "github.com/masa-finance/tee-worker/api/types" +) + +var _ = Describe("TikTokQueryArguments", func() { + Describe("Marshalling and unmarshalling", func() { + It("should unmarshal valid arguments with search", func() { + var args query.Arguments + jsonData := []byte(`{"type":"searchbyquery","search":["test query","another query"],"max_items":20}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.Search).To(Equal([]string{"test query", "another query"})) + Expect(args.MaxItems).To(Equal(uint(20))) + }) + + It("should unmarshal valid arguments with start_urls", func() { + var args query.Arguments + jsonData := []byte(`{"type":"searchbyquery","start_urls":["https://tiktok.com/@user1","https://tiktok.com/@user2"],"max_items":15}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.StartUrls).To(Equal([]string{"https://tiktok.com/@user1", "https://tiktok.com/@user2"})) + Expect(args.MaxItems).To(Equal(uint(15))) + }) + + It("should unmarshal valid arguments with both search and start_urls", func() { + var args query.Arguments + jsonData := []byte(`{"type":"searchbyquery","search":["test"],"start_urls":["https://tiktok.com/@user"],"max_items":5}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.Search).To(Equal([]string{"test"})) + Expect(args.StartUrls).To(Equal([]string{"https://tiktok.com/@user"})) + Expect(args.MaxItems).To(Equal(uint(5))) + }) + + It("should unmarshal valid arguments without max_items (should use default)", func() { + var args query.Arguments + jsonData := []byte(`{"type":"searchbyquery","search":["test query"]}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.Search).To(Equal([]string{"test query"})) + Expect(args.MaxItems).To(Equal(uint(10))) // Default value + }) + + It("should fail unmarshal with invalid JSON", func() { + var args query.Arguments + jsonData := []byte(`{"type":"searchbyquery","search":["test query"`) + err := json.Unmarshal(jsonData, &args) + Expect(err).To(HaveOccurred()) + }) + + It("should fail unmarshal when neither search nor start_urls are provided", func() { + var args query.Arguments + jsonData := []byte(`{"type":"searchbyquery","max_items":10}`) + err := json.Unmarshal(jsonData, &args) + Expect(errors.Is(err, query.ErrSearchOrUrlsRequired)).To(BeTrue()) + }) + }) + + Describe("Validation", func() { + It("should succeed with valid search arguments", func() { + args := query.NewArguments() + args.Search = []string{"test query", "another query"} + args.MaxItems = 20 + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should succeed with valid start_urls arguments", func() { + args := query.NewArguments() + args.StartUrls = []string{"https://tiktok.com/@user1", "https://tiktok.com/@user2"} + args.MaxItems = 15 + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should succeed with both search and start_urls", func() { + args := query.NewArguments() + args.Search = []string{"test"} + args.StartUrls = []string{"https://tiktok.com/@user"} + args.MaxItems = 5 + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail when both search and start_urls are empty", func() { + args := query.NewArguments() + args.MaxItems = 10 + err := args.Validate() + Expect(errors.Is(err, query.ErrSearchOrUrlsRequired)).To(BeTrue()) + }) + + It("should fail when search is empty slice", func() { + args := query.NewArguments() + args.Search = []string{} + args.MaxItems = 10 + err := args.Validate() + Expect(errors.Is(err, query.ErrSearchOrUrlsRequired)).To(BeTrue()) + }) + + It("should fail when start_urls is empty slice", func() { + args := query.NewArguments() + args.StartUrls = []string{} + args.MaxItems = 10 + err := args.Validate() + Expect(errors.Is(err, query.ErrSearchOrUrlsRequired)).To(BeTrue()) + }) + }) + + Describe("Default values", func() { + It("should set default max_items when not provided", func() { + args := query.NewArguments() + args.Search = []string{"test"} + args.SetDefaultValues() + Expect(args.MaxItems).To(Equal(uint(10))) + }) + + It("should not override existing max_items", func() { + args := query.NewArguments() + args.Search = []string{"test"} + args.MaxItems = 25 + args.SetDefaultValues() + Expect(args.MaxItems).To(Equal(uint(25))) + }) + + It("should not override zero max_items if explicitly set", func() { + args := query.NewArguments() + args.Search = []string{"test"} + args.MaxItems = 0 + args.SetDefaultValues() + Expect(args.MaxItems).To(Equal(uint(10))) // Should set default + }) + }) + + Describe("Job capability", func() { + It("should return the searchbyquery capability", func() { + args := query.NewArguments() + Expect(args.GetCapability()).To(Equal(types.CapSearchByQuery)) + }) + + It("should validate capability for TiktokJob", func() { + args := query.NewArguments() + args.Search = []string{"test query"} + args.MaxItems = 10 + err := args.ValidateCapability(types.TiktokJob) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail validation for incompatible job type", func() { + args := query.NewArguments() + args.Search = []string{"test query"} + args.MaxItems = 10 + // Set a different capability first + args.Type = types.CapTranscription + err := args.ValidateCapability(types.TwitterJob) + Expect(err).To(HaveOccurred()) + // The capability should remain unchanged + Expect(args.Type).To(Equal(types.CapTranscription)) + }) + }) + + Describe("Edge cases", func() { + It("should handle empty search strings", func() { + args := query.NewArguments() + args.Search = []string{"", "valid query", ""} + args.MaxItems = 10 + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should handle empty start_urls strings", func() { + args := query.NewArguments() + args.StartUrls = []string{"", "https://tiktok.com/@user", ""} + args.MaxItems = 10 + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should handle large max_items values", func() { + args := query.NewArguments() + args.Search = []string{"test"} + args.MaxItems = 1000 + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should handle end_page field", func() { + args := query.NewArguments() + args.Search = []string{"test"} + args.MaxItems = 10 + args.EndPage = 5 + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + Expect(args.EndPage).To(Equal(uint(5))) + }) + }) +}) diff --git a/api/args/tiktok/tiktok.go b/api/args/tiktok/tiktok.go new file mode 100644 index 00000000..3779e448 --- /dev/null +++ b/api/args/tiktok/tiktok.go @@ -0,0 +1,11 @@ +package tiktok + +import ( + "github.com/masa-finance/tee-worker/api/args/tiktok/query" + "github.com/masa-finance/tee-worker/api/args/tiktok/transcription" + "github.com/masa-finance/tee-worker/api/args/tiktok/trending" +) + +type Transcription = transcription.Arguments +type Query = query.Arguments +type Trending = trending.Arguments diff --git a/api/args/tiktok/transcription/transcription.go b/api/args/tiktok/transcription/transcription.go new file mode 100644 index 00000000..03356fab --- /dev/null +++ b/api/args/tiktok/transcription/transcription.go @@ -0,0 +1,127 @@ +package transcription + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + + "github.com/masa-finance/tee-worker/api/args/base" + "github.com/masa-finance/tee-worker/api/types" +) + +var ( + ErrVideoURLRequired = errors.New("video_url is required") + ErrInvalidVideoURL = errors.New("invalid video_url format") + ErrInvalidTikTokURL = errors.New("url must be a valid TikTok video URL") + ErrInvalidLanguageCode = errors.New("invalid language code") + ErrUnmarshalling = errors.New("failed to unmarshal TikTok transcription arguments") +) + +const ( + DefaultLanguage = "eng-US" +) + +// Verify interface implementation +var _ base.JobArgument = (*Arguments)(nil) + +// Arguments defines args for TikTok transcriptions +type Arguments struct { + Type types.Capability `json:"type"` + VideoURL string `json:"video_url"` + Language string `json:"language,omitempty"` +} + +func (a *Arguments) UnmarshalJSON(data []byte) error { + type Alias Arguments + aux := &struct{ *Alias }{Alias: (*Alias)(a)} + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("%w: %w", ErrUnmarshalling, err) + } + a.SetDefaultValues() + return a.Validate() +} + +func (a *Arguments) SetDefaultValues() { + if a.Language == "" { + a.Language = DefaultLanguage + } +} + +// Validate validates the TikTok arguments +func (t *Arguments) Validate() error { + err := t.ValidateCapability(types.TiktokJob) + if err != nil { + return err + } + if t.VideoURL == "" { + return ErrVideoURLRequired + } + + // Validate URL format + parsedURL, err := url.Parse(t.VideoURL) + if err != nil { + return fmt.Errorf("%w: %v", ErrInvalidVideoURL, err) + } + + // Basic TikTok URL validation + if !t.IsTikTokURL(parsedURL) { + return ErrInvalidTikTokURL + } + + // Validate language format if provided + if t.Language != "" { + if err := t.validateLanguageCode(); err != nil { + return err + } + } + + return nil +} + +func (t *Arguments) GetCapability() types.Capability { + return t.Type +} + +func (t *Arguments) ValidateCapability(jobType types.JobType) error { + return jobType.ValidateCapability(&t.Type) +} + +// IsTikTokURL validates if the URL is a TikTok URL +func (t *Arguments) IsTikTokURL(parsedURL *url.URL) bool { + host := strings.ToLower(parsedURL.Host) + return host == "tiktok.com" || strings.HasSuffix(host, ".tiktok.com") +} + +// HasLanguagePreference returns true if a language preference is specified +func (t *Arguments) HasLanguagePreference() bool { + return t.Language != "" +} + +// GetVideoURL returns the source video URL +func (t *Arguments) GetVideoURL() string { + return t.VideoURL +} + +// GetLanguageCode returns the language code, defaulting to "en-us" if not specified +func (t *Arguments) GetLanguageCode() string { + return t.Language +} + +// validateLanguageCode validates the language code format +func (t *Arguments) validateLanguageCode() error { + parts := strings.Split(t.Language, "-") + if len(parts) != 2 || (len(parts[0]) != 2 && len(parts[0]) != 3) || len(parts[1]) != 2 { + return fmt.Errorf("%w: %s", ErrInvalidLanguageCode, t.Language) + } + return nil +} + +func NewArguments() Arguments { + args := Arguments{ + Type: types.CapTranscription, + } + args.SetDefaultValues() + return args +} diff --git a/api/args/tiktok/transcription/transcription_suite_test.go b/api/args/tiktok/transcription/transcription_suite_test.go new file mode 100644 index 00000000..a9d7709b --- /dev/null +++ b/api/args/tiktok/transcription/transcription_suite_test.go @@ -0,0 +1,13 @@ +package transcription_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestArgs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Args Suite") +} diff --git a/api/args/tiktok/transcription/transcription_test.go b/api/args/tiktok/transcription/transcription_test.go new file mode 100644 index 00000000..50f5d6ea --- /dev/null +++ b/api/args/tiktok/transcription/transcription_test.go @@ -0,0 +1,242 @@ +package transcription_test + +import ( + "encoding/json" + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/args/tiktok/transcription" + "github.com/masa-finance/tee-worker/api/types" +) + +var _ = Describe("TikTokTranscriptionArguments", func() { + Describe("Marshalling and unmarshalling", func() { + It("should unmarshal valid arguments", func() { + var args transcription.Arguments + jsonData := []byte(`{"type":"transcription","video_url":"https://tiktok.com/@user/video/123","language":"en-us"}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.VideoURL).To(Equal("https://tiktok.com/@user/video/123")) + Expect(args.Language).To(Equal("en-us")) + }) + + It("should unmarshal valid arguments without language", func() { + var args transcription.Arguments + jsonData := []byte(`{"type":"transcription","video_url":"https://tiktok.com/@user/video/123"}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.VideoURL).To(Equal("https://tiktok.com/@user/video/123")) + Expect(args.Language).To(Equal("eng-US")) // Default language should be set + }) + + It("should fail unmarshal with invalid JSON", func() { + var args transcription.Arguments + jsonData := []byte(`{"type":"transcription","video_url":"https://tiktok.com/@user/video/123"`) + err := json.Unmarshal(jsonData, &args) + Expect(err).To(HaveOccurred()) + }) + + It("should fail unmarshal when video_url is missing", func() { + var args transcription.Arguments + jsonData := []byte(`{"type":"transcription","language":"en-us"}`) + err := json.Unmarshal(jsonData, &args) + Expect(errors.Is(err, transcription.ErrVideoURLRequired)).To(BeTrue()) + }) + }) + + Describe("Validation", func() { + It("should succeed with valid arguments", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "en-us" + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should succeed with valid arguments without language", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "" + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail when video_url is missing", func() { + args := transcription.NewArguments() + args.Language = "en-us" + err := args.Validate() + Expect(errors.Is(err, transcription.ErrVideoURLRequired)).To(BeTrue()) + }) + + It("should fail with an invalid URL format", func() { + args := transcription.NewArguments() + args.VideoURL = "not-a-url" + args.Language = "en-us" + err := args.Validate() + Expect(errors.Is(err, transcription.ErrInvalidTikTokURL)).To(BeTrue()) + }) + + It("should fail with non-TikTok URL", func() { + args := transcription.NewArguments() + args.VideoURL = "https://youtube.com/watch?v=123" + args.Language = "en-us" + err := args.Validate() + Expect(errors.Is(err, transcription.ErrInvalidTikTokURL)).To(BeTrue()) + }) + + It("should fail with invalid language code format", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "invalid" + err := args.Validate() + Expect(errors.Is(err, transcription.ErrInvalidLanguageCode)).To(BeTrue()) + }) + }) + + Describe("TikTok URL validation", func() { + It("should accept tiktok.com URLs", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should accept www.tiktok.com URLs", func() { + args := transcription.NewArguments() + args.VideoURL = "https://www.tiktok.com/@user/video/123" + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should accept m.tiktok.com URLs", func() { + args := transcription.NewArguments() + args.VideoURL = "https://m.tiktok.com/@user/video/123" + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should reject non-TikTok URLs", func() { + args := transcription.NewArguments() + args.VideoURL = "https://youtube.com/watch?v=123" + err := args.Validate() + Expect(errors.Is(err, transcription.ErrInvalidTikTokURL)).To(BeTrue()) + }) + }) + + Describe("Language code validation", func() { + It("should accept valid 2-letter language codes", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "en-us" + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should accept valid 3-letter language codes", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "eng-us" + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should accept mixed case language codes", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "EN-US" + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should reject invalid language format", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "english" + err := args.Validate() + Expect(errors.Is(err, transcription.ErrInvalidLanguageCode)).To(BeTrue()) + }) + + It("should reject too many parts", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "en-us-extra" + err := args.Validate() + Expect(errors.Is(err, transcription.ErrInvalidLanguageCode)).To(BeTrue()) + }) + + It("should reject too few parts", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "en" + err := args.Validate() + Expect(errors.Is(err, transcription.ErrInvalidLanguageCode)).To(BeTrue()) + }) + + It("should reject invalid region length", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "en-usa" + err := args.Validate() + Expect(errors.Is(err, transcription.ErrInvalidLanguageCode)).To(BeTrue()) + }) + + It("should reject invalid language length", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "english-us" + err := args.Validate() + Expect(errors.Is(err, transcription.ErrInvalidLanguageCode)).To(BeTrue()) + }) + }) + + Describe("Job capability", func() { + It("should return the transcription capability", func() { + args := transcription.NewArguments() + Expect(args.GetCapability()).To(Equal(types.CapTranscription)) + }) + + It("should validate capability for TiktokJob", func() { + args := transcription.NewArguments() + args.VideoURL = "https://tiktok.com/@user/video/123" + args.Language = "en-us" + err := args.ValidateCapability(types.TiktokJob) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Describe("Helper methods", func() { + It("should return true when language preference is set", func() { + args := transcription.NewArguments() + args.Language = "en-us" + Expect(args.HasLanguagePreference()).To(BeTrue()) + }) + + It("should return false when language preference is not set", func() { + args := transcription.NewArguments() + args.Language = "" + Expect(args.HasLanguagePreference()).To(BeFalse()) + }) + + It("should return the language code when set", func() { + args := transcription.NewArguments() + args.Language = "en-us" + Expect(args.GetLanguageCode()).To(Equal("en-us")) + }) + + It("should return default language code when not set", func() { + args := transcription.NewArguments() + args.Language = "" + args.SetDefaultValues() + Expect(args.GetLanguageCode()).To(Equal("eng-US")) + }) + + It("should return the video URL", func() { + expected := "https://tiktok.com/@user/video/123" + args := transcription.NewArguments() + args.VideoURL = expected + Expect(args.GetVideoURL()).To(Equal(expected)) + }) + }) +}) diff --git a/api/args/tiktok/trending/trending.go b/api/args/tiktok/trending/trending.go new file mode 100644 index 00000000..00435e75 --- /dev/null +++ b/api/args/tiktok/trending/trending.go @@ -0,0 +1,122 @@ +package trending + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/masa-finance/tee-worker/api/args/base" + "github.com/masa-finance/tee-worker/api/types" +) + +var ( + ErrTrendingCountryCodeRequired = errors.New("country_code is required") + ErrTrendingSortByRequired = errors.New("sort_by is required") + ErrTrendingPeriodRequired = errors.New("period is required") + ErrTrendingMaxItemsNegative = errors.New("max_items must be non-negative") + ErrUnmarshalling = errors.New("failed to unmarshal TikTok searchbytrending arguments") +) + +// Period constants for TikTok trending search +const ( + periodWeek string = "7" + periodMonth string = "30" +) + +const ( + sortTrending string = "vv" + sortLike string = "like" + sortComment string = "comment" + sortRepost string = "repost" +) + +// Verify interface implementation +var _ base.JobArgument = (*Arguments)(nil) + +// Arguments defines args for lexis-solutions/tiktok-trending-videos-scraper +type Arguments struct { + Type types.Capability `json:"type"` + CountryCode string `json:"country_code,omitempty"` + SortBy string `json:"sort_by,omitempty"` + MaxItems int `json:"max_items,omitempty"` + Period string `json:"period,omitempty"` +} + +func (t *Arguments) UnmarshalJSON(data []byte) error { + type Alias Arguments + aux := &struct{ *Alias }{Alias: (*Alias)(t)} + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("%w: %w", ErrUnmarshalling, err) + } + t.SetDefaultValues() + return t.Validate() +} + +func (a *Arguments) SetDefaultValues() { + if a.CountryCode == "" { + a.CountryCode = "US" + } + if a.SortBy == "" { + a.SortBy = sortTrending + } + if a.Period == "" { + a.Period = periodWeek + } +} + +func (t *Arguments) GetCapability() types.Capability { + return t.Type +} + +func (t *Arguments) ValidateCapability(jobType types.JobType) error { + return jobType.ValidateCapability(&t.Type) +} + +func (t *Arguments) Validate() error { + err := t.ValidateCapability(types.TiktokJob) + if err != nil { + return err + } + allowedSorts := map[string]struct{}{ + sortTrending: {}, sortLike: {}, sortComment: {}, sortRepost: {}, + } + + allowedPeriods := map[string]struct{}{ + periodWeek: {}, + periodMonth: {}, + } + + allowedCountries := map[string]struct{}{ + "AU": {}, "BR": {}, "CA": {}, "EG": {}, "FR": {}, "DE": {}, "ID": {}, "IL": {}, "IT": {}, "JP": {}, + "MY": {}, "PH": {}, "RU": {}, "SA": {}, "SG": {}, "KR": {}, "ES": {}, "TW": {}, "TH": {}, "TR": {}, + "AE": {}, "GB": {}, "US": {}, "VN": {}, + } + + if _, ok := allowedCountries[strings.ToUpper(t.CountryCode)]; !ok { + return fmt.Errorf("%w: '%s'", ErrTrendingCountryCodeRequired, t.CountryCode) + } + if _, ok := allowedSorts[strings.ToLower(t.SortBy)]; !ok { + return fmt.Errorf("%w: '%s'", ErrTrendingSortByRequired, t.SortBy) + } + if _, ok := allowedPeriods[t.Period]; !ok { + // Extract keys for error message + var validKeys []string + for key := range allowedPeriods { + validKeys = append(validKeys, key) + } + return fmt.Errorf("%w: '%s' (allowed: %s)", ErrTrendingPeriodRequired, t.Period, strings.Join(validKeys, ", ")) + } + if t.MaxItems < 0 { + return fmt.Errorf("%w, got: %d", ErrTrendingMaxItemsNegative, t.MaxItems) + } + return nil +} + +func NewArguments() Arguments { + args := Arguments{ + Type: types.CapSearchByTrending, + } + args.SetDefaultValues() + return args +} diff --git a/api/args/tiktok/trending/trending_suite_test.go b/api/args/tiktok/trending/trending_suite_test.go new file mode 100644 index 00000000..29eee1ec --- /dev/null +++ b/api/args/tiktok/trending/trending_suite_test.go @@ -0,0 +1,13 @@ +package trending_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestArgs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Args Suite") +} diff --git a/api/args/tiktok/trending/trending_test.go b/api/args/tiktok/trending/trending_test.go new file mode 100644 index 00000000..cf95b07d --- /dev/null +++ b/api/args/tiktok/trending/trending_test.go @@ -0,0 +1,422 @@ +package trending_test + +import ( + "encoding/json" + "errors" + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/args/tiktok/trending" + "github.com/masa-finance/tee-worker/api/types" +) + +var _ = Describe("TikTokTrendingArguments", func() { + Describe("Marshalling and unmarshalling", func() { + It("should unmarshal valid arguments with all fields", func() { + var args trending.Arguments + jsonData := []byte(`{"type":"searchbytrending","country_code":"US","sort_by":"vv","max_items":50,"period":"7"}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.Type).To(Equal(types.CapSearchByTrending)) + Expect(args.CountryCode).To(Equal("US")) + Expect(args.SortBy).To(Equal("vv")) + Expect(args.MaxItems).To(Equal(50)) + Expect(args.Period).To(Equal("7")) + }) + + It("should unmarshal valid arguments with minimal fields", func() { + var args trending.Arguments + jsonData := []byte(`{"type":"searchbytrending"}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.Type).To(Equal(types.CapSearchByTrending)) + Expect(args.CountryCode).To(Equal("US")) // Default + Expect(args.SortBy).To(Equal("vv")) // Default + Expect(args.Period).To(Equal("7")) // Default + Expect(args.MaxItems).To(Equal(0)) // No default for MaxItems + }) + + It("should fail unmarshal with invalid JSON", func() { + var args trending.Arguments + jsonData := []byte(`{"type":"searchbytrending","country_code":"US"`) + err := json.Unmarshal(jsonData, &args) + Expect(err).To(HaveOccurred()) + }) + + It("should fail unmarshal with invalid country code", func() { + var args trending.Arguments + jsonData := []byte(`{"type":"searchbytrending","country_code":"INVALID"}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).To(HaveOccurred()) + Expect(strings.Contains(err.Error(), "country_code is required")).To(BeTrue()) + }) + + It("should fail unmarshal with invalid sort_by", func() { + var args trending.Arguments + jsonData := []byte(`{"type":"searchbytrending","sort_by":"invalid"}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).To(HaveOccurred()) + Expect(strings.Contains(err.Error(), "sort_by is required")).To(BeTrue()) + }) + + It("should fail unmarshal with invalid period", func() { + var args trending.Arguments + jsonData := []byte(`{"type":"searchbytrending","period":"invalid"}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).To(HaveOccurred()) + Expect(strings.Contains(err.Error(), "period is required")).To(BeTrue()) + }) + + It("should fail unmarshal with negative max_items", func() { + var args trending.Arguments + jsonData := []byte(`{"type":"searchbytrending","max_items":-1}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).To(HaveOccurred()) + Expect(strings.Contains(err.Error(), "max_items must be non-negative")).To(BeTrue()) + }) + }) + + Describe("Validation", func() { + It("should succeed with valid arguments", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + MaxItems: 50, + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail with invalid country code", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "INVALID", + SortBy: "vv", + MaxItems: 50, + Period: "7", + } + err := args.Validate() + Expect(errors.Is(err, trending.ErrTrendingCountryCodeRequired)).To(BeTrue()) + }) + + It("should fail with invalid sort_by", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "invalid", + MaxItems: 50, + Period: "7", + } + err := args.Validate() + Expect(errors.Is(err, trending.ErrTrendingSortByRequired)).To(BeTrue()) + }) + + It("should fail with invalid period", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + MaxItems: 50, + Period: "invalid", + } + err := args.Validate() + Expect(errors.Is(err, trending.ErrTrendingPeriodRequired)).To(BeTrue()) + }) + + It("should fail with negative max_items", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + MaxItems: -1, + Period: "7", + } + err := args.Validate() + Expect(errors.Is(err, trending.ErrTrendingMaxItemsNegative)).To(BeTrue()) + }) + }) + + Describe("Default values", func() { + It("should set default country_code when not provided", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + SortBy: "vv", + Period: "7", + } + args.SetDefaultValues() + Expect(args.CountryCode).To(Equal("US")) + }) + + It("should set default sort_by when not provided", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + Period: "7", + } + args.SetDefaultValues() + Expect(args.SortBy).To(Equal("vv")) + }) + + It("should set default period when not provided", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + } + args.SetDefaultValues() + Expect(args.Period).To(Equal("7")) + }) + + It("should not override existing values", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "CA", + SortBy: "like", + Period: "30", + } + args.SetDefaultValues() + Expect(args.CountryCode).To(Equal("CA")) + Expect(args.SortBy).To(Equal("like")) + Expect(args.Period).To(Equal("30")) + }) + }) + + Describe("Country code validation", func() { + It("should accept valid country codes", func() { + validCountries := []string{"US", "CA", "GB", "AU", "DE", "FR", "JP", "KR", "BR"} + for _, country := range validCountries { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: country, + SortBy: "vv", + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred(), "Country %s should be valid", country) + } + }) + + It("should accept lowercase country codes", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "us", + SortBy: "vv", + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should reject invalid country codes", func() { + invalidCountries := []string{"INVALID", "XX", "123", ""} + for _, country := range invalidCountries { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: country, + SortBy: "vv", + Period: "7", + } + err := args.Validate() + Expect(err).To(HaveOccurred(), "Country %s should be invalid", country) + } + }) + }) + + Describe("Sort by validation", func() { + It("should accept valid sort options", func() { + validSorts := []string{"vv", "like", "comment", "repost"} + for _, sort := range validSorts { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: sort, + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred(), "Sort %s should be valid", sort) + } + }) + + It("should accept uppercase sort options", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "LIKE", + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should reject invalid sort options", func() { + invalidSorts := []string{"invalid", "views", "likes", ""} + for _, sort := range invalidSorts { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: sort, + Period: "7", + } + err := args.Validate() + Expect(err).To(HaveOccurred(), "Sort %s should be invalid", sort) + } + }) + }) + + Describe("Period validation", func() { + It("should accept valid periods", func() { + validPeriods := []string{"7", "30"} + for _, period := range validPeriods { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + Period: period, + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred(), "Period %s should be valid", period) + } + }) + + It("should reject invalid periods", func() { + invalidPeriods := []string{"1", "14", "60", "invalid", ""} + for _, period := range invalidPeriods { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + Period: period, + } + err := args.Validate() + Expect(err).To(HaveOccurred(), "Period %s should be invalid", period) + } + }) + }) + + Describe("MaxItems validation", func() { + It("should accept zero max_items", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + MaxItems: 0, + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should accept positive max_items", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + MaxItems: 100, + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should reject negative max_items", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + MaxItems: -1, + Period: "7", + } + err := args.Validate() + Expect(errors.Is(err, trending.ErrTrendingMaxItemsNegative)).To(BeTrue()) + }) + }) + + Describe("Job capability", func() { + It("should return the searchbytrending capability", func() { + args := trending.NewArguments() + Expect(args.GetCapability()).To(Equal(types.CapSearchByTrending)) + }) + + It("should validate capability for TiktokJob", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + MaxItems: 50, + Period: "7", + } + err := args.ValidateCapability(types.TiktokJob) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail validation for incompatible job type", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + MaxItems: 50, + Period: "7", + } + err := args.ValidateCapability(types.TwitterJob) + Expect(err).To(HaveOccurred()) + }) + }) + + Describe("Edge cases", func() { + It("should handle mixed case country codes", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "us", + SortBy: "vv", + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should handle mixed case sort options", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "LIKE", + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should handle large max_items values", func() { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: "US", + SortBy: "vv", + MaxItems: 10000, + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should handle all supported countries", func() { + supportedCountries := []string{ + "AU", "BR", "CA", "EG", "FR", "DE", "ID", "IL", "IT", "JP", + "MY", "PH", "RU", "SA", "SG", "KR", "ES", "TW", "TH", "TR", + "AE", "GB", "US", "VN", + } + for _, country := range supportedCountries { + args := &trending.Arguments{ + Type: types.CapSearchByTrending, + CountryCode: country, + SortBy: "vv", + Period: "7", + } + err := args.Validate() + Expect(err).ToNot(HaveOccurred(), "Country %s should be supported", country) + } + }) + }) +}) diff --git a/api/args/twitter/search/search.go b/api/args/twitter/search/search.go new file mode 100644 index 00000000..ef9643d0 --- /dev/null +++ b/api/args/twitter/search/search.go @@ -0,0 +1,125 @@ +package search + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/masa-finance/tee-worker/api/args/base" + "github.com/masa-finance/tee-worker/api/types" +) + +var ( + ErrCountNegative = errors.New("count must be non-negative") + ErrCountTooLarge = errors.New("count must be less than or equal to 1000") + ErrMaxResultsTooLarge = errors.New("max_results must be less than or equal to 1000") + ErrMaxResultsNegative = errors.New("max_results must be non-negative") + ErrUnmarshalling = errors.New("failed to unmarshal twitter search arguments") +) + +const ( + MaxResults = 1000 +) + +// Verify interface implementation +var _ base.JobArgument = (*Arguments)(nil) + +// Arguments defines args for Twitter searches +type Arguments struct { + Type types.Capability `json:"type"` + Query string `json:"query"` // Username or search query + Count int `json:"count"` + StartTime string `json:"start_time"` // Optional ISO timestamp + EndTime string `json:"end_time"` // Optional ISO timestamp + MaxResults int `json:"max_results"` // Optional, max number of results + NextCursor string `json:"next_cursor"` +} + +func (t *Arguments) UnmarshalJSON(data []byte) error { + type Alias Arguments + aux := &struct{ *Alias }{Alias: (*Alias)(t)} + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("%w: %w", ErrUnmarshalling, err) + } + t.SetDefaultValues() + return t.Validate() +} + +// SetDefaultValues sets default values for the arguments +func (t *Arguments) SetDefaultValues() { + if t.MaxResults == 0 { + t.MaxResults = MaxResults + } +} + +// Validate validates the arguments (general validation) +func (t *Arguments) Validate() error { + // note, query is not required for all capabilities + err := t.ValidateCapability(types.TwitterJob) + if err != nil { + return err + } + if t.Count < 0 { + return fmt.Errorf("%w, got: %d", ErrCountNegative, t.Count) + } + if t.Count > MaxResults { + return fmt.Errorf("%w, got: %d", ErrCountTooLarge, t.Count) + } + if t.MaxResults < 0 { + return fmt.Errorf("%w, got: %d", ErrMaxResultsNegative, t.MaxResults) + } + if t.MaxResults > MaxResults { + return fmt.Errorf("%w, got: %d", ErrMaxResultsTooLarge, t.MaxResults) + } + + return nil +} + +func (t *Arguments) GetCapability() types.Capability { + return t.Type +} + +func (t *Arguments) ValidateCapability(jobType types.JobType) error { + return jobType.ValidateCapability(&t.Type) +} + +func (t *Arguments) IsSingleTweetOperation() bool { + return t.GetCapability() == types.CapGetById +} + +func (t *Arguments) IsMultipleTweetOperation() bool { + c := t.GetCapability() + return c == types.CapSearchByQuery || + c == types.CapSearchByFullArchive || + c == types.CapGetTweets || + c == types.CapGetReplies || + c == types.CapGetMedia +} + +func (t *Arguments) IsSingleProfileOperation() bool { + c := t.GetCapability() + return c == types.CapGetProfileById || + c == types.CapSearchByProfile +} + +func (t *Arguments) IsMultipleProfileOperation() bool { + c := t.GetCapability() + return c == types.CapGetFollowing || + c == types.CapGetFollowers || + c == types.CapGetRetweeters +} + +func (t *Arguments) IsSingleSpaceOperation() bool { + return t.GetCapability() == types.CapGetSpace +} + +func (t *Arguments) IsTrendsOperation() bool { + return t.GetCapability() == types.CapGetTrends +} + +func NewArguments() Arguments { + args := Arguments{} + args.SetDefaultValues() + args.Validate() // This will set the default capability via ValidateCapability + return args +} diff --git a/api/args/twitter/search/search_suite_test.go b/api/args/twitter/search/search_suite_test.go new file mode 100644 index 00000000..688d8ec8 --- /dev/null +++ b/api/args/twitter/search/search_suite_test.go @@ -0,0 +1,13 @@ +package search_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestArgs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Args Suite") +} diff --git a/api/args/twitter/search/search_test.go b/api/args/twitter/search/search_test.go new file mode 100644 index 00000000..2be7f9bb --- /dev/null +++ b/api/args/twitter/search/search_test.go @@ -0,0 +1,297 @@ +package search_test + +import ( + "encoding/json" + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/args/twitter/search" + "github.com/masa-finance/tee-worker/api/types" +) + +var _ = Describe("TwitterSearchArguments", func() { + Describe("Marshalling and unmarshalling", func() { + It("should unmarshal valid arguments with all fields", func() { + var args search.Arguments + jsonData := []byte(`{ + "type": "searchbyquery", + "query": "test query", + "count": 50, + "start_time": "2023-01-01T00:00:00Z", + "end_time": "2023-12-31T23:59:59Z", + "max_results": 100, + "next_cursor": "cursor123" + }`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.Query).To(Equal("test query")) + Expect(args.Count).To(Equal(50)) + Expect(args.StartTime).To(Equal("2023-01-01T00:00:00Z")) + Expect(args.EndTime).To(Equal("2023-12-31T23:59:59Z")) + Expect(args.MaxResults).To(Equal(100)) + Expect(args.NextCursor).To(Equal("cursor123")) + }) + + It("should unmarshal valid arguments with minimal fields", func() { + var args search.Arguments + jsonData := []byte(`{ + "type": "searchbyquery", + "query": "minimal test" + }`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + Expect(args.Query).To(Equal("minimal test")) + Expect(args.Count).To(Equal(0)) + Expect(args.MaxResults).To(Equal(1000)) // SetDefaultValues() sets this to MaxResults + }) + + It("should fail unmarshal with invalid JSON", func() { + var args search.Arguments + jsonData := []byte(`{"type":"searchbyquery","query":"test"`) + err := json.Unmarshal(jsonData, &args) + Expect(err).To(HaveOccurred()) + // The error is a JSON syntax error, not wrapped with ErrUnmarshalling + // since the JSON is malformed before reaching the custom UnmarshalJSON method + }) + + It("should set default values after unmarshalling", func() { + var args search.Arguments + jsonData := []byte(`{"type":"searchbyquery","query":"test"}`) + err := json.Unmarshal(jsonData, &args) + Expect(err).ToNot(HaveOccurred()) + // Default values should be set by SetDefaultValues() + Expect(args.GetCapability()).To(Equal(types.CapSearchByQuery)) + }) + }) + + Describe("Validation", func() { + It("should succeed with valid arguments", func() { + args := search.NewArguments() + args.Query = "test query" + args.Count = 50 + args.MaxResults = 100 + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail when count is negative", func() { + args := search.NewArguments() + args.Query = "test query" + args.Count = -1 + err := args.Validate() + Expect(errors.Is(err, search.ErrCountNegative)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("got: -1")) + }) + + It("should fail when count exceeds maximum", func() { + args := search.NewArguments() + args.Query = "test query" + args.Count = 1001 + err := args.Validate() + Expect(errors.Is(err, search.ErrCountTooLarge)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("got: 1001")) + }) + + It("should fail when max_results is negative", func() { + args := search.NewArguments() + args.Query = "test query" + args.MaxResults = -1 + err := args.Validate() + Expect(errors.Is(err, search.ErrMaxResultsNegative)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("got: -1")) + }) + + It("should fail when max_results exceeds maximum", func() { + args := search.NewArguments() + args.Query = "test query" + args.MaxResults = 1001 + err := args.Validate() + Expect(errors.Is(err, search.ErrMaxResultsTooLarge)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("got: 1001")) + }) + + It("should succeed with count at maximum boundary", func() { + args := search.NewArguments() + args.Query = "test query" + args.Count = 1000 + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should succeed with max_results at maximum boundary", func() { + args := search.NewArguments() + args.Query = "test query" + args.MaxResults = 1000 + err := args.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Describe("Operation Type Detection", func() { + Context("Single Tweet Operations", func() { + It("should identify getbyid as single tweet operation", func() { + args := search.NewArguments() + args.Type = types.CapGetById + Expect(args.IsSingleTweetOperation()).To(BeTrue()) + }) + + It("should not identify searchbyquery as single tweet operation", func() { + args := search.NewArguments() + // Type is already CapSearchByQuery from NewArguments() + Expect(args.IsSingleTweetOperation()).To(BeFalse()) + }) + }) + + Context("Multiple Tweet Operations", func() { + It("should identify searchbyquery as multiple tweet operation", func() { + args := search.NewArguments() + // Type is already CapSearchByQuery from NewArguments() + Expect(args.IsMultipleTweetOperation()).To(BeTrue()) + }) + + It("should identify searchbyfullarchive as multiple tweet operation", func() { + args := search.NewArguments() + args.Type = types.CapSearchByFullArchive + Expect(args.IsMultipleTweetOperation()).To(BeTrue()) + }) + + It("should identify gettweets as multiple tweet operation", func() { + args := search.NewArguments() + args.Type = types.CapGetTweets + Expect(args.IsMultipleTweetOperation()).To(BeTrue()) + }) + + It("should identify getreplies as multiple tweet operation", func() { + args := search.NewArguments() + args.Type = types.CapGetReplies + Expect(args.IsMultipleTweetOperation()).To(BeTrue()) + }) + + It("should identify getmedia as multiple tweet operation", func() { + args := search.NewArguments() + args.Type = types.CapGetMedia + Expect(args.IsMultipleTweetOperation()).To(BeTrue()) + }) + + It("should not identify getbyid as multiple tweet operation", func() { + args := search.NewArguments() + args.Type = types.CapGetById + Expect(args.IsMultipleTweetOperation()).To(BeFalse()) + }) + }) + + Context("Single Profile Operations", func() { + It("should identify getprofilebyid as single profile operation", func() { + args := search.NewArguments() + args.Type = types.CapGetProfileById + Expect(args.IsSingleProfileOperation()).To(BeTrue()) + }) + + It("should identify searchbyprofile as single profile operation", func() { + args := search.NewArguments() + args.Type = types.CapSearchByProfile + Expect(args.IsSingleProfileOperation()).To(BeTrue()) + }) + + It("should not identify getfollowers as single profile operation", func() { + args := search.NewArguments() + args.Type = types.CapGetFollowers + Expect(args.IsSingleProfileOperation()).To(BeFalse()) + }) + }) + + Context("Multiple Profile Operations", func() { + It("should identify getfollowing as multiple profile operation", func() { + args := search.NewArguments() + args.Type = types.CapGetFollowing + Expect(args.IsMultipleProfileOperation()).To(BeTrue()) + }) + + It("should identify getfollowers as multiple profile operation", func() { + args := search.NewArguments() + args.Type = types.CapGetFollowers + Expect(args.IsMultipleProfileOperation()).To(BeTrue()) + }) + + It("should identify getretweeters as multiple profile operation", func() { + args := search.NewArguments() + args.Type = types.CapGetRetweeters + Expect(args.IsMultipleProfileOperation()).To(BeTrue()) + }) + + It("should not identify getprofilebyid as multiple profile operation", func() { + args := search.NewArguments() + args.Type = types.CapGetProfileById + Expect(args.IsMultipleProfileOperation()).To(BeFalse()) + }) + }) + + Context("Single Space Operations", func() { + It("should identify getspace as single space operation", func() { + args := search.NewArguments() + args.Type = types.CapGetSpace + Expect(args.IsSingleSpaceOperation()).To(BeTrue()) + }) + + It("should not identify searchbyquery as single space operation", func() { + args := search.NewArguments() + // Type is already CapSearchByQuery from NewArguments() + Expect(args.IsSingleSpaceOperation()).To(BeFalse()) + }) + }) + + Context("Trends Operations", func() { + It("should identify gettrends as trends operation", func() { + args := search.NewArguments() + args.Type = types.CapGetTrends + Expect(args.IsTrendsOperation()).To(BeTrue()) + }) + + It("should not identify searchbyquery as trends operation", func() { + args := search.NewArguments() + // Type is already CapSearchByQuery from NewArguments() + Expect(args.IsTrendsOperation()).To(BeFalse()) + }) + }) + }) + + Describe("Constants and Error Values", func() { + It("should have correct MaxResults constant", func() { + Expect(search.MaxResults).To(Equal(1000)) + }) + + It("should have correct error messages", func() { + Expect(search.ErrCountNegative.Error()).To(Equal("count must be non-negative")) + Expect(search.ErrCountTooLarge.Error()).To(Equal("count must be less than or equal to 1000")) + Expect(search.ErrMaxResultsTooLarge.Error()).To(Equal("max_results must be less than or equal to 1000")) + Expect(search.ErrMaxResultsNegative.Error()).To(Equal("max_results must be non-negative")) + Expect(search.ErrUnmarshalling.Error()).To(Equal("failed to unmarshal twitter search arguments")) + }) + }) + + Describe("JSON Marshalling", func() { + It("should marshal arguments correctly", func() { + args := search.NewArguments() + args.Query = "test query" + args.Count = 50 + args.StartTime = "2023-01-01T00:00:00Z" + args.EndTime = "2023-12-31T23:59:59Z" + args.MaxResults = 100 + args.NextCursor = "cursor123" + jsonData, err := json.Marshal(args) + Expect(err).ToNot(HaveOccurred()) + + var unmarshalled search.Arguments + err = json.Unmarshal(jsonData, &unmarshalled) + Expect(err).ToNot(HaveOccurred()) + Expect(unmarshalled.Query).To(Equal(args.Query)) + Expect(unmarshalled.Count).To(Equal(args.Count)) + Expect(unmarshalled.StartTime).To(Equal(args.StartTime)) + Expect(unmarshalled.EndTime).To(Equal(args.EndTime)) + Expect(unmarshalled.MaxResults).To(Equal(args.MaxResults)) + Expect(unmarshalled.NextCursor).To(Equal(args.NextCursor)) + }) + }) +}) diff --git a/api/args/twitter/twitter.go b/api/args/twitter/twitter.go new file mode 100644 index 00000000..819c7da5 --- /dev/null +++ b/api/args/twitter/twitter.go @@ -0,0 +1,7 @@ +package twitter + +import ( + "github.com/masa-finance/tee-worker/api/args/twitter/search" +) + +type Search = search.Arguments diff --git a/api/args/unmarshaller.go b/api/args/unmarshaller.go new file mode 100644 index 00000000..a8bf65fd --- /dev/null +++ b/api/args/unmarshaller.go @@ -0,0 +1,138 @@ +package args + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/masa-finance/tee-worker/api/args/base" + "github.com/masa-finance/tee-worker/api/args/linkedin" + "github.com/masa-finance/tee-worker/api/args/reddit" + "github.com/masa-finance/tee-worker/api/args/telemetry" + "github.com/masa-finance/tee-worker/api/args/tiktok" + "github.com/masa-finance/tee-worker/api/args/twitter" + "github.com/masa-finance/tee-worker/api/args/web" + "github.com/masa-finance/tee-worker/api/types" +) + +var ( + ErrUnknownJobType = errors.New("unknown job type") + ErrUnknownCapability = errors.New("unknown capability") + ErrFailedToUnmarshal = errors.New("failed to unmarshal job arguments") + ErrFailedToMarshal = errors.New("failed to marshal job arguments") +) + +type Args = map[string]any + +// UnmarshalJobArguments unmarshals job arguments from a generic map into the appropriate typed struct +// This works with both tee-indexer and tee-worker JobArgument types +func UnmarshalJobArguments(jobType types.JobType, args Args) (base.JobArgument, error) { + switch jobType { + case types.WebJob: + return unmarshalWebArguments(args) + + case types.TiktokJob: + return unmarshalTikTokArguments(args) + + case types.TwitterJob: + return unmarshalTwitterArguments(args) + + case types.LinkedInJob: + return unmarshalLinkedInArguments(args) + + case types.RedditJob: + return unmarshalRedditArguments(args) + + case types.TelemetryJob: + return unmarshalTelemetryArguments(args) + + default: + return nil, fmt.Errorf("%w: %s", ErrUnknownJobType, jobType) + } +} + +// Helper functions for unmarshaling specific argument types +func unmarshalWebArguments(args Args) (*web.Page, error) { + webArgs := &web.Page{} + if err := unmarshalToStruct(args, webArgs); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToUnmarshal, err) + } + return webArgs, nil +} + +func unmarshalTikTokArguments(args Args) (base.JobArgument, error) { + minimal := base.Arguments{} + if err := unmarshalToStruct(args, &minimal); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToUnmarshal, err) + } + switch minimal.Type { + case types.CapSearchByQuery: + searchArgs := &tiktok.Query{} + if err := unmarshalToStruct(args, searchArgs); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToUnmarshal, err) + } + return searchArgs, nil + case types.CapSearchByTrending: + searchArgs := &tiktok.Trending{} + if err := unmarshalToStruct(args, searchArgs); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToUnmarshal, err) + } + return searchArgs, nil + case types.CapTranscription: + transcriptionArgs := &tiktok.Transcription{} + if err := unmarshalToStruct(args, transcriptionArgs); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToUnmarshal, err) + } + return transcriptionArgs, nil + default: + return nil, fmt.Errorf("%w: %s", ErrUnknownCapability, minimal.Type) + } +} + +func unmarshalTwitterArguments(args Args) (*twitter.Search, error) { + twitterArgs := &twitter.Search{} + if err := unmarshalToStruct(args, twitterArgs); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToUnmarshal, err) + } + return twitterArgs, nil +} + +func unmarshalLinkedInArguments(args Args) (*linkedin.Profile, error) { + linkedInArgs := &linkedin.Profile{} + if err := unmarshalToStruct(args, linkedInArgs); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToUnmarshal, err) + } + return linkedInArgs, nil +} + +func unmarshalRedditArguments(args Args) (*reddit.Search, error) { + redditArgs := &reddit.Search{} + if err := unmarshalToStruct(args, redditArgs); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToUnmarshal, err) + } + return redditArgs, nil +} + +func unmarshalTelemetryArguments(args Args) (*telemetry.Telemetry, error) { + telemetryArgs := &telemetry.Telemetry{} + if err := unmarshalToStruct(args, telemetryArgs); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToUnmarshal, err) + } + return telemetryArgs, nil +} + +// unmarshalToStruct converts a map[string]any to a struct using JSON marshal/unmarshal +// This provides the same functionality as the existing JobArgument.Unmarshal methods +func unmarshalToStruct(args Args, target any) error { + // Use JSON marshal/unmarshal for conversion - this triggers our custom UnmarshalJSON methods + data, err := json.Marshal(args) + if err != nil { + return fmt.Errorf("%w: %w", ErrFailedToMarshal, err) + } + + if err := json.Unmarshal(data, target); err != nil { + return fmt.Errorf("%w: %w", ErrFailedToUnmarshal, err) + } + + return nil +} diff --git a/api/args/unmarshaller_test.go b/api/args/unmarshaller_test.go new file mode 100644 index 00000000..b250cc72 --- /dev/null +++ b/api/args/unmarshaller_test.go @@ -0,0 +1,100 @@ +package args_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/args" + "github.com/masa-finance/tee-worker/api/args/reddit" + "github.com/masa-finance/tee-worker/api/args/telemetry" + "github.com/masa-finance/tee-worker/api/args/tiktok" + "github.com/masa-finance/tee-worker/api/args/twitter" + "github.com/masa-finance/tee-worker/api/args/web" + "github.com/masa-finance/tee-worker/api/types" +) + +var _ = Describe("Unmarshaller", func() { + Describe("UnmarshalJobArguments", func() { + Context("with a WebJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "url": "https://example.com", + "max_depth": 2, + } + jobArgs, err := args.UnmarshalJobArguments(types.WebJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + webArgs, ok := jobArgs.(*web.Page) + Expect(ok).To(BeTrue()) + Expect(webArgs.URL).To(Equal("https://example.com")) + Expect(webArgs.MaxDepth).To(Equal(2)) + }) + }) + + Context("with a TiktokJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "type": "transcription", + "video_url": "https://www.tiktok.com/@user/video/123", + "language": "en-us", + } + jobArgs, err := args.UnmarshalJobArguments(types.TiktokJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + tiktokArgs, ok := jobArgs.(*tiktok.Transcription) + Expect(ok).To(BeTrue()) + Expect(tiktokArgs.VideoURL).To(Equal("https://www.tiktok.com/@user/video/123")) + Expect(tiktokArgs.Language).To(Equal("en-us")) + }) + }) + + Context("with a TwitterJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "type": "searchbyquery", + "query": "golang", + "count": 10, + } + jobArgs, err := args.UnmarshalJobArguments(types.TwitterJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + twitterArgs, ok := jobArgs.(*twitter.Search) + Expect(ok).To(BeTrue()) + Expect(twitterArgs.Type).To(Equal(types.CapSearchByQuery)) + Expect(twitterArgs.Query).To(Equal("golang")) + Expect(twitterArgs.Count).To(Equal(10)) + }) + }) + + Context("with a RedditJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "type": "searchposts", + "queries": []string{"golang"}, + "sort": "new", + } + jobArgs, err := args.UnmarshalJobArguments(types.RedditJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + redditArgs, ok := jobArgs.(*reddit.Search) + Expect(ok).To(BeTrue()) + Expect(redditArgs.Type).To(Equal(types.CapSearchPosts)) + }) + }) + + Context("with a TelemetryJob", func() { + It("should return a TelemetryArguments struct", func() { + argsMap := map[string]any{} + jobArgs, err := args.UnmarshalJobArguments(types.TelemetryJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + _, ok := jobArgs.(*telemetry.Arguments) + Expect(ok).To(BeTrue()) + }) + }) + + Context("with an unknown job type", func() { + It("should return an error", func() { + argsMap := map[string]any{} + _, err := args.UnmarshalJobArguments("unknown", argsMap) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unknown job type")) + }) + }) + }) +}) diff --git a/api/args/web/page/page.go b/api/args/web/page/page.go new file mode 100644 index 00000000..8356a285 --- /dev/null +++ b/api/args/web/page/page.go @@ -0,0 +1,113 @@ +package page + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + + "github.com/masa-finance/tee-worker/api/args/base" + "github.com/masa-finance/tee-worker/api/types" +) + +var ( + ErrURLRequired = errors.New("url is required") + ErrURLInvalid = errors.New("invalid URL format") + ErrURLSchemeMissing = errors.New("url must include a scheme (http:// or https://)") + ErrMaxDepth = errors.New("max depth must be non-negative") + ErrMaxPages = errors.New("max pages must be at least 1") + ErrUnmarshalling = errors.New("failed to unmarshal web page arguments") +) + +const ( + DefaultMaxPages = 1 + DefaultMethod = "GET" + DefaultRespectRobotsTxtFile = false + DefaultSaveMarkdown = true +) + +// Verify interface implementation +var _ base.JobArgument = (*Arguments)(nil) + +type Arguments struct { + Type types.Capability `json:"type"` + URL string `json:"url"` + MaxDepth int `json:"max_depth"` + MaxPages int `json:"max_pages"` +} + +func (t *Arguments) UnmarshalJSON(data []byte) error { + type Alias Arguments + aux := &struct{ *Alias }{Alias: (*Alias)(t)} + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("%w: %w", ErrUnmarshalling, err) + } + t.SetDefaultValues() + return t.Validate() +} + +func (w *Arguments) SetDefaultValues() { + if w.MaxPages == 0 { + w.MaxPages = DefaultMaxPages + } +} + +// Validate validates the arguments +func (w *Arguments) Validate() error { + err := w.ValidateCapability(types.WebJob) + if err != nil { + return err + } + + if w.URL == "" { + return ErrURLRequired + } + + // Validate URL format + parsedURL, err := url.Parse(w.URL) + if err != nil { + return fmt.Errorf("%w: %v", ErrURLInvalid, err) + } + + // Ensure URL has a scheme + if parsedURL.Scheme == "" { + return ErrURLSchemeMissing + } + + if w.MaxDepth < 0 { + return fmt.Errorf("%w: got %v", ErrMaxDepth, w.MaxDepth) + } + + if w.MaxPages < 1 { + return fmt.Errorf("%w: got %v", ErrMaxPages, w.MaxPages) + } + + return nil +} + +func (w *Arguments) GetCapability() types.Capability { + return w.Type +} + +func (w *Arguments) ValidateCapability(jobType types.JobType) error { + return jobType.ValidateCapability(&w.Type) +} + +func (w Arguments) ToScraperRequest() types.WebScraperRequest { + return types.WebScraperRequest{ + StartUrls: []types.WebStartURL{ + {URL: w.URL, Method: DefaultMethod}, + }, + MaxCrawlDepth: w.MaxDepth, + MaxCrawlPages: w.MaxPages, + RespectRobotsTxtFile: DefaultRespectRobotsTxtFile, + SaveMarkdown: DefaultSaveMarkdown, + } +} + +func NewArguments() Arguments { + args := Arguments{} + args.SetDefaultValues() + args.Validate() // This will set the default capability via ValidateCapability + return args +} diff --git a/api/args/web/page/page_suite_test.go b/api/args/web/page/page_suite_test.go new file mode 100644 index 00000000..8da7b908 --- /dev/null +++ b/api/args/web/page/page_suite_test.go @@ -0,0 +1,13 @@ +package page_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestArgs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Args Suite") +} diff --git a/api/args/web/page/page_test.go b/api/args/web/page/page_test.go new file mode 100644 index 00000000..523c647d --- /dev/null +++ b/api/args/web/page/page_test.go @@ -0,0 +1,138 @@ +package page_test + +import ( + "encoding/json" + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/args/web/page" + "github.com/masa-finance/tee-worker/api/types" +) + +var _ = Describe("WebArguments", func() { + Describe("Marshalling and unmarshalling", func() { + It("should set default values", func() { + webArgs := page.NewArguments() + webArgs.URL = "https://example.com" + webArgs.MaxDepth = 0 + webArgs.MaxPages = 0 + jsonData, err := json.Marshal(webArgs) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal([]byte(jsonData), &webArgs) + Expect(err).ToNot(HaveOccurred()) + Expect(webArgs.MaxPages).To(Equal(1)) + }) + + It("should override default values", func() { + webArgs := page.NewArguments() + webArgs.URL = "https://example.com" + webArgs.MaxDepth = 2 + webArgs.MaxPages = 5 + jsonData, err := json.Marshal(webArgs) + Expect(err).ToNot(HaveOccurred()) + err = json.Unmarshal([]byte(jsonData), &webArgs) + Expect(err).ToNot(HaveOccurred()) + Expect(webArgs.MaxPages).To(Equal(5)) + }) + + It("should fail unmarshal when url is missing", func() { + var webArgs page.Arguments + jsonData := []byte(`{"type":"scraper","max_depth":1,"max_pages":1}`) + err := json.Unmarshal(jsonData, &webArgs) + Expect(errors.Is(err, page.ErrURLRequired)).To(BeTrue()) + }) + }) + + Describe("Validation", func() { + It("should succeed with valid arguments", func() { + webArgs := page.NewArguments() + webArgs.URL = "https://example.com" + webArgs.MaxDepth = 2 + webArgs.MaxPages = 3 + err := webArgs.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail when url is missing", func() { + webArgs := page.NewArguments() + webArgs.MaxDepth = 0 + webArgs.MaxPages = 1 + err := webArgs.Validate() + Expect(errors.Is(err, page.ErrURLRequired)).To(BeTrue()) + }) + + It("should fail with an invalid URL format", func() { + webArgs := page.NewArguments() + webArgs.URL = "http:// invalid.com" + webArgs.MaxDepth = 0 + webArgs.MaxPages = 1 + err := webArgs.Validate() + Expect(errors.Is(err, page.ErrURLInvalid)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("invalid URL format")) + }) + + It("should fail when scheme is missing", func() { + webArgs := page.NewArguments() + webArgs.URL = "example.com" + webArgs.MaxDepth = 0 + webArgs.MaxPages = 1 + err := webArgs.Validate() + Expect(errors.Is(err, page.ErrURLSchemeMissing)).To(BeTrue()) + }) + + It("should fail when max depth is negative", func() { + webArgs := page.NewArguments() + webArgs.URL = "https://example.com" + webArgs.MaxDepth = -1 + webArgs.MaxPages = 1 + err := webArgs.Validate() + Expect(errors.Is(err, page.ErrMaxDepth)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("got -1")) + }) + + It("should fail when max pages is less than 1", func() { + webArgs := page.NewArguments() + webArgs.URL = "https://example.com" + webArgs.MaxDepth = 0 + webArgs.MaxPages = 0 + err := webArgs.Validate() + Expect(errors.Is(err, page.ErrMaxPages)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("got 0")) + }) + }) + + Describe("Job capability", func() { + It("should return the scraper capability", func() { + webArgs := page.NewArguments() + Expect(webArgs.GetCapability()).To(Equal(types.CapScraper)) + }) + + It("should validate capability for WebJob", func() { + webArgs := page.NewArguments() + webArgs.URL = "https://example.com" + webArgs.MaxDepth = 1 + webArgs.MaxPages = 1 + err := webArgs.ValidateCapability(types.WebJob) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Describe("ToWebScraperRequest", func() { + It("should map fields correctly", func() { + webArgs := page.NewArguments() + webArgs.URL = "https://example.com" + webArgs.MaxDepth = 2 + webArgs.MaxPages = 3 + req := webArgs.ToScraperRequest() + Expect(req.StartUrls).To(HaveLen(1)) + Expect(req.StartUrls[0].URL).To(Equal("https://example.com")) + Expect(req.StartUrls[0].Method).To(Equal("GET")) + Expect(req.MaxCrawlDepth).To(Equal(2)) + Expect(req.MaxCrawlPages).To(Equal(3)) + Expect(req.RespectRobotsTxtFile).To(BeFalse()) + Expect(req.SaveMarkdown).To(BeTrue()) + }) + }) +}) diff --git a/api/args/web/web.go b/api/args/web/web.go new file mode 100644 index 00000000..5f897201 --- /dev/null +++ b/api/args/web/web.go @@ -0,0 +1,7 @@ +package web + +import ( + "github.com/masa-finance/tee-worker/api/args/web/page" +) + +type Page = page.Arguments diff --git a/api/types/encrypted.go b/api/tee/encrypted.go similarity index 79% rename from api/types/encrypted.go rename to api/tee/encrypted.go index 1d5f36ad..aab4ecb1 100644 --- a/api/types/encrypted.go +++ b/api/tee/encrypted.go @@ -1,24 +1,27 @@ -package types +package tee import ( "encoding/json" "fmt" + "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/pkg/tee" ) +// EncryptedRequest represents an encrypted request/response pair type EncryptedRequest struct { EncryptedResult string `json:"encrypted_result"` EncryptedRequest string `json:"encrypted_request"` } +// Unseal decrypts the encrypted request and result func (payload EncryptedRequest) Unseal() (string, error) { jobRequest, err := tee.Unseal(payload.EncryptedRequest) if err != nil { return "", fmt.Errorf("error while unsealing the encrypted request: %w", err) } - job := Job{} + job := types.Job{} if err := json.Unmarshal(jobRequest, &job); err != nil { return "", fmt.Errorf("error while unmarshalling the job request: %w", err) } @@ -30,7 +33,3 @@ func (payload EncryptedRequest) Unseal() (string, error) { return string(dat), nil } - -type JobError struct { - Error string `json:"error"` -} diff --git a/api/tee/job.go b/api/tee/job.go new file mode 100644 index 00000000..21db0077 --- /dev/null +++ b/api/tee/job.go @@ -0,0 +1,62 @@ +package tee + +import ( + "crypto/sha256" + "encoding/json" + "fmt" + "math/rand/v2" + + "github.com/masa-finance/tee-worker/api/types" + "github.com/masa-finance/tee-worker/pkg/tee" +) + +var letterRunes = []rune("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+") + +func randStringRunes(n int) string { + b := make([]rune, n) + for i := range b { + // TODO: Move xcrypt from indexer to tee-types, and use RandomString here (although we'll need a different alpahbet) + b[i] = letterRunes[rand.IntN(len(letterRunes))] + } + return string(b) +} + +// GenerateJobSignature generates a signature for the job. +func GenerateJobSignature(job *types.Job) (string, error) { + dat, err := json.Marshal(job) + if err != nil { + return "", err + } + + checksum := sha256.New() + checksum.Write(dat) + + job.Nonce = fmt.Sprintf("%s-%s", string(checksum.Sum(nil)), randStringRunes(99)) + + dat, err = json.Marshal(job) + if err != nil { + return "", err + } + + return tee.Seal(dat) +} + +// SealJobResult seals a job result with the job's nonce. +func SealJobResult(jr *types.JobResult) (string, error) { + return tee.SealWithKey(jr.Job.Nonce, jr.Data) +} + +// DecryptJob decrypts the job request. +func DecryptJob(jobRequest *types.JobRequest) (*types.Job, error) { + dat, err := tee.Unseal(jobRequest.EncryptedJob) + if err != nil { + return nil, err + } + + job := types.Job{} + if err := json.Unmarshal(dat, &job); err != nil { + return nil, err + } + + return &job, nil +} diff --git a/api/types/job.go b/api/types/job.go deleted file mode 100644 index 9010c3b5..00000000 --- a/api/types/job.go +++ /dev/null @@ -1,113 +0,0 @@ -package types - -import ( - "crypto/sha256" - "encoding/json" - "fmt" - "time" - - teetypes "github.com/masa-finance/tee-types/types" - "github.com/masa-finance/tee-worker/pkg/tee" - "golang.org/x/exp/rand" -) - -type JobArguments map[string]interface{} - -func (ja JobArguments) Unmarshal(i interface{}) error { - dat, err := json.Marshal(ja) - if err != nil { - return err - } - return json.Unmarshal(dat, i) -} - -type Job struct { - Type teetypes.JobType `json:"type"` - Arguments JobArguments `json:"arguments"` - UUID string `json:"-"` - Nonce string `json:"quote"` - WorkerID string `json:"worker_id"` - TargetWorker string `json:"target_worker"` - Timeout time.Duration `json:"timeout"` -} - -func (j Job) String() string { - return fmt.Sprintf("UUID: %s Type: %s Arguments: %s", j.UUID, j.Type, j.Arguments) -} - -var letterRunes = []rune("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+") - -func randStringRunes(n int) string { - b := make([]rune, n) - for i := range b { - // TODO: Move xcrypt from indexer to tee-types, and use RandomString here (although we'll need a different alpahbet) - b[i] = letterRunes[rand.Intn(len(letterRunes))] - } - return string(b) -} - -// GenerateJobSignature generates a signature for the job. -func (job *Job) GenerateJobSignature() (string, error) { - - dat, err := json.Marshal(job) - if err != nil { - return "", err - } - - checksum := sha256.New() - checksum.Write(dat) - - job.Nonce = fmt.Sprintf("%s-%s", string(checksum.Sum(nil)), randStringRunes(99)) - - dat, err = json.Marshal(job) - if err != nil { - return "", err - } - - return tee.Seal(dat) -} - -type JobResponse struct { - UID string `json:"uid"` -} - -type JobResult struct { - Error string `json:"error"` - Data []byte `json:"data"` - Job Job `json:"job"` - NextCursor string `json:"next_cursor"` -} - -// Success returns true if the job was successful. -func (jr JobResult) Success() bool { - return jr.Error == "" -} - -// Seal returns the sealed job result. -func (jr JobResult) Seal() (string, error) { - return tee.SealWithKey(jr.Job.Nonce, jr.Data) -} - -// Unmarshal unmarshals the job result data. -func (jr JobResult) Unmarshal(i interface{}) error { - return json.Unmarshal(jr.Data, i) -} - -type JobRequest struct { - EncryptedJob string `json:"encrypted_job"` -} - -// DecryptJob decrypts the job request. -func (jobRequest JobRequest) DecryptJob() (*Job, error) { - dat, err := tee.Unseal(jobRequest.EncryptedJob) - if err != nil { - return nil, err - } - - job := Job{} - if err := json.Unmarshal(dat, &job); err != nil { - return nil, err - } - - return &job, nil -} diff --git a/api/types/jobs.go b/api/types/jobs.go new file mode 100644 index 00000000..c28dff77 --- /dev/null +++ b/api/types/jobs.go @@ -0,0 +1,235 @@ +package types + +import ( + "encoding/json" + "fmt" + "slices" + "time" + + "github.com/masa-finance/tee-worker/pkg/util" +) + +type JobType string +type Capability string + +type WorkerCapabilities map[JobType][]Capability + +type JobArguments map[string]any + +func (ja JobArguments) Unmarshal(i any) error { + dat, err := json.Marshal(ja) + if err != nil { + return err + } + return json.Unmarshal(dat, i) +} + +type Job struct { + Type JobType `json:"type"` + Arguments JobArguments `json:"arguments"` + UUID string `json:"-"` + Nonce string `json:"quote"` + WorkerID string `json:"worker_id"` + TargetWorker string `json:"target_worker"` + Timeout time.Duration `json:"timeout"` +} + +func (j Job) String() string { + return fmt.Sprintf("UUID: %s Type: %s Arguments: %s", j.UUID, j.Type, j.Arguments) +} + +// String returns the string representation of the JobType +func (j JobType) String() string { + return string(j) +} + +// ValidateCapability validates that a capability is supported for this job type +// If the capability is CapEmpty, it will be set to the default capability for the job type +func (j JobType) ValidateCapability(capability *Capability) error { + // Set default capability if empty + if *capability == CapEmpty { + defaultCap, exists := JobDefaultCapabilityMap[j] + if !exists { + return fmt.Errorf("no default capability configured for job type: %s", j) + } + *capability = defaultCap + } + + // Validate the capability + validCaps, exists := JobCapabilityMap[j] + if !exists { + return fmt.Errorf("unknown job type: %s", j) + } + + if !slices.Contains(validCaps, *capability) { + return fmt.Errorf("capability '%s' is not valid for job type '%s'. valid capabilities: %v", + *capability, j, validCaps) + } + + return nil +} + +// combineCapabilities combines multiple capability slices and ensures uniqueness +func combineCapabilities(capSlices ...[]Capability) []Capability { + caps := util.NewSet[Capability]() + for _, capSlice := range capSlices { + caps.Add(capSlice...) + } + return caps.Items() +} + +// Job type constants - centralized from tee-indexer and tee-worker +const ( + WebJob JobType = "web" + TelemetryJob JobType = "telemetry" + TiktokJob JobType = "tiktok" + TwitterJob JobType = "twitter" + LinkedInJob JobType = "linkedin" + RedditJob JobType = "reddit" +) + +// Capability constants - typed to prevent typos and enable discoverability +const ( + + // Twitter (credential-based) capabilities + CapSearchByQuery Capability = "searchbyquery" + CapSearchByProfile Capability = "searchbyprofile" + CapGetById Capability = "getbyid" + CapGetReplies Capability = "getreplies" + CapGetRetweeters Capability = "getretweeters" + CapGetMedia Capability = "getmedia" + CapGetProfileById Capability = "getprofilebyid" + CapGetTrends Capability = "gettrends" + CapGetSpace Capability = "getspace" + CapGetProfile Capability = "getprofile" + CapGetTweets Capability = "gettweets" + + // Twitter (apify-based) capabilities + CapGetFollowing Capability = "getfollowing" + CapGetFollowers Capability = "getfollowers" + + // Twitter (api-based) capabilities + CapSearchByFullArchive Capability = "searchbyfullarchive" + + CapScraper Capability = "scraper" + CapSearchByTrending Capability = "searchbytrending" + CapTelemetry Capability = "telemetry" + CapTranscription Capability = "transcription" + + // Reddit capabilities + CapScrapeUrls Capability = "scrapeurls" + CapSearchPosts Capability = "searchposts" + CapSearchUsers Capability = "searchusers" + CapSearchCommunities Capability = "searchcommunities" + + CapEmpty Capability = "" +) + +// Capability group constants for easy reuse +var ( + AlwaysAvailableTelemetryCaps = []Capability{CapTelemetry, CapEmpty} + AlwaysAvailableTiktokCaps = []Capability{CapTranscription, CapEmpty} + + // AlwaysAvailableCapabilities defines the job capabilities that are always available regardless of configuration + AlwaysAvailableCapabilities = WorkerCapabilities{ + TelemetryJob: AlwaysAvailableTelemetryCaps, + TiktokJob: AlwaysAvailableTiktokCaps, + } + + // Twitter capabilities + TwitterCaps = []Capability{ + CapSearchByQuery, CapSearchByProfile, CapSearchByFullArchive, + CapGetById, CapGetReplies, CapGetRetweeters, CapGetTweets, CapGetMedia, CapGetProfileById, + CapGetTrends, CapGetFollowing, CapGetFollowers, CapGetSpace, CapEmpty, + } + + // TiktokSearchCaps are Tiktok capabilities available with Apify + TiktokSearchCaps = []Capability{CapSearchByQuery, CapSearchByTrending} + + // RedditCaps are all the Reddit capabilities (only available with Apify) + RedditCaps = []Capability{CapScrapeUrls, CapSearchPosts, CapSearchUsers, CapSearchCommunities} + + // WebCaps are all the Web capabilities (only available with Apify) + WebCaps = []Capability{CapScraper, CapEmpty} + + // LinkedInCaps are all the LinkedIn capabilities (only available with Apify) + LinkedInCaps = []Capability{CapSearchByProfile} +) + +// JobCapabilityMap defines which capabilities are valid for each job type +var JobCapabilityMap = map[JobType][]Capability{ + // Twitter job capabilities + TwitterJob: TwitterCaps, + + // Web job capabilities + WebJob: WebCaps, + + // LinkedIn job capabilities + LinkedInJob: LinkedInCaps, + + // TikTok job capabilities + TiktokJob: combineCapabilities( + AlwaysAvailableTiktokCaps, + TiktokSearchCaps, + ), + + // Reddit job capabilities + RedditJob: RedditCaps, + + // Telemetry job capabilities + TelemetryJob: AlwaysAvailableTelemetryCaps, +} + +// if no capability is specified, use the default capability for the job type +var JobDefaultCapabilityMap = map[JobType]Capability{ + TwitterJob: CapSearchByQuery, + WebJob: CapScraper, + TiktokJob: CapTranscription, + RedditJob: CapScrapeUrls, + TelemetryJob: CapTelemetry, + LinkedInJob: CapSearchByProfile, +} + +// JobResponse represents a response to a job submission +type JobResponse struct { + UID string `json:"uid"` +} + +// JobResult represents the result of a job execution +type JobResult struct { + Error string `json:"error"` + Data []byte `json:"data"` + Job Job `json:"job"` + NextCursor string `json:"next_cursor"` +} + +// Success returns true if the job was successful. +func (jr JobResult) Success() bool { + return jr.Error == "" +} + +// Unmarshal unmarshals the job result data. +func (jr JobResult) Unmarshal(i interface{}) error { + return json.Unmarshal(jr.Data, i) +} + +// JobRequest represents a request to execute a job +type JobRequest struct { + EncryptedJob string `json:"encrypted_job"` +} + +// JobError represents an error in job execution +type JobError struct { + Error string `json:"error"` +} + +// Key represents a key request +type Key struct { + Key string `json:"key"` + Signature string `json:"signature"` +} + +// KeyResponse represents a response to a key operation +type KeyResponse struct { + Status string `json:"status"` +} diff --git a/api/types/key.go b/api/types/key.go deleted file mode 100644 index 8691eae9..00000000 --- a/api/types/key.go +++ /dev/null @@ -1,11 +0,0 @@ -package types - -type Key struct { - Key string `json:"key"` - - Signature string `json:"signature"` -} - -type KeyResponse struct { - Status string `json:"status"` -} diff --git a/api/types/linkedin/experiences/experiences.go b/api/types/linkedin/experiences/experiences.go new file mode 100644 index 00000000..b339230e --- /dev/null +++ b/api/types/linkedin/experiences/experiences.go @@ -0,0 +1,41 @@ +package experiences + +import "github.com/masa-finance/tee-worker/pkg/util" + +// id represents a LinkedIn experience level identifier +type Id string + +// Experience level constants +const ( + LessThanAYear Id = "1" + OneToTwoYears Id = "2" + ThreeToFiveYears Id = "3" + SixToTenYears Id = "4" + MoreThanTenYears Id = "5" +) + +var All = util.NewSet( + LessThanAYear, + OneToTwoYears, + ThreeToFiveYears, + SixToTenYears, + MoreThanTenYears, +) + +type ExperiencesConfig struct { + All util.Set[Id] + LessThanAYear Id + OneToTwoYears Id + ThreeToFiveYears Id + SixToTenYears Id + MoreThanTenYears Id +} + +var Experiences = ExperiencesConfig{ + All: *All, + LessThanAYear: LessThanAYear, + OneToTwoYears: OneToTwoYears, + ThreeToFiveYears: ThreeToFiveYears, + SixToTenYears: SixToTenYears, + MoreThanTenYears: MoreThanTenYears, +} diff --git a/api/types/linkedin/functions/functions.go b/api/types/linkedin/functions/functions.go new file mode 100644 index 00000000..7dd715c3 --- /dev/null +++ b/api/types/linkedin/functions/functions.go @@ -0,0 +1,121 @@ +package functions + +import "github.com/masa-finance/tee-worker/pkg/util" + +// id represents a LinkedIn function identifier +type Id string + +// Function constants +const ( + Accounting Id = "1" + Administrative Id = "2" + ArtsAndDesign Id = "3" + BusinessDevelopment Id = "4" + CommunityAndSocialServices Id = "5" + Consulting Id = "6" + Education Id = "7" + Engineering Id = "8" + Entrepreneurship Id = "9" + Finance Id = "10" + HealthcareServices Id = "11" + HumanResources Id = "12" + InformationTechnology Id = "13" + Legal Id = "14" + Marketing Id = "15" + MediaAndCommunication Id = "16" + MilitaryAndProtectiveServices Id = "17" + Operations Id = "18" + ProductManagement Id = "19" + ProgramAndProjectManagement Id = "20" + Purchasing Id = "21" + QualityAssurance Id = "22" + RealEstate Id = "23" + Research Id = "24" + Sales Id = "25" +) + +var All = util.NewSet( + Accounting, + Administrative, + ArtsAndDesign, + BusinessDevelopment, + CommunityAndSocialServices, + Consulting, + Education, + Engineering, + Entrepreneurship, + Finance, + HealthcareServices, + HumanResources, + InformationTechnology, + Legal, + Marketing, + MediaAndCommunication, + MilitaryAndProtectiveServices, + Operations, + ProductManagement, + ProgramAndProjectManagement, + Purchasing, + QualityAssurance, + RealEstate, + Research, + Sales, +) + +type FunctionsConfig struct { + All util.Set[Id] + Accounting Id + Administrative Id + ArtsAndDesign Id + BusinessDevelopment Id + CommunityAndSocialServices Id + Consulting Id + Education Id + Engineering Id + Entrepreneurship Id + Finance Id + HealthcareServices Id + HumanResources Id + InformationTechnology Id + Legal Id + Marketing Id + MediaAndCommunication Id + MilitaryAndProtectiveServices Id + Operations Id + ProductManagement Id + ProgramAndProjectManagement Id + Purchasing Id + QualityAssurance Id + RealEstate Id + Research Id + Sales Id +} + +var Functions = FunctionsConfig{ + All: *All, + Accounting: Accounting, + Administrative: Administrative, + ArtsAndDesign: ArtsAndDesign, + BusinessDevelopment: BusinessDevelopment, + CommunityAndSocialServices: CommunityAndSocialServices, + Consulting: Consulting, + Education: Education, + Engineering: Engineering, + Entrepreneurship: Entrepreneurship, + Finance: Finance, + HealthcareServices: HealthcareServices, + HumanResources: HumanResources, + InformationTechnology: InformationTechnology, + Legal: Legal, + Marketing: Marketing, + MediaAndCommunication: MediaAndCommunication, + MilitaryAndProtectiveServices: MilitaryAndProtectiveServices, + Operations: Operations, + ProductManagement: ProductManagement, + ProgramAndProjectManagement: ProgramAndProjectManagement, + Purchasing: Purchasing, + QualityAssurance: QualityAssurance, + RealEstate: RealEstate, + Research: Research, + Sales: Sales, +} diff --git a/api/types/linkedin/industries/industries.go b/api/types/linkedin/industries/industries.go new file mode 100644 index 00000000..666df4d5 --- /dev/null +++ b/api/types/linkedin/industries/industries.go @@ -0,0 +1,717 @@ +package industries + +import "github.com/masa-finance/tee-worker/pkg/util" + +// Id represents a LinkedIn industry identifier +type Id string + +// Industry constants +const ( + // Technology & Software + SoftwareDevelopment Id = "4" + ComputerHardwareManufacturing Id = "3" + ComputerNetworkingProducts Id = "5" + ItServicesAndItConsulting Id = "96" + ComputerAndNetworkSecurity Id = "118" + Telecommunications Id = "8" + WirelessServices Id = "119" + TechnologyInformationAndInternet Id = "6" + DataInfrastructureAndAnalytics Id = "2458" + InformationServices Id = "84" + InternetPublishing Id = "3132" + SocialNetworkingPlatforms Id = "3127" + ComputerGames Id = "109" + MobileGamingApps Id = "3131" + BlockchainServices Id = "3134" + BusinessIntelligencePlatforms Id = "3128" + + // Financial Services + FinancialServices Id = "43" + Banking Id = "41" + Insurance Id = "42" + InvestmentBanking Id = "45" + CapitalMarkets Id = "129" + VentureCapitalAndPrivateEquityPrincipals Id = "106" + SecuritiesAndCommodityExchanges Id = "1713" + FundsAndTrusts Id = "1742" + + // Healthcare & Medical + Hospitals Id = "2081" + MedicalPractices Id = "13" + MedicalEquipmentManufacturing Id = "17" + PublicHealth Id = "2358" + VeterinaryServices Id = "16" + BiotechnologyResearch Id = "12" + + // Manufacturing + Manufacturing Id = "25" + ComputersAndElectronicsManufacturing Id = "24" + SemiconductorManufacturing Id = "7" + MachineryManufacturing Id = "55" + IndustrialMachineryManufacturing Id = "135" + FoodAndBeverageManufacturing Id = "23" + TextileManufacturing Id = "60" + MotorVehicleManufacturing Id = "53" + MotorVehiclePartsManufacturing Id = "1042" + AviationAndAerospaceComponentManufacturing Id = "52" + DefenseAndSpaceManufacturing Id = "1" + PlasticsManufacturing Id = "117" + RubberProductsManufacturing Id = "763" + PaperAndForestProductManufacturing Id = "61" + WoodProductManufacturing Id = "784" + FurnitureAndHomeFurnishingsManufacturing Id = "26" + SportingGoodsManufacturing Id = "20" + PrintingServices Id = "83" + + // Retail & Consumer Goods + Retail Id = "27" + RetailGroceries Id = "22" + OnlineAndMailOrderRetail Id = "1445" + RetailApparelAndFashion Id = "19" + RetailAppliancesElectricalAndElectronicEquipment Id = "1319" + RetailBooksAndPrintedNews Id = "1409" + RetailBuildingMaterialsAndGardenEquipment Id = "1324" + RetailFurnitureAndHomeFurnishings Id = "1309" + RetailHealthAndPersonalCareProducts Id = "1359" + RetailLuxuryGoodsAndJewelry Id = "143" + RetailMotorVehicles Id = "1292" + RetailOfficeEquipment Id = "138" + RetailOfficeSuppliesAndGifts Id = "1424" + + // Professional Services + ProfessionalServices Id = "1810" + Accounting Id = "47" + LegalServices Id = "10" + LawPractice Id = "9" + BusinessConsultingAndServices Id = "11" + StrategicManagementServices Id = "102" + HumanResourcesServices Id = "137" + MarketingServices Id = "1862" + AdvertisingServices Id = "80" + PublicRelationsAndCommunicationsServices Id = "98" + MarketResearch Id = "97" + ArchitectureAndPlanning Id = "50" + DesignServices Id = "99" + GraphicDesign Id = "140" + InteriorDesign Id = "3126" + EngineeringServices Id = "3242" + EnvironmentalServices Id = "86" + ResearchServices Id = "70" + ThinkTanks Id = "130" + Photography Id = "136" + TranslationAndLocalization Id = "108" + WritingAndEditing Id = "103" + + // Education + Education Id = "1999" + HigherEducation Id = "68" + ProfessionalTrainingAndCoaching Id = "105" + SportsAndRecreationInstruction Id = "2027" + + // Transportation & Logistics + TransportationLogisticsSupplyChainAndStorage Id = "116" + AirlinesAndAviation Id = "94" + FreightAndPackageTransportation Id = "87" + MaritimeTransportation Id = "95" + RailTransportation Id = "1481" + TruckTransportation Id = "92" + WarehousingAndStorage Id = "93" + PostalServices Id = "1573" + + // Energy & Utilities + Utilities Id = "59" + ElectricPowerGeneration Id = "383" + RenewableEnergyPowerGeneration Id = "3240" + OilAndGas Id = "57" + Mining Id = "56" + OilGasAndMining Id = "332" + + // Media & Entertainment + TechnologyInformationAndMedia Id = "1594" + BroadcastMediaProductionAndDistribution Id = "36" + RadioAndTelevisionBroadcasting Id = "1633" + MoviesVideosAndSound Id = "35" + MediaProduction Id = "126" + SoundRecording Id = "1623" + BookAndPeriodicalPublishing Id = "82" + NewspaperPublishing Id = "81" + PeriodicalPublishing Id = "1600" + EntertainmentProviders Id = "28" + ArtistsAndWriters Id = "38" + Musicians Id = "115" + + // Construction & Real Estate + Construction Id = "48" + CivilEngineering Id = "51" + RealEstate Id = "44" + RealEstateAgentsAndBrokers Id = "1770" + + // Hospitality & Services + Hospitality Id = "31" + HotelsAndMotels Id = "2194" + Restaurants Id = "32" + FoodAndBeverageServices Id = "34" + TravelArrangements Id = "30" + EventsServices Id = "110" + WellnessAndFitnessServices Id = "124" + ConsumerServices Id = "91" + + // Government & Non-Profit + ArmedForces Id = "71" + GovernmentRelationsServices Id = "148" + NonProfitOrganizations Id = "100" + CivicAndSocialOrganizations Id = "90" + PoliticalOrganizations Id = "107" + ProfessionalOrganizations Id = "1911" + Fundraising Id = "101" + + // Wholesale & Distribution + Wholesale Id = "133" + WholesaleImportAndExport Id = "134" + WholesaleComputerEquipment Id = "1157" + WholesaleFoodAndBeverage Id = "1231" + WholesaleBuildingMaterials Id = "49" + WholesaleMachinery Id = "1187" + WholesaleMotorVehiclesAndParts Id = "1128" + + // Other Services + StaffingAndRecruiting Id = "104" + ExecutiveSearchServices Id = "1923" + OfficeAdministration Id = "1916" + SecurityAndInvestigations Id = "121" + EquipmentRentalServices Id = "1779" + Libraries Id = "85" +) + +var All = util.NewSet( + // Technology & Software + SoftwareDevelopment, + ComputerHardwareManufacturing, + ComputerNetworkingProducts, + ItServicesAndItConsulting, + ComputerAndNetworkSecurity, + Telecommunications, + WirelessServices, + TechnologyInformationAndInternet, + DataInfrastructureAndAnalytics, + InformationServices, + InternetPublishing, + SocialNetworkingPlatforms, + ComputerGames, + MobileGamingApps, + BlockchainServices, + BusinessIntelligencePlatforms, + + // Financial Services + FinancialServices, + Banking, + Insurance, + InvestmentBanking, + CapitalMarkets, + VentureCapitalAndPrivateEquityPrincipals, + SecuritiesAndCommodityExchanges, + FundsAndTrusts, + + // Healthcare & Medical + Hospitals, + MedicalPractices, + MedicalEquipmentManufacturing, + PublicHealth, + VeterinaryServices, + BiotechnologyResearch, + + // Manufacturing + Manufacturing, + ComputersAndElectronicsManufacturing, + SemiconductorManufacturing, + MachineryManufacturing, + IndustrialMachineryManufacturing, + FoodAndBeverageManufacturing, + TextileManufacturing, + MotorVehicleManufacturing, + MotorVehiclePartsManufacturing, + AviationAndAerospaceComponentManufacturing, + DefenseAndSpaceManufacturing, + PlasticsManufacturing, + RubberProductsManufacturing, + PaperAndForestProductManufacturing, + WoodProductManufacturing, + FurnitureAndHomeFurnishingsManufacturing, + SportingGoodsManufacturing, + PrintingServices, + + // Retail & Consumer Goods + Retail, + RetailGroceries, + OnlineAndMailOrderRetail, + RetailApparelAndFashion, + RetailAppliancesElectricalAndElectronicEquipment, + RetailBooksAndPrintedNews, + RetailBuildingMaterialsAndGardenEquipment, + RetailFurnitureAndHomeFurnishings, + RetailHealthAndPersonalCareProducts, + RetailLuxuryGoodsAndJewelry, + RetailMotorVehicles, + RetailOfficeEquipment, + RetailOfficeSuppliesAndGifts, + + // Professional Services + ProfessionalServices, + Accounting, + LegalServices, + LawPractice, + BusinessConsultingAndServices, + StrategicManagementServices, + HumanResourcesServices, + MarketingServices, + AdvertisingServices, + PublicRelationsAndCommunicationsServices, + MarketResearch, + ArchitectureAndPlanning, + DesignServices, + GraphicDesign, + InteriorDesign, + EngineeringServices, + EnvironmentalServices, + ResearchServices, + ThinkTanks, + Photography, + TranslationAndLocalization, + WritingAndEditing, + + // Education + Education, + HigherEducation, + ProfessionalTrainingAndCoaching, + SportsAndRecreationInstruction, + + // Transportation & Logistics + TransportationLogisticsSupplyChainAndStorage, + AirlinesAndAviation, + FreightAndPackageTransportation, + MaritimeTransportation, + RailTransportation, + TruckTransportation, + WarehousingAndStorage, + PostalServices, + + // Energy & Utilities + Utilities, + ElectricPowerGeneration, + RenewableEnergyPowerGeneration, + OilAndGas, + Mining, + OilGasAndMining, + + // Media & Entertainment + TechnologyInformationAndMedia, + BroadcastMediaProductionAndDistribution, + RadioAndTelevisionBroadcasting, + MoviesVideosAndSound, + MediaProduction, + SoundRecording, + BookAndPeriodicalPublishing, + NewspaperPublishing, + PeriodicalPublishing, + EntertainmentProviders, + ArtistsAndWriters, + Musicians, + + // Construction & Real Estate + Construction, + CivilEngineering, + RealEstate, + RealEstateAgentsAndBrokers, + + // Hospitality & Services + Hospitality, + HotelsAndMotels, + Restaurants, + FoodAndBeverageServices, + TravelArrangements, + EventsServices, + WellnessAndFitnessServices, + ConsumerServices, + + // Government & Non-Profit + ArmedForces, + GovernmentRelationsServices, + NonProfitOrganizations, + CivicAndSocialOrganizations, + PoliticalOrganizations, + ProfessionalOrganizations, + Fundraising, + + // Wholesale & Distribution + Wholesale, + WholesaleImportAndExport, + WholesaleComputerEquipment, + WholesaleFoodAndBeverage, + WholesaleBuildingMaterials, + WholesaleMachinery, + WholesaleMotorVehiclesAndParts, + + // Other Services + StaffingAndRecruiting, + ExecutiveSearchServices, + OfficeAdministration, + SecurityAndInvestigations, + EquipmentRentalServices, + Libraries, +) + +type IndustriesConfig struct { + All util.Set[Id] + // Technology & Software + SoftwareDevelopment Id + ComputerHardwareManufacturing Id + ComputerNetworkingProducts Id + ItServicesAndItConsulting Id + ComputerAndNetworkSecurity Id + Telecommunications Id + WirelessServices Id + TechnologyInformationAndInternet Id + DataInfrastructureAndAnalytics Id + InformationServices Id + InternetPublishing Id + SocialNetworkingPlatforms Id + ComputerGames Id + MobileGamingApps Id + BlockchainServices Id + BusinessIntelligencePlatforms Id + + // Financial Services + FinancialServices Id + Banking Id + Insurance Id + InvestmentBanking Id + CapitalMarkets Id + VentureCapitalAndPrivateEquityPrincipals Id + SecuritiesAndCommodityExchanges Id + FundsAndTrusts Id + + // Healthcare & Medical + Hospitals Id + MedicalPractices Id + MedicalEquipmentManufacturing Id + PublicHealth Id + VeterinaryServices Id + BiotechnologyResearch Id + + // Manufacturing + Manufacturing Id + ComputersAndElectronicsManufacturing Id + SemiconductorManufacturing Id + MachineryManufacturing Id + IndustrialMachineryManufacturing Id + FoodAndBeverageManufacturing Id + TextileManufacturing Id + MotorVehicleManufacturing Id + MotorVehiclePartsManufacturing Id + AviationAndAerospaceComponentManufacturing Id + DefenseAndSpaceManufacturing Id + PlasticsManufacturing Id + RubberProductsManufacturing Id + PaperAndForestProductManufacturing Id + WoodProductManufacturing Id + FurnitureAndHomeFurnishingsManufacturing Id + SportingGoodsManufacturing Id + PrintingServices Id + + // Retail & Consumer Goods + Retail Id + RetailGroceries Id + OnlineAndMailOrderRetail Id + RetailApparelAndFashion Id + RetailAppliancesElectricalAndElectronicEquipment Id + RetailBooksAndPrintedNews Id + RetailBuildingMaterialsAndGardenEquipment Id + RetailFurnitureAndHomeFurnishings Id + RetailHealthAndPersonalCareProducts Id + RetailLuxuryGoodsAndJewelry Id + RetailMotorVehicles Id + RetailOfficeEquipment Id + RetailOfficeSuppliesAndGifts Id + + // Professional Services + ProfessionalServices Id + Accounting Id + LegalServices Id + LawPractice Id + BusinessConsultingAndServices Id + StrategicManagementServices Id + HumanResourcesServices Id + MarketingServices Id + AdvertisingServices Id + PublicRelationsAndCommunicationsServices Id + MarketResearch Id + ArchitectureAndPlanning Id + DesignServices Id + GraphicDesign Id + InteriorDesign Id + EngineeringServices Id + EnvironmentalServices Id + ResearchServices Id + ThinkTanks Id + Photography Id + TranslationAndLocalization Id + WritingAndEditing Id + + // Education + Education Id + HigherEducation Id + ProfessionalTrainingAndCoaching Id + SportsAndRecreationInstruction Id + + // Transportation & Logistics + TransportationLogisticsSupplyChainAndStorage Id + AirlinesAndAviation Id + FreightAndPackageTransportation Id + MaritimeTransportation Id + RailTransportation Id + TruckTransportation Id + WarehousingAndStorage Id + PostalServices Id + + // Energy & Utilities + Utilities Id + ElectricPowerGeneration Id + RenewableEnergyPowerGeneration Id + OilAndGas Id + Mining Id + OilGasAndMining Id + + // Media & Entertainment + TechnologyInformationAndMedia Id + BroadcastMediaProductionAndDistribution Id + RadioAndTelevisionBroadcasting Id + MoviesVideosAndSound Id + MediaProduction Id + SoundRecording Id + BookAndPeriodicalPublishing Id + NewspaperPublishing Id + PeriodicalPublishing Id + EntertainmentProviders Id + ArtistsAndWriters Id + Musicians Id + + // Construction & Real Estate + Construction Id + CivilEngineering Id + RealEstate Id + RealEstateAgentsAndBrokers Id + + // Hospitality & Services + Hospitality Id + HotelsAndMotels Id + Restaurants Id + FoodAndBeverageServices Id + TravelArrangements Id + EventsServices Id + WellnessAndFitnessServices Id + ConsumerServices Id + + // Government & Non-Profit + ArmedForces Id + GovernmentRelationsServices Id + NonProfitOrganizations Id + CivicAndSocialOrganizations Id + PoliticalOrganizations Id + ProfessionalOrganizations Id + Fundraising Id + + // Wholesale & Distribution + Wholesale Id + WholesaleImportAndExport Id + WholesaleComputerEquipment Id + WholesaleFoodAndBeverage Id + WholesaleBuildingMaterials Id + WholesaleMachinery Id + WholesaleMotorVehiclesAndParts Id + + // Other Services + StaffingAndRecruiting Id + ExecutiveSearchServices Id + OfficeAdministration Id + SecurityAndInvestigations Id + EquipmentRentalServices Id + Libraries Id +} + +var Industries = IndustriesConfig{ + All: *All, + // Technology & Software + SoftwareDevelopment: SoftwareDevelopment, + ComputerHardwareManufacturing: ComputerHardwareManufacturing, + ComputerNetworkingProducts: ComputerNetworkingProducts, + ItServicesAndItConsulting: ItServicesAndItConsulting, + ComputerAndNetworkSecurity: ComputerAndNetworkSecurity, + Telecommunications: Telecommunications, + WirelessServices: WirelessServices, + TechnologyInformationAndInternet: TechnologyInformationAndInternet, + DataInfrastructureAndAnalytics: DataInfrastructureAndAnalytics, + InformationServices: InformationServices, + InternetPublishing: InternetPublishing, + SocialNetworkingPlatforms: SocialNetworkingPlatforms, + ComputerGames: ComputerGames, + MobileGamingApps: MobileGamingApps, + BlockchainServices: BlockchainServices, + BusinessIntelligencePlatforms: BusinessIntelligencePlatforms, + + // Financial Services + FinancialServices: FinancialServices, + Banking: Banking, + Insurance: Insurance, + InvestmentBanking: InvestmentBanking, + CapitalMarkets: CapitalMarkets, + VentureCapitalAndPrivateEquityPrincipals: VentureCapitalAndPrivateEquityPrincipals, + SecuritiesAndCommodityExchanges: SecuritiesAndCommodityExchanges, + FundsAndTrusts: FundsAndTrusts, + + // Healthcare & Medical + Hospitals: Hospitals, + MedicalPractices: MedicalPractices, + MedicalEquipmentManufacturing: MedicalEquipmentManufacturing, + PublicHealth: PublicHealth, + VeterinaryServices: VeterinaryServices, + BiotechnologyResearch: BiotechnologyResearch, + + // Manufacturing + Manufacturing: Manufacturing, + ComputersAndElectronicsManufacturing: ComputersAndElectronicsManufacturing, + SemiconductorManufacturing: SemiconductorManufacturing, + MachineryManufacturing: MachineryManufacturing, + IndustrialMachineryManufacturing: IndustrialMachineryManufacturing, + FoodAndBeverageManufacturing: FoodAndBeverageManufacturing, + TextileManufacturing: TextileManufacturing, + MotorVehicleManufacturing: MotorVehicleManufacturing, + MotorVehiclePartsManufacturing: MotorVehiclePartsManufacturing, + AviationAndAerospaceComponentManufacturing: AviationAndAerospaceComponentManufacturing, + DefenseAndSpaceManufacturing: DefenseAndSpaceManufacturing, + PlasticsManufacturing: PlasticsManufacturing, + RubberProductsManufacturing: RubberProductsManufacturing, + PaperAndForestProductManufacturing: PaperAndForestProductManufacturing, + WoodProductManufacturing: WoodProductManufacturing, + FurnitureAndHomeFurnishingsManufacturing: FurnitureAndHomeFurnishingsManufacturing, + SportingGoodsManufacturing: SportingGoodsManufacturing, + PrintingServices: PrintingServices, + + // Retail & Consumer Goods + Retail: Retail, + RetailGroceries: RetailGroceries, + OnlineAndMailOrderRetail: OnlineAndMailOrderRetail, + RetailApparelAndFashion: RetailApparelAndFashion, + RetailAppliancesElectricalAndElectronicEquipment: RetailAppliancesElectricalAndElectronicEquipment, + RetailBooksAndPrintedNews: RetailBooksAndPrintedNews, + RetailBuildingMaterialsAndGardenEquipment: RetailBuildingMaterialsAndGardenEquipment, + RetailFurnitureAndHomeFurnishings: RetailFurnitureAndHomeFurnishings, + RetailHealthAndPersonalCareProducts: RetailHealthAndPersonalCareProducts, + RetailLuxuryGoodsAndJewelry: RetailLuxuryGoodsAndJewelry, + RetailMotorVehicles: RetailMotorVehicles, + RetailOfficeEquipment: RetailOfficeEquipment, + RetailOfficeSuppliesAndGifts: RetailOfficeSuppliesAndGifts, + + // Professional Services + ProfessionalServices: ProfessionalServices, + Accounting: Accounting, + LegalServices: LegalServices, + LawPractice: LawPractice, + BusinessConsultingAndServices: BusinessConsultingAndServices, + StrategicManagementServices: StrategicManagementServices, + HumanResourcesServices: HumanResourcesServices, + MarketingServices: MarketingServices, + AdvertisingServices: AdvertisingServices, + PublicRelationsAndCommunicationsServices: PublicRelationsAndCommunicationsServices, + MarketResearch: MarketResearch, + ArchitectureAndPlanning: ArchitectureAndPlanning, + DesignServices: DesignServices, + GraphicDesign: GraphicDesign, + InteriorDesign: InteriorDesign, + EngineeringServices: EngineeringServices, + EnvironmentalServices: EnvironmentalServices, + ResearchServices: ResearchServices, + ThinkTanks: ThinkTanks, + Photography: Photography, + TranslationAndLocalization: TranslationAndLocalization, + WritingAndEditing: WritingAndEditing, + + // Education + Education: Education, + HigherEducation: HigherEducation, + ProfessionalTrainingAndCoaching: ProfessionalTrainingAndCoaching, + SportsAndRecreationInstruction: SportsAndRecreationInstruction, + + // Transportation & Logistics + TransportationLogisticsSupplyChainAndStorage: TransportationLogisticsSupplyChainAndStorage, + AirlinesAndAviation: AirlinesAndAviation, + FreightAndPackageTransportation: FreightAndPackageTransportation, + MaritimeTransportation: MaritimeTransportation, + RailTransportation: RailTransportation, + TruckTransportation: TruckTransportation, + WarehousingAndStorage: WarehousingAndStorage, + PostalServices: PostalServices, + + // Energy & Utilities + Utilities: Utilities, + ElectricPowerGeneration: ElectricPowerGeneration, + RenewableEnergyPowerGeneration: RenewableEnergyPowerGeneration, + OilAndGas: OilAndGas, + Mining: Mining, + OilGasAndMining: OilGasAndMining, + + // Media & Entertainment + TechnologyInformationAndMedia: TechnologyInformationAndMedia, + BroadcastMediaProductionAndDistribution: BroadcastMediaProductionAndDistribution, + RadioAndTelevisionBroadcasting: RadioAndTelevisionBroadcasting, + MoviesVideosAndSound: MoviesVideosAndSound, + MediaProduction: MediaProduction, + SoundRecording: SoundRecording, + BookAndPeriodicalPublishing: BookAndPeriodicalPublishing, + NewspaperPublishing: NewspaperPublishing, + PeriodicalPublishing: PeriodicalPublishing, + EntertainmentProviders: EntertainmentProviders, + ArtistsAndWriters: ArtistsAndWriters, + Musicians: Musicians, + + // Construction & Real Estate + Construction: Construction, + CivilEngineering: CivilEngineering, + RealEstate: RealEstate, + RealEstateAgentsAndBrokers: RealEstateAgentsAndBrokers, + + // Hospitality & Services + Hospitality: Hospitality, + HotelsAndMotels: HotelsAndMotels, + Restaurants: Restaurants, + FoodAndBeverageServices: FoodAndBeverageServices, + TravelArrangements: TravelArrangements, + EventsServices: EventsServices, + WellnessAndFitnessServices: WellnessAndFitnessServices, + ConsumerServices: ConsumerServices, + + // Government & Non-Profit + ArmedForces: ArmedForces, + GovernmentRelationsServices: GovernmentRelationsServices, + NonProfitOrganizations: NonProfitOrganizations, + CivicAndSocialOrganizations: CivicAndSocialOrganizations, + PoliticalOrganizations: PoliticalOrganizations, + ProfessionalOrganizations: ProfessionalOrganizations, + Fundraising: Fundraising, + + // Wholesale & Distribution + Wholesale: Wholesale, + WholesaleImportAndExport: WholesaleImportAndExport, + WholesaleComputerEquipment: WholesaleComputerEquipment, + WholesaleFoodAndBeverage: WholesaleFoodAndBeverage, + WholesaleBuildingMaterials: WholesaleBuildingMaterials, + WholesaleMachinery: WholesaleMachinery, + WholesaleMotorVehiclesAndParts: WholesaleMotorVehiclesAndParts, + + // Other Services + StaffingAndRecruiting: StaffingAndRecruiting, + ExecutiveSearchServices: ExecutiveSearchServices, + OfficeAdministration: OfficeAdministration, + SecurityAndInvestigations: SecurityAndInvestigations, + EquipmentRentalServices: EquipmentRentalServices, + Libraries: Libraries, +} diff --git a/api/types/linkedin/linkedin.go b/api/types/linkedin/linkedin.go new file mode 100644 index 00000000..c0d3b226 --- /dev/null +++ b/api/types/linkedin/linkedin.go @@ -0,0 +1,27 @@ +package linkedin + +import ( + "github.com/masa-finance/tee-worker/api/types/linkedin/experiences" + "github.com/masa-finance/tee-worker/api/types/linkedin/functions" + "github.com/masa-finance/tee-worker/api/types/linkedin/industries" + "github.com/masa-finance/tee-worker/api/types/linkedin/profile" + "github.com/masa-finance/tee-worker/api/types/linkedin/seniorities" +) + +type LinkedInConfig struct { + Experiences *experiences.ExperiencesConfig + Seniorities *seniorities.SenioritiesConfig + Functions *functions.FunctionsConfig + Industries *industries.IndustriesConfig + Profile *profile.Profile +} + +var LinkedIn = LinkedInConfig{ + Experiences: &experiences.Experiences, + Seniorities: &seniorities.Seniorities, + Functions: &functions.Functions, + Industries: &industries.Industries, + Profile: &profile.Profile{}, +} + +type Profile = profile.Profile diff --git a/api/types/linkedin/linkedin_suite_test.go b/api/types/linkedin/linkedin_suite_test.go new file mode 100644 index 00000000..8f46e3dc --- /dev/null +++ b/api/types/linkedin/linkedin_suite_test.go @@ -0,0 +1,13 @@ +package linkedin_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestTypes(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Types Suite") +} diff --git a/api/types/linkedin/linkedin_test.go b/api/types/linkedin/linkedin_test.go new file mode 100644 index 00000000..88c26f0f --- /dev/null +++ b/api/types/linkedin/linkedin_test.go @@ -0,0 +1,135 @@ +package linkedin_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/types" + "github.com/masa-finance/tee-worker/api/types/linkedin/experiences" + "github.com/masa-finance/tee-worker/api/types/linkedin/functions" + "github.com/masa-finance/tee-worker/api/types/linkedin/industries" + "github.com/masa-finance/tee-worker/api/types/linkedin/seniorities" +) + +var _ = Describe("LinkedIn Types", func() { + Describe("LinkedIn Package", func() { + It("should have all required fields", func() { + linkedin := types.LinkedIn + + Expect(linkedin.Seniorities).ToNot(BeNil()) + Expect(linkedin.Experiences).ToNot(BeNil()) + Expect(linkedin.Functions).ToNot(BeNil()) + Expect(linkedin.Industries).ToNot(BeNil()) + }) + }) + + Describe("Seniorities", func() { + It("should have all seniority levels", func() { + s := types.LinkedIn.Seniorities + + Expect(s.InTraining).To(Equal(seniorities.InTraining)) + Expect(s.EntryLevel).To(Equal(seniorities.EntryLevel)) + Expect(s.Senior).To(Equal(seniorities.Senior)) + Expect(s.Strategic).To(Equal(seniorities.Strategic)) + Expect(s.EntryLevelManager).To(Equal(seniorities.EntryLevelManager)) + Expect(s.ExperiencedManager).To(Equal(seniorities.ExperiencedManager)) + Expect(s.Director).To(Equal(seniorities.Director)) + Expect(s.VicePresident).To(Equal(seniorities.VicePresident)) + Expect(s.CXO).To(Equal(seniorities.CXO)) + Expect(s.Partner).To(Equal(seniorities.Partner)) + }) + + It("should have All set containing all seniorities", func() { + all := types.LinkedIn.Seniorities.All + + Expect(all.Contains(seniorities.InTraining)).To(BeTrue()) + Expect(all.Contains(seniorities.EntryLevel)).To(BeTrue()) + Expect(all.Contains(seniorities.Senior)).To(BeTrue()) + Expect(all.Contains(seniorities.Strategic)).To(BeTrue()) + Expect(all.Contains(seniorities.EntryLevelManager)).To(BeTrue()) + Expect(all.Contains(seniorities.ExperiencedManager)).To(BeTrue()) + Expect(all.Contains(seniorities.Director)).To(BeTrue()) + Expect(all.Contains(seniorities.VicePresident)).To(BeTrue()) + Expect(all.Contains(seniorities.CXO)).To(BeTrue()) + Expect(all.Contains(seniorities.Partner)).To(BeTrue()) + + Expect(all.Length()).To(Equal(10)) + }) + }) + + Describe("Experiences", func() { + It("should have all experience levels", func() { + e := types.LinkedIn.Experiences + + Expect(e.LessThanAYear).To(Equal(experiences.LessThanAYear)) + Expect(e.OneToTwoYears).To(Equal(experiences.OneToTwoYears)) + Expect(e.ThreeToFiveYears).To(Equal(experiences.ThreeToFiveYears)) + Expect(e.SixToTenYears).To(Equal(experiences.SixToTenYears)) + Expect(e.MoreThanTenYears).To(Equal(experiences.MoreThanTenYears)) + }) + + It("should have All set containing all experiences", func() { + all := types.LinkedIn.Experiences.All + + Expect(all.Contains(experiences.LessThanAYear)).To(BeTrue()) + Expect(all.Contains(experiences.OneToTwoYears)).To(BeTrue()) + Expect(all.Contains(experiences.ThreeToFiveYears)).To(BeTrue()) + Expect(all.Contains(experiences.SixToTenYears)).To(BeTrue()) + Expect(all.Contains(experiences.MoreThanTenYears)).To(BeTrue()) + + Expect(all.Length()).To(Equal(5)) + }) + }) + + Describe("Functions", func() { + It("should have all function types", func() { + f := types.LinkedIn.Functions + + Expect(f.Accounting).To(Equal(functions.Accounting)) + Expect(f.Engineering).To(Equal(functions.Engineering)) + Expect(f.Marketing).To(Equal(functions.Marketing)) + Expect(f.Sales).To(Equal(functions.Sales)) + Expect(f.HumanResources).To(Equal(functions.HumanResources)) + }) + + It("should have All set containing all functions", func() { + all := types.LinkedIn.Functions.All + + Expect(all.Contains(functions.Accounting)).To(BeTrue()) + Expect(all.Contains(functions.Engineering)).To(BeTrue()) + Expect(all.Contains(functions.Marketing)).To(BeTrue()) + Expect(all.Contains(functions.Sales)).To(BeTrue()) + Expect(all.Contains(functions.HumanResources)).To(BeTrue()) + Expect(all.Contains(functions.InformationTechnology)).To(BeTrue()) + Expect(all.Contains(functions.Finance)).To(BeTrue()) + + Expect(all.Length()).To(Equal(25)) + }) + }) + + Describe("Industries", func() { + It("should have all industry types", func() { + i := types.LinkedIn.Industries + + Expect(i.SoftwareDevelopment).To(Equal(industries.SoftwareDevelopment)) + Expect(i.FinancialServices).To(Equal(industries.FinancialServices)) + Expect(i.Manufacturing).To(Equal(industries.Manufacturing)) + Expect(i.Retail).To(Equal(industries.Retail)) + Expect(i.Education).To(Equal(industries.Education)) + }) + + It("should have All set containing all industries", func() { + all := types.LinkedIn.Industries.All + + Expect(all.Contains(industries.SoftwareDevelopment)).To(BeTrue()) + Expect(all.Contains(industries.FinancialServices)).To(BeTrue()) + Expect(all.Contains(industries.Manufacturing)).To(BeTrue()) + Expect(all.Contains(industries.Retail)).To(BeTrue()) + Expect(all.Contains(industries.Education)).To(BeTrue()) + Expect(all.Contains(industries.Hospitals)).To(BeTrue()) + Expect(all.Contains(industries.ProfessionalServices)).To(BeTrue()) + + Expect(all.Length()).To(BeNumerically(">=", 100)) // Should have many industries + }) + }) +}) diff --git a/api/types/linkedin/profile/profile.go b/api/types/linkedin/profile/profile.go new file mode 100644 index 00000000..024151d2 --- /dev/null +++ b/api/types/linkedin/profile/profile.go @@ -0,0 +1,261 @@ +package profile + +import ( + "time" + + "github.com/masa-finance/tee-worker/pkg/util" +) + +type ScraperMode string + +const ( + ScraperModeShort ScraperMode = "Short" + ScraperModeFull ScraperMode = "Full" + ScraperModeFullEmail ScraperMode = "Full + email search" +) + +var AllScraperModes = util.NewSet(ScraperModeShort, ScraperModeFull, ScraperModeFullEmail) + +// Profile represents a complete profile response +type Profile struct { + ID string `json:"id"` + PublicIdentifier string `json:"publicIdentifier,omitempty"` + URL string `json:"linkedinUrl"` + FirstName string `json:"firstName"` + LastName string `json:"lastName"` + Headline *string `json:"headline,omitempty"` + About *string `json:"about,omitempty"` + Summary *string `json:"summary,omitempty"` + OpenToWork bool `json:"openToWork,omitempty"` + OpenProfile bool `json:"openProfile,omitempty"` + Hiring bool `json:"hiring,omitempty"` + Photo *string `json:"photo,omitempty"` + PictureUrl *string `json:"pictureUrl,omitempty"` + Premium bool `json:"premium,omitempty"` + Influencer bool `json:"influencer,omitempty"` + Location Location `json:"location,omitempty"` + Verified bool `json:"verified,omitempty"` + RegisteredAt time.Time `json:"registeredAt,omitempty"` + TopSkills *string `json:"topSkills,omitempty"` + ConnectionsCount int `json:"connectionsCount,omitempty"` + FollowerCount int `json:"followerCount,omitempty"` + ComposeOptionType *string `json:"composeOptionType,omitempty"` + + // Full mode + CurrentPosition []CurrentPosition `json:"currentPosition,omitempty"` + + // Short mode + CurrentPositions []ShortCurrentPosition `json:"currentPositions,omitempty"` + + Experience []Experience `json:"experience,omitempty"` + Education []Education `json:"education,omitempty"` + Certifications []Certification `json:"certifications,omitempty"` + Projects []Project `json:"projects,omitempty"` + Volunteering []Volunteering `json:"volunteering,omitempty"` + ReceivedRecommendations []Recommendation `json:"receivedRecommendations,omitempty"` + Skills []Skill `json:"skills,omitempty"` + Courses []Course `json:"courses,omitempty"` + Publications []Publication `json:"publications,omitempty"` + Patents []Patent `json:"patents,omitempty"` + HonorsAndAwards []HonorAndAward `json:"honorsAndAwards,omitempty"` + Languages []Language `json:"languages,omitempty"` + Featured any `json:"featured,omitempty"` + MoreProfiles []MoreProfile `json:"moreProfiles,omitempty"` + + // Email mode + Emails []string `json:"emails,omitempty"` + CompanyWebsites []CompanyWebsite `json:"companyWebsites,omitempty"` +} + +// Location represents the location information +type Location struct { + Text string `json:"linkedinText"` + CountryCode string `json:"countryCode,omitempty"` + Parsed ParsedLocation `json:"parsed,omitempty"` +} + +// ParsedLocation represents the parsed location details +type ParsedLocation struct { + Text string `json:"text,omitempty"` + CountryCode string `json:"countryCode,omitempty"` + RegionCode *string `json:"regionCode,omitempty"` + Country string `json:"country,omitempty"` + CountryFull string `json:"countryFull,omitempty"` + State string `json:"state,omitempty"` + City string `json:"city,omitempty"` +} + +// CurrentPosition represents current position information +type CurrentPosition struct { + CompanyID *string `json:"companyId,omitempty"` + CompanyLinkedinUrl *string `json:"companyLinkedinUrl,omitempty"` + CompanyName string `json:"companyName"` + DateRange *DatePeriod `json:"dateRange,omitempty"` +} + +// Experience represents work experience +type Experience struct { + Position string `json:"position"` + Location *string `json:"location,omitempty"` + EmploymentType *string `json:"employmentType,omitempty"` + WorkplaceType *string `json:"workplaceType,omitempty"` + CompanyName string `json:"companyName"` + CompanyURL *string `json:"companyUrl,omitempty"` + CompanyID *string `json:"companyId,omitempty"` + CompanyUniversalName *string `json:"companyUniversalName,omitempty"` + Duration string `json:"duration"` + Description *string `json:"description,omitempty"` + Skills []string `json:"skills,omitempty"` + StartDate DateRange `json:"startDate"` + EndDate DateRange `json:"endDate"` +} + +// DateRange represents a date range with month, year, and text +type DateRange struct { + Month *string `json:"month,omitempty"` + Year *int `json:"year,omitempty"` + Text string `json:"text"` +} + +// Education represents educational background +type Education struct { + SchoolName string `json:"schoolName,omitempty"` + SchoolURL string `json:"schoolUrl,omitempty"` + Degree string `json:"degree,omitempty"` + FieldOfStudy *string `json:"fieldOfStudy,omitempty"` + Skills []string `json:"skills,omitempty"` + StartDate DateRange `json:"startDate,omitempty"` + EndDate DateRange `json:"endDate,omitempty"` + Period string `json:"period,omitempty"` +} + +// Certification represents a certification +type Certification struct { + Title string `json:"title,omitempty"` + IssuedAt string `json:"issuedAt,omitempty"` + IssuedBy string `json:"issuedBy,omitempty"` + IssuedByLink string `json:"issuedByLink,omitempty"` +} + +// Project represents a project +type Project struct { + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Duration string `json:"duration,omitempty"` + StartDate DateRange `json:"startDate,omitempty"` + EndDate DateRange `json:"endDate,omitempty"` +} + +// Volunteering represents volunteer experience +type Volunteering struct { + Role string `json:"role,omitempty"` + Duration string `json:"duration,omitempty"` + StartDate *DateRange `json:"startDate,omitempty"` + EndDate *DateRange `json:"endDate,omitempty"` + OrganizationName string `json:"organizationName,omitempty"` + OrganizationURL *string `json:"organizationUrl,omitempty"` + Cause string `json:"cause,omitempty"` +} + +// Skill represents a skill with optional positions and endorsements +type Skill struct { + Name string `json:"name,omitempty"` + Positions []string `json:"positions,omitempty"` + Endorsements string `json:"endorsements,omitempty"` +} + +// Course represents a course +type Course struct { + Title string `json:"title,omitempty"` + AssociatedWith string `json:"associatedWith,omitempty"` + AssociatedWithLink string `json:"associatedWithLink,omitempty"` +} + +// Publication represents a publication +type Publication struct { + Title string `json:"title,omitempty"` + PublishedAt string `json:"publishedAt,omitempty"` + Link string `json:"link,omitempty"` +} + +// HonorAndAward represents an honor or award +type HonorAndAward struct { + Title string `json:"title,omitempty"` + IssuedBy string `json:"issuedBy,omitempty"` + IssuedAt string `json:"issuedAt,omitempty"` + Description string `json:"description,omitempty"` + AssociatedWith string `json:"associatedWith,omitempty"` + AssociatedWithLink string `json:"associatedWithLink,omitempty"` +} + +// Language represents a language with proficiency level +type Language struct { + Name string `json:"name,omitempty"` + Proficiency string `json:"proficiency,omitempty"` +} + +// MoreProfile represents a related profile +type MoreProfile struct { + ID string `json:"id,omitempty"` + FirstName string `json:"firstName,omitempty"` + LastName string `json:"lastName,omitempty"` + Position *string `json:"position,omitempty"` + PublicIdentifier string `json:"publicIdentifier,omitempty"` + URL string `json:"linkedinUrl,omitempty"` +} + +// ShortCurrentPosition represents the short profile current positions array +type ShortCurrentPosition struct { + TenureAtPosition *Tenure `json:"tenureAtPosition,omitempty"` + CompanyName string `json:"companyName,omitempty"` + Title *string `json:"title,omitempty"` + Current *bool `json:"current,omitempty"` + TenureAtCompany *Tenure `json:"tenureAtCompany,omitempty"` + StartedOn *StartedOn `json:"startedOn,omitempty"` + CompanyID *string `json:"companyId,omitempty"` + CompanyLinkedinUrl *string `json:"companyLinkedinUrl,omitempty"` +} + +type Tenure struct { + NumYears *int `json:"numYears,omitempty"` + NumMonths *int `json:"numMonths,omitempty"` +} + +type StartedOn struct { + Month int `json:"month,omitempty"` + Year int `json:"year,omitempty"` +} + +// DatePeriod represents a date period with optional start and end parts +type DatePeriod struct { + Start *DateParts `json:"start,omitempty"` + End *DateParts `json:"end,omitempty"` +} + +type DateParts struct { + Month *int `json:"month,omitempty"` + Year *int `json:"year,omitempty"` + Day *int `json:"day,omitempty"` +} + +// CompanyWebsite represents company website with validation hint +type CompanyWebsite struct { + URL string `json:"url,omitempty"` + Domain string `json:"domain,omitempty"` + ValidEmailServer *bool `json:"validEmailServer,omitempty"` +} + +// Recommendation captures received recommendations +type Recommendation struct { + GivenBy *string `json:"givenBy,omitempty"` + GivenByLink *string `json:"givenByLink,omitempty"` + GivenAt *string `json:"givenAt,omitempty"` + Description string `json:"description,omitempty"` +} + +// Patent represents a patent entry +type Patent struct { + Title string `json:"title,omitempty"` + Number *string `json:"number,omitempty"` + IssuedAt string `json:"issuedAt,omitempty"` +} diff --git a/api/types/linkedin/seniorities/seniorities.go b/api/types/linkedin/seniorities/seniorities.go new file mode 100644 index 00000000..4382c51b --- /dev/null +++ b/api/types/linkedin/seniorities/seniorities.go @@ -0,0 +1,61 @@ +package seniorities + +import "github.com/masa-finance/tee-worker/pkg/util" + +// id represents a LinkedIn seniority level identifier +type Id string + +// Seniority level constants +const ( + InTraining Id = "100" + EntryLevel Id = "110" + Senior Id = "120" + Strategic Id = "130" + EntryLevelManager Id = "200" + ExperiencedManager Id = "210" + Director Id = "220" + VicePresident Id = "300" + CXO Id = "310" + Partner Id = "320" +) + +var All = util.NewSet( + InTraining, + EntryLevel, + Senior, + Strategic, + EntryLevelManager, + ExperiencedManager, + Director, + VicePresident, + CXO, + Partner, +) + +type SenioritiesConfig struct { + All util.Set[Id] + InTraining Id + EntryLevel Id + Senior Id + Strategic Id + EntryLevelManager Id + ExperiencedManager Id + Director Id + VicePresident Id + CXO Id + Partner Id +} + +var Seniorities = SenioritiesConfig{ + All: *All, + InTraining: InTraining, + EntryLevel: EntryLevel, + Senior: Senior, + Strategic: Strategic, + EntryLevelManager: EntryLevelManager, + ExperiencedManager: ExperiencedManager, + Director: Director, + VicePresident: VicePresident, + CXO: CXO, + Partner: Partner, +} diff --git a/api/types/llm.go b/api/types/llm.go new file mode 100644 index 00000000..ca990759 --- /dev/null +++ b/api/types/llm.go @@ -0,0 +1,15 @@ +package types + +type LLMProcessorRequest struct { + InputDatasetId string `json:"inputDatasetId"` + LLMProviderApiKey string `json:"llmProviderApiKey"` // encrypted api key by miner + Model string `json:"model"` + MultipleColumns bool `json:"multipleColumns"` + Prompt string `json:"prompt"` // example: summarize the content of this webpage: ${markdown} + Temperature string `json:"temperature"` // the actor expects a string + MaxTokens uint `json:"maxTokens"` +} + +type LLMProcessorResult struct { + LLMResponse string `json:"llmresponse"` +} diff --git a/api/types/reddit.go b/api/types/reddit.go new file mode 100644 index 00000000..dec89e4c --- /dev/null +++ b/api/types/reddit.go @@ -0,0 +1,197 @@ +package types + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/masa-finance/tee-worker/pkg/util" +) + +var AllRedditQueryTypes = util.NewSet(CapScrapeUrls, CapSearchPosts, CapSearchUsers, CapSearchCommunities) + +type RedditSortType string + +const ( + RedditSortRelevance RedditSortType = "relevance" + RedditSortHot RedditSortType = "hot" + RedditSortTop RedditSortType = "top" + RedditSortNew RedditSortType = "new" + RedditSortRising RedditSortType = "rising" + RedditSortComments RedditSortType = "comments" +) + +var AllRedditSortTypes = util.NewSet( + RedditSortRelevance, + RedditSortHot, + RedditSortTop, + RedditSortNew, + RedditSortRising, + RedditSortComments, +) + +// RedditStartURL represents a single start URL for the Apify Reddit scraper. +type RedditStartURL struct { + URL string `json:"url"` + Method string `json:"method"` +} + +type RedditItemType string + +const ( + RedditUserItem RedditItemType = "user" + RedditPostItem RedditItemType = "post" + RedditCommentItem RedditItemType = "comment" + RedditCommunityItem RedditItemType = "community" +) + +// RedditUser represents the data structure for a Reddit user from the Apify scraper. +type RedditUser struct { + ID string `json:"id"` + URL string `json:"url"` + Username string `json:"username"` + UserIcon string `json:"userIcon"` + PostKarma int `json:"postKarma"` + CommentKarma int `json:"commentKarma"` + Description string `json:"description"` + Over18 bool `json:"over18"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + DataType string `json:"dataType"` +} + +// RedditPost represents the data structure for a Reddit post from the Apify scraper. +type RedditPost struct { + ID string `json:"id"` + ParsedID string `json:"parsedId"` + URL string `json:"url"` + Username string `json:"username"` + Title string `json:"title"` + CommunityName string `json:"communityName"` + ParsedCommunityName string `json:"parsedCommunityName"` + Body string `json:"body"` + HTML *string `json:"html"` + NumberOfComments int `json:"numberOfComments"` + UpVotes int `json:"upVotes"` + IsVideo bool `json:"isVideo"` + IsAd bool `json:"isAd"` + Over18 bool `json:"over18"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + DataType string `json:"dataType"` +} + +// RedditComment represents the data structure for a Reddit comment from the Apify scraper. +type RedditComment struct { + ID string `json:"id"` + ParsedID string `json:"parsedId"` + URL string `json:"url"` + ParentID string `json:"parentId"` + Username string `json:"username"` + Category string `json:"category"` + CommunityName string `json:"communityName"` + Body string `json:"body"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + UpVotes int `json:"upVotes"` + NumberOfReplies int `json:"numberOfreplies"` + HTML string `json:"html"` + DataType string `json:"dataType"` +} + +// RedditCommunity represents the data structure for a Reddit community from the Apify scraper. +type RedditCommunity struct { + ID string `json:"id"` + Name string `json:"name"` + Title string `json:"title"` + HeaderImage string `json:"headerImage"` + Description string `json:"description"` + Over18 bool `json:"over18"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + NumberOfMembers int `json:"numberOfMembers"` + URL string `json:"url"` + DataType string `json:"dataType"` +} + +// RedditResponse represents a Reddit API response that can be any of the Reddit item types +type RedditResponse struct { + Type RedditItemType `json:"type"` + User *RedditUser `json:"user,omitempty"` + Post *RedditPost `json:"post,omitempty"` + Comment *RedditComment `json:"comment,omitempty"` + Community *RedditCommunity `json:"community,omitempty"` +} + +// UnmarshalJSON implements custom JSON unmarshaling for RedditResponse +func (r *RedditResponse) UnmarshalJSON(data []byte) error { + // First, unmarshal into a map to get the type + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Get the type field (check both 'type' and 'dataType' for compatibility) + var itemType RedditItemType + if typeData, exists := raw["type"]; exists { + if err := json.Unmarshal(typeData, &itemType); err != nil { + return fmt.Errorf("failed to unmarshal reddit response type: %w", err) + } + } else if typeData, exists := raw["dataType"]; exists { + if err := json.Unmarshal(typeData, &itemType); err != nil { + return fmt.Errorf("failed to unmarshal reddit response dataType: %w", err) + } + } else { + return fmt.Errorf("missing 'type' or 'dataType' field in reddit response") + } + + r.Type = itemType + + // Unmarshal the appropriate struct based on type + switch itemType { + case RedditUserItem: + r.User = &RedditUser{} + if err := json.Unmarshal(data, r.User); err != nil { + return fmt.Errorf("failed to unmarshal reddit user: %w", err) + } + case RedditPostItem: + r.Post = &RedditPost{} + if err := json.Unmarshal(data, r.Post); err != nil { + return fmt.Errorf("failed to unmarshal reddit post: %w", err) + } + case RedditCommentItem: + r.Comment = &RedditComment{} + if err := json.Unmarshal(data, r.Comment); err != nil { + return fmt.Errorf("failed to unmarshal reddit comment: %w", err) + } + case RedditCommunityItem: + r.Community = &RedditCommunity{} + if err := json.Unmarshal(data, r.Community); err != nil { + return fmt.Errorf("failed to unmarshal reddit community: %w", err) + } + default: + return fmt.Errorf("unknown Reddit response type: %s", itemType) + } + + return nil +} + +// MarshalJSON implements the json.Marshaler interface for RedditResponse. +// It unwraps the inner struct (User, Post, Comment, or Community) and marshals it directly. +func (r *RedditResponse) MarshalJSON() ([]byte, error) { + switch r.Type { + case RedditUserItem: + return json.Marshal(r.User) + case RedditPostItem: + return json.Marshal(r.Post) + case RedditCommentItem: + return json.Marshal(r.Comment) + case RedditCommunityItem: + return json.Marshal(r.Community) + default: + return nil, fmt.Errorf("unknown Reddit response type: %s", r.Type) + } +} + +// RedditItem is an alias for RedditResponse for backward compatibility +type RedditItem = RedditResponse diff --git a/api/types/reddit/reddit.go b/api/types/reddit/reddit.go deleted file mode 100644 index 6a5f7e7e..00000000 --- a/api/types/reddit/reddit.go +++ /dev/null @@ -1,152 +0,0 @@ -package reddit - -import ( - "encoding/json" - "fmt" - "time" -) - -// TODO: These are duplicated here and in tee-types/types/reddit.go -type ResponseType string - -const ( - UserResponse ResponseType = "user" - PostResponse ResponseType = "post" - CommentResponse ResponseType = "comment" - CommunityResponse ResponseType = "community" -) - -// User represents the data structure for a Reddit user from the Apify scraper. -type User struct { - ID string `json:"id"` - URL string `json:"url"` - Username string `json:"username"` - UserIcon string `json:"userIcon"` - PostKarma int `json:"postKarma"` - CommentKarma int `json:"commentKarma"` - Description string `json:"description"` - Over18 bool `json:"over18"` - CreatedAt time.Time `json:"createdAt"` - ScrapedAt time.Time `json:"scrapedAt"` - DataType string `json:"dataType"` -} - -// Post represents the data structure for a Reddit post from the Apify scraper. -type Post struct { - ID string `json:"id"` - ParsedID string `json:"parsedId"` - URL string `json:"url"` - Username string `json:"username"` - Title string `json:"title"` - CommunityName string `json:"communityName"` - ParsedCommunityName string `json:"parsedCommunityName"` - Body string `json:"body"` - HTML *string `json:"html"` - NumberOfComments int `json:"numberOfComments"` - UpVotes int `json:"upVotes"` - IsVideo bool `json:"isVideo"` - IsAd bool `json:"isAd"` - Over18 bool `json:"over18"` - CreatedAt time.Time `json:"createdAt"` - ScrapedAt time.Time `json:"scrapedAt"` - DataType string `json:"dataType"` -} - -// Comment represents the data structure for a Reddit comment from the Apify scraper. -type Comment struct { - ID string `json:"id"` - ParsedID string `json:"parsedId"` - URL string `json:"url"` - ParentID string `json:"parentId"` - Username string `json:"username"` - Category string `json:"category"` - CommunityName string `json:"communityName"` - Body string `json:"body"` - CreatedAt time.Time `json:"createdAt"` - ScrapedAt time.Time `json:"scrapedAt"` - UpVotes int `json:"upVotes"` - NumberOfReplies int `json:"numberOfreplies"` - HTML string `json:"html"` - DataType string `json:"dataType"` -} - -// Community represents the data structure for a Reddit community from the Apify scraper. -type Community struct { - ID string `json:"id"` - Name string `json:"name"` - Title string `json:"title"` - HeaderImage string `json:"headerImage"` - Description string `json:"description"` - Over18 bool `json:"over18"` - CreatedAt time.Time `json:"createdAt"` - ScrapedAt time.Time `json:"scrapedAt"` - NumberOfMembers int `json:"numberOfMembers"` - URL string `json:"url"` - DataType string `json:"dataType"` -} - -type TypeSwitch struct { - Type ResponseType `json:"dataType"` -} - -type Response struct { - TypeSwitch *TypeSwitch - User *User - Post *Post - Comment *Comment - Community *Community -} - -func (t *Response) UnmarshalJSON(data []byte) error { - t.TypeSwitch = &TypeSwitch{} - if err := json.Unmarshal(data, &t.TypeSwitch); err != nil { - return fmt.Errorf("failed to unmarshal reddit response type: %w", err) - } - - switch t.TypeSwitch.Type { - case UserResponse: - t.User = &User{} - if err := json.Unmarshal(data, t.User); err != nil { - return fmt.Errorf("failed to unmarshal reddit user: %w", err) - } - case PostResponse: - t.Post = &Post{} - if err := json.Unmarshal(data, t.Post); err != nil { - return fmt.Errorf("failed to unmarshal reddit post: %w", err) - } - case CommentResponse: - t.Comment = &Comment{} - if err := json.Unmarshal(data, t.Comment); err != nil { - return fmt.Errorf("failed to unmarshal reddit comment: %w", err) - } - case CommunityResponse: - t.Community = &Community{} - if err := json.Unmarshal(data, t.Community); err != nil { - return fmt.Errorf("failed to unmarshal reddit community: %w", err) - } - default: - return fmt.Errorf("unknown Reddit response type during unmarshal: %s", t.TypeSwitch.Type) - } - return nil -} - -// MarshalJSON implements the json.Marshaler interface for Response. -// It unwraps the inner struct (User, Post, Comment, or Community) and marshals it directly. -func (t *Response) MarshalJSON() ([]byte, error) { - if t.TypeSwitch == nil { - return []byte("null"), nil - } - - switch t.TypeSwitch.Type { - case UserResponse: - return json.Marshal(t.User) - case PostResponse: - return json.Marshal(t.Post) - case CommentResponse: - return json.Marshal(t.Comment) - case CommunityResponse: - return json.Marshal(t.Community) - default: - return nil, fmt.Errorf("unknown Reddit response type during marshal: %s", t.TypeSwitch.Type) - } -} diff --git a/api/types/reddit/reddit_test.go b/api/types/reddit/reddit_test.go deleted file mode 100644 index d12319e4..00000000 --- a/api/types/reddit/reddit_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package reddit_test - -import ( - "encoding/json" - "time" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - - "github.com/masa-finance/tee-worker/api/types/reddit" -) - -var _ = Describe("Response", func() { - Context("Unmarshalling JSON", func() { - It("should correctly unmarshal a UserResponse", func() { - jsonData := `{"dataType": "user", "id": "u1", "username": "testuser"}` - var resp reddit.Response - err := json.Unmarshal([]byte(jsonData), &resp) - Expect(err).NotTo(HaveOccurred()) - Expect(resp.User).NotTo(BeNil()) - Expect(resp.User.ID).To(Equal("u1")) - Expect(resp.User.Username).To(Equal("testuser")) - Expect(resp.Post).To(BeNil()) - Expect(resp.Comment).To(BeNil()) - Expect(resp.Community).To(BeNil()) - }) - - It("should correctly unmarshal a PostResponse", func() { - jsonData := `{"dataType": "post", "id": "p1", "title": "Test Post"}` - var resp reddit.Response - err := json.Unmarshal([]byte(jsonData), &resp) - Expect(err).NotTo(HaveOccurred()) - Expect(resp.Post).NotTo(BeNil()) - Expect(resp.Post.ID).To(Equal("p1")) - Expect(resp.Post.Title).To(Equal("Test Post")) - Expect(resp.User).To(BeNil()) - Expect(resp.Comment).To(BeNil()) - Expect(resp.Community).To(BeNil()) - }) - - It("should correctly unmarshal a CommentResponse", func() { - jsonData := `{"dataType": "comment", "id": "c1", "body": "Test Comment"}` - var resp reddit.Response - err := json.Unmarshal([]byte(jsonData), &resp) - Expect(err).NotTo(HaveOccurred()) - Expect(resp.Comment).NotTo(BeNil()) - Expect(resp.Comment.ID).To(Equal("c1")) - Expect(resp.Comment.Body).To(Equal("Test Comment")) - Expect(resp.User).To(BeNil()) - Expect(resp.Post).To(BeNil()) - Expect(resp.Community).To(BeNil()) - }) - - It("should correctly unmarshal a CommunityResponse", func() { - jsonData := `{"dataType": "community", "id": "co1", "name": "Test Community"}` - var resp reddit.Response - err := json.Unmarshal([]byte(jsonData), &resp) - Expect(err).NotTo(HaveOccurred()) - Expect(resp.Community).NotTo(BeNil()) - Expect(resp.Community.ID).To(Equal("co1")) - Expect(resp.Community.Name).To(Equal("Test Community")) - Expect(resp.User).To(BeNil()) - Expect(resp.Post).To(BeNil()) - Expect(resp.Comment).To(BeNil()) - }) - - It("should return an error for an unknown type", func() { - jsonData := `{"dataType": "unknown", "id": "u1"}` - var resp reddit.Response - err := json.Unmarshal([]byte(jsonData), &resp) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("unknown Reddit response type during unmarshal: unknown")) - }) - - It("should return an error for invalid JSON", func() { - jsonData := `{"type": "user", "id": "u1"` - var resp reddit.Response - err := json.Unmarshal([]byte(jsonData), &resp) - Expect(err).To(HaveOccurred()) - }) - }) - - Context("Marshalling JSON", func() { - It("should correctly marshal a UserResponse", func() { - resp := reddit.Response{ - TypeSwitch: &reddit.TypeSwitch{Type: reddit.UserResponse}, - User: &reddit.User{ID: "u1", Username: "testuser", DataType: "user"}, - } - jsonData, err := json.Marshal(&resp) - Expect(err).NotTo(HaveOccurred()) - expectedJSON := `{"id":"u1","url":"","username":"testuser","userIcon":"","postKarma":0,"commentKarma":0,"description":"","over18":false,"createdAt":"0001-01-01T00:00:00Z","scrapedAt":"0001-01-01T00:00:00Z","dataType":"user"}` - Expect(jsonData).To(MatchJSON(expectedJSON)) - }) - - It("should correctly marshal a PostResponse", func() { - resp := reddit.Response{ - TypeSwitch: &reddit.TypeSwitch{Type: reddit.PostResponse}, - Post: &reddit.Post{ID: "p1", Title: "Test Post", DataType: "post"}, - } - jsonData, err := json.Marshal(&resp) - Expect(err).NotTo(HaveOccurred()) - expectedJSON := `{"id":"p1","parsedId":"","url":"","username":"","title":"Test Post","communityName":"","parsedCommunityName":"","body":"","html":null,"numberOfComments":0,"upVotes":0,"isVideo":false,"isAd":false,"over18":false,"createdAt":"0001-01-01T00:00:00Z","scrapedAt":"0001-01-01T00:00:00Z","dataType":"post"}` - Expect(jsonData).To(MatchJSON(expectedJSON)) - }) - - It("should correctly marshal a CommentResponse", func() { - now := time.Now().UTC() - resp := reddit.Response{ - TypeSwitch: &reddit.TypeSwitch{Type: reddit.CommentResponse}, - Comment: &reddit.Comment{ID: "c1", Body: "Test Comment", CreatedAt: now, ScrapedAt: now, DataType: "comment"}, - } - jsonData, err := json.Marshal(&resp) - Expect(err).NotTo(HaveOccurred()) - - expectedComment := &reddit.Comment{ID: "c1", Body: "Test Comment", CreatedAt: now, ScrapedAt: now, DataType: "comment"} - expectedJSON, _ := json.Marshal(expectedComment) - Expect(jsonData).To(MatchJSON(expectedJSON)) - }) - - It("should correctly marshal a CommunityResponse", func() { - now := time.Now().UTC() - resp := reddit.Response{ - TypeSwitch: &reddit.TypeSwitch{Type: reddit.CommunityResponse}, - Community: &reddit.Community{ID: "co1", Name: "Test Community", CreatedAt: now, ScrapedAt: now, DataType: "community"}, - } - jsonData, err := json.Marshal(&resp) - Expect(err).NotTo(HaveOccurred()) - - expectedCommunity := &reddit.Community{ID: "co1", Name: "Test Community", CreatedAt: now, ScrapedAt: now, DataType: "community"} - expectedJSON, _ := json.Marshal(expectedCommunity) - Expect(jsonData).To(MatchJSON(expectedJSON)) - }) - - It("should return an error for an unknown type", func() { - resp := reddit.Response{ - TypeSwitch: &reddit.TypeSwitch{Type: "unknown"}, - } - _, err := json.Marshal(&resp) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("unknown Reddit response type during marshal: unknown")) - }) - - It("should marshal to null if TypeSwitch is nil", func() { - resp := reddit.Response{} - jsonData, err := json.Marshal(&resp) - Expect(err).NotTo(HaveOccurred()) - Expect(string(jsonData)).To(Equal("null")) - }) - }) -}) diff --git a/api/types/reddit_test.go b/api/types/reddit_test.go new file mode 100644 index 00000000..d31934a5 --- /dev/null +++ b/api/types/reddit_test.go @@ -0,0 +1,93 @@ +package types_test + +import ( + "encoding/json" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/api/types" +) + +var _ = Describe("RedditResponse", func() { + Describe("Unmarshalling", func() { + It("should unmarshal a user response", func() { + jsonData := `{"type": "user", "id": "user123", "username": "testuser"}` + var resp types.RedditResponse + err := json.Unmarshal([]byte(jsonData), &resp) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.User).ToNot(BeNil()) + Expect(resp.Post).To(BeNil()) + Expect(resp.User.ID).To(Equal("user123")) + Expect(resp.User.Username).To(Equal("testuser")) + }) + + It("should unmarshal a post response", func() { + jsonData := `{"type": "post", "id": "post123", "title": "Test Post"}` + var resp types.RedditResponse + err := json.Unmarshal([]byte(jsonData), &resp) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Post).ToNot(BeNil()) + Expect(resp.User).To(BeNil()) + Expect(resp.Post.ID).To(Equal("post123")) + Expect(resp.Post.Title).To(Equal("Test Post")) + }) + + It("should return an error for an unknown type", func() { + jsonData := `{"type": "unknown", "id": "123"}` + var resp types.RedditResponse + err := json.Unmarshal([]byte(jsonData), &resp) + Expect(err).To(MatchError("unknown Reddit response type: unknown")) + }) + }) + + Describe("Marshalling", func() { + It("should marshal a user response", func() { + now := time.Now() + resp := types.RedditResponse{ + Type: types.RedditUserItem, + User: &types.RedditUser{ + ID: "user123", + Username: "testuser", + CreatedAt: now, + CommentKarma: 10, + }, + } + + expectedJSON, err := json.Marshal(resp.User) + Expect(err).ToNot(HaveOccurred()) + + actualJSON, err := json.Marshal(&resp) + Expect(err).ToNot(HaveOccurred()) + + Expect(actualJSON).To(MatchJSON(expectedJSON)) + }) + + It("should marshal a post response", func() { + resp := types.RedditResponse{ + Type: types.RedditPostItem, + Post: &types.RedditPost{ + ID: "post123", + Title: "Test Post", + }, + } + + expectedJSON, err := json.Marshal(resp.Post) + Expect(err).ToNot(HaveOccurred()) + + actualJSON, err := json.Marshal(&resp) + Expect(err).ToNot(HaveOccurred()) + + Expect(actualJSON).To(MatchJSON(expectedJSON)) + }) + + It("should return an error for an unknown type", func() { + resp := types.RedditResponse{ + Type: "unknown", + } + _, err := json.Marshal(&resp) + Expect(err).To(HaveOccurred()) + }) + }) +}) diff --git a/api/types/tiktok.go b/api/types/tiktok.go new file mode 100644 index 00000000..ed657b7a --- /dev/null +++ b/api/types/tiktok.go @@ -0,0 +1,174 @@ +// Package types provides shared types between tee-worker and tee-indexer +package types + +// TikTokTranscriptionResult defines the structure of the result data for a TikTok transcription +type TikTokTranscriptionResult struct { + TranscriptionText string `json:"transcription_text"` + DetectedLanguage string `json:"detected_language,omitempty"` + VideoTitle string `json:"video_title,omitempty"` + OriginalURL string `json:"original_url"` + ThumbnailURL string `json:"thumbnail_url,omitempty"` +} + +type TikTokSearchByQueryResult struct { + URL string `json:"url"` + ID string `json:"id"` + Desc string `json:"desc"` + CreateTime string `json:"createTime"` + ScheduleTime int64 `json:"scheduleTime"` + Video TikTokVideo `json:"video"` + Author string `json:"author"` + Music TikTokMusic `json:"music"` + Challenges []any `json:"challenges"` // we don't have examples of this data yet... + Stats TikTokStats `json:"stats"` + IsActivityItem bool `json:"isActivityItem"` + DuetInfo TikTokDuetInfo `json:"duetInfo"` + WarnInfo []any `json:"warnInfo"` // we don't have examples of this data yet... + OriginalItem bool `json:"originalItem"` + OfficialItem bool `json:"officalItem"` + TextExtra []TikTokTextExtra `json:"textExtra"` + Secret bool `json:"secret"` + ForFriend bool `json:"forFriend"` + Digged bool `json:"digged"` + ItemCommentStatus int `json:"itemCommentStatus"` + ShowNotPass bool `json:"showNotPass"` + VL1 bool `json:"vl1"` + TakeDown int `json:"takeDown"` + ItemMute bool `json:"itemMute"` + EffectStickers []any `json:"effectStickers"` // we don't have examples of this data yet... + AuthorStats TikTokAuthorStats `json:"authorStats"` + PrivateItem bool `json:"privateItem"` + DuetEnabled bool `json:"duetEnabled"` + StitchEnabled bool `json:"stitchEnabled"` + StickersOnItem []any `json:"stickersOnItem"` // we don't have examples of this data yet... + IsAd bool `json:"isAd"` + ShareEnabled bool `json:"shareEnabled"` + Comments []any `json:"comments"` // we don't have examples of this data yet... + DuetDisplay int `json:"duetDisplay"` + StitchDisplay int `json:"stitchDisplay"` + IndexEnabled bool `json:"indexEnabled"` + DiversificationLabels []string `json:"diversificationLabels"` + AdAuthorization bool `json:"adAuthorization"` + AdLabelVersion int `json:"adLabelVersion"` + LocationCreated string `json:"locationCreated"` + Nickname string `json:"nickname"` + AuthorID string `json:"authorId"` + AuthorSecID string `json:"authorSecId"` + AvatarThumb string `json:"avatarThumb"` + DownloadSetting int `json:"downloadSetting"` + AuthorPrivate bool `json:"authorPrivate"` +} + +type TikTokSearchByTrending struct { + CountryCode string `json:"country_code"` + Cover string `json:"cover"` + Duration int `json:"duration"` + ID string `json:"id"` + ItemID string `json:"item_id"` + ItemURL string `json:"item_url"` + Region string `json:"region"` + Title string `json:"title"` +} + +type TikTokVideo struct { + ID string `json:"id"` + Height int `json:"height"` + Width int `json:"width"` + Duration int `json:"duration"` + Ratio string `json:"ratio"` + Cover string `json:"cover"` + OriginCover string `json:"origin_cover"` + DynamicCover string `json:"dynamic_cover"` + PlayAddr string `json:"play_addr"` + DownloadAddr string `json:"download_addr"` + ShareCover []string `json:"share_cover"` + ReflowCover string `json:"reflowCover"` + Bitrate int `json:"bitrate"` + EncodedType string `json:"encodedType"` + Format string `json:"format"` + VideoQuality string `json:"videoQuality"` + EncodeUserTag string `json:"encodeUserTag"` + CodecType string `json:"codecType"` + Definition string `json:"definition"` + SubtitleInfos []any `json:"subtitleInfos"` // we don't have examples of this data yet... + ZoomCover TikTokZoomCover `json:"zoomCover"` + VolumeInfo TikTokVolumeInfo `json:"volumeInfo"` + BitrateInfo []TikTokBitrateInfo `json:"bitrateInfo"` +} + +type TikTokZoomCover struct { + Cover240 string `json:"240"` + Cover480 string `json:"480"` + Cover720 string `json:"720"` + Cover960 string `json:"960"` +} + +type TikTokVolumeInfo struct { + Loudness float64 `json:"Loudness"` + Peak float64 `json:"Peak"` +} + +type TikTokBitrateInfo struct { + GearName string `json:"GearName"` + Bitrate int `json:"bitrate"` + QualityType int `json:"QualityType"` + PlayAddr TikTokPlayAddr `json:"PlayAddr"` + CodecType string `json:"CodecType"` +} + +type TikTokPlayAddr struct { + Uri string `json:"Uri"` + UrlList []string `json:"UrlList"` + DataSize string `json:"DataSize"` + UrlKey string `json:"UrlKey"` + FileHash string `json:"FileHash"` + FileCs string `json:"FileCs"` +} + +type TikTokMusic struct { + ID string `json:"id"` + Title string `json:"title"` + PlayURL string `json:"playUrl"` + CoverLarge string `json:"coverLarge"` + CoverMedium string `json:"coverMedium"` + CoverThumb string `json:"coverThumb"` + AuthorName string `json:"authorName"` + Original bool `json:"original"` + Duration int `json:"duration"` + Album string `json:"album"` + ScheduleSearchTime int64 `json:"scheduleSearchTime"` +} + +type TikTokStats struct { + DiggCount int64 `json:"diggCount"` + ShareCount int64 `json:"shareCount"` + CommentCount int64 `json:"commentCount"` + PlayCount int64 `json:"playCount"` +} + +type TikTokDuetInfo struct { + DuetFromID string `json:"duetFromId"` +} + +type TikTokTextExtra struct { + AwemeID string `json:"awemeId"` + Start int `json:"start"` + End int `json:"end"` + HashtagID string `json:"hashtagId"` + HashtagName string `json:"hashtagName"` + Type int `json:"type"` + SubType int `json:"subType"` + UserID string `json:"userId"` + IsCommerce bool `json:"isCommerce"` + UserUniqueID string `json:"userUniqueId"` + SecUID string `json:"secUid"` +} + +type TikTokAuthorStats struct { + FollowerCount int64 `json:"followerCount"` + FollowingCount int64 `json:"followingCount"` + Heart int64 `json:"heart"` + HeartCount int64 `json:"heartCount"` + VideoCount int64 `json:"videoCount"` + DiggCount int64 `json:"diggCount"` +} diff --git a/api/types/twitter.go b/api/types/twitter.go new file mode 100644 index 00000000..e699fbd8 --- /dev/null +++ b/api/types/twitter.go @@ -0,0 +1,187 @@ +// Package types provides shared types between tee-worker and tee-indexer +package types + +import "time" + +type TweetResult struct { + ID int64 `json:"id"` + TweetID string `json:"tweet_id"` + ConversationID string `json:"conversation_id"` + UserID string `json:"user_id"` + Text string `json:"text"` + CreatedAt time.Time `json:"created_at"` + Timestamp int64 `json:"timestamp"` + + ThreadCursor struct { + FocalTweetID string `json:"focal_tweet_id"` + ThreadID string `json:"thread_id"` + Cursor string `json:"cursor"` + CursorType string `json:"cursor_type"` + } + IsQuoted bool `json:"is_quoted"` + IsPin bool `json:"is_pin"` + IsReply bool `json:"is_reply"` + IsRetweet bool `json:"is_retweet"` + IsSelfThread bool `json:"is_self_thread"` + Likes int `json:"likes"` + Hashtags []string `json:"hashtags"` + HTML string `json:"html"` + Replies int `json:"replies"` + Retweets int `json:"retweets"` + URLs []string `json:"urls"` + Username string `json:"username"` + + Photos []Photo `json:"photos"` + + // Video type. + Videos []Video `json:"videos"` + + RetweetedStatusID string `json:"retweeted_status_id"` + Views int `json:"views"` + SensitiveContent bool `json:"sensitive_content"` + + // from twitterx + AuthorID string `json:"author_id"` + PublicMetrics PublicMetrics `json:"public_metrics"` + PossiblySensitive bool `json:"possibly_sensitive"` + Lang string `json:"lang"` + NewestID string `json:"newest_id"` + OldestID string `json:"oldest_id"` + ResultCount int `json:"result_count"` + + Error error `json:"error"` +} + +type PublicMetrics struct { + RetweetCount int `json:"retweet_count"` + ReplyCount int `json:"reply_count"` + LikeCount int `json:"like_count"` + QuoteCount int `json:"quote_count"` + BookmarkCount int `json:"bookmark_count"` + ImpressionCount int `json:"impression_count"` +} +type Photo struct { + ID string `json:"id"` + URL string `json:"url"` +} + +type Video struct { + ID string `json:"id"` + Preview string `json:"preview"` + URL string `json:"url"` + HLSURL string `json:"hls_url"` +} + +type ProfileResultApify struct { + ID int64 `json:"id"` + IDStr string `json:"id_str"` + Name string `json:"name"` + ScreenName string `json:"screen_name"` + Location string `json:"location"` + Description string `json:"description"` + URL *string `json:"url"` + Entities ProfileEntities `json:"entities"` + Protected bool `json:"protected"` + FollowersCount int `json:"followers_count"` + FastFollowersCount int `json:"fast_followers_count"` + NormalFollowersCount int `json:"normal_followers_count"` + FriendsCount int `json:"friends_count"` + ListedCount int `json:"listed_count"` + CreatedAt string `json:"created_at"` + FavouritesCount int `json:"favourites_count"` + UTCOffset *int `json:"utc_offset"` + TimeZone *string `json:"time_zone"` + GeoEnabled bool `json:"geo_enabled"` + Verified bool `json:"verified"` + StatusesCount int `json:"statuses_count"` + MediaCount int `json:"media_count"` + Lang *string `json:"lang"` + ContributorsEnabled bool `json:"contributors_enabled"` + IsTranslator bool `json:"is_translator"` + IsTranslationEnabled bool `json:"is_translation_enabled"` + ProfileBackgroundColor string `json:"profile_background_color"` + ProfileBackgroundImageURL *string `json:"profile_background_image_url"` + ProfileBackgroundImageURLHTTPS *string `json:"profile_background_image_url_https"` + ProfileBackgroundTile bool `json:"profile_background_tile"` + ProfileImageURL string `json:"profile_image_url"` + ProfileImageURLHTTPS string `json:"profile_image_url_https"` + ProfileLinkColor string `json:"profile_link_color"` + ProfileSidebarBorderColor string `json:"profile_sidebar_border_color"` + ProfileSidebarFillColor string `json:"profile_sidebar_fill_color"` + ProfileTextColor string `json:"profile_text_color"` + ProfileUseBackgroundImage bool `json:"profile_use_background_image"` + HasExtendedProfile bool `json:"has_extended_profile"` + DefaultProfile bool `json:"default_profile"` + DefaultProfileImage bool `json:"default_profile_image"` + PinnedTweetIDs []int64 `json:"pinned_tweet_ids"` + PinnedTweetIDsStr []string `json:"pinned_tweet_ids_str"` + HasCustomTimelines bool `json:"has_custom_timelines"` + CanMediaTag bool `json:"can_media_tag"` + FollowedBy bool `json:"followed_by"` + Following bool `json:"following"` + LiveFollowing bool `json:"live_following"` + FollowRequestSent bool `json:"follow_request_sent"` + Notifications bool `json:"notifications"` + Muting bool `json:"muting"` + Blocking bool `json:"blocking"` + BlockedBy bool `json:"blocked_by"` + AdvertiserAccountType string `json:"advertiser_account_type"` + AdvertiserAccountServiceLevels []string `json:"advertiser_account_service_levels"` + AnalyticsType string `json:"analytics_type"` + BusinessProfileState string `json:"business_profile_state"` + TranslatorType string `json:"translator_type"` + WithheldInCountries []string `json:"withheld_in_countries"` + RequireSomeConsent bool `json:"require_some_consent"` + Type string `json:"type"` + TargetUsername string `json:"target_username"` + Email *string `json:"email"` +} + +type ProfileEntities struct { + URL *URLEntities `json:"url,omitempty"` + Description *URLEntities `json:"description,omitempty"` +} + +type URLEntities struct { + URLs []URLEntity `json:"urls,omitempty"` +} + +type URLEntity struct { + URL string `json:"url"` + ExpandedURL string `json:"expanded_url"` + DisplayURL string `json:"display_url"` + Indices []int `json:"indices"` +} + +type ProfileResultScraper struct { + Avatar string `json:"avatar"` + Banner string `json:"banner"` + Biography string `json:"biography"` + Birthday string `json:"birthday"` + FollowersCount int `json:"followers_count"` + FollowingCount int `json:"following_count"` + FriendsCount int `json:"friends_count"` + IsPrivate bool `json:"is_private"` + IsVerified bool `json:"is_verified"` + IsBlueVerified bool `json:"is_blue_verified"` + Joined *time.Time `json:"joined"` + LikesCount int `json:"likes_count"` + ListedCount int `json:"listed_count"` + Location string `json:"location"` + Name string `json:"name"` + PinnedTweetIDs []string `json:"pinned_tweet_ids"` + TweetsCount int `json:"tweets_count"` + URL string `json:"url"` + UserID string `json:"user_id"` + Username string `json:"username"` + Website string `json:"website"` + Sensitive bool `json:"sensitive"` + Following bool `json:"following"` + FollowedBy bool `json:"followed_by"` + MediaCount int `json:"media_count"` + FastFollowersCount int `json:"fast_followers_count"` + NormalFollowersCount int `json:"normal_followers_count"` + ProfileImageShape string `json:"profile_image_shape"` + HasGraduatedAccess bool `json:"has_graduated_access"` + CanHighlightTweets bool `json:"can_highlight_tweets"` +} diff --git a/api/types/types.go b/api/types/types.go new file mode 100644 index 00000000..d6b47ab4 --- /dev/null +++ b/api/types/types.go @@ -0,0 +1,7 @@ +package types + +import ( + linkedin "github.com/masa-finance/tee-worker/api/types/linkedin" +) + +var LinkedIn = linkedin.LinkedIn diff --git a/api/types/types_suite_test.go b/api/types/types_suite_test.go new file mode 100644 index 00000000..3356638f --- /dev/null +++ b/api/types/types_suite_test.go @@ -0,0 +1,13 @@ +package types_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestTypes(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Types Suite") +} diff --git a/api/types/web.go b/api/types/web.go new file mode 100644 index 00000000..dda1cea0 --- /dev/null +++ b/api/types/web.go @@ -0,0 +1,55 @@ +package types + +import ( + "time" +) + +// WebStartURL represents a single start URL configuration for web scraping +type WebStartURL struct { + URL string `json:"url"` + Method string `json:"method"` +} + +type WebQueryType string + +const ( + WebScraper WebQueryType = "scraper" +) + +// WebScraperRequest represents the customizable configuration for web scraping operations +type WebScraperRequest struct { + StartUrls []WebStartURL `json:"startUrls"` + MaxCrawlDepth int `json:"maxCrawlDepth"` + MaxCrawlPages int `json:"maxCrawlPages"` + RespectRobotsTxtFile bool `json:"respectRobotsTxtFile"` + SaveMarkdown bool `json:"saveMarkdown"` +} + +// WebCrawlInfo contains information about the crawling process +type WebCrawlInfo struct { + LoadedURL string `json:"loadedUrl"` + LoadedTime time.Time `json:"loadedTime"` + ReferrerURL string `json:"referrerUrl"` + Depth int `json:"depth"` + HTTPStatusCode int `json:"httpStatusCode"` +} + +// WebMetadata contains metadata extracted from the scraped page +type WebMetadata struct { + CanonicalURL string `json:"canonicalUrl"` + Title string `json:"title"` + Description *string `json:"description"` + Author *string `json:"author"` + Keywords *string `json:"keywords"` + LanguageCode *string `json:"languageCode"` +} + +// WebScraperResult represents the complete result from web scraping a single page +type WebScraperResult struct { + URL string `json:"url"` + Crawl WebCrawlInfo `json:"crawl"` + Metadata WebMetadata `json:"metadata"` + Text string `json:"text"` + Markdown string `json:"markdown"` + LLMResponse string `json:"llmresponse,omitempty"` // populated by LLM processor +} diff --git a/go.mod b/go.mod index ecef5a83..40bfcd2b 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,6 @@ require ( github.com/joho/godotenv v1.5.1 github.com/labstack/echo-contrib v0.17.4 github.com/labstack/echo/v4 v4.13.4 - github.com/masa-finance/tee-types v1.2.0 github.com/onsi/ginkgo/v2 v2.26.0 github.com/onsi/gomega v1.38.2 github.com/sirupsen/logrus v1.9.3 diff --git a/go.sum b/go.sum index 410d9d8b..573a0d6f 100644 --- a/go.sum +++ b/go.sum @@ -44,8 +44,6 @@ github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0 github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg= -github.com/masa-finance/tee-types v1.2.0 h1:RqyDMlDY0XCXAw6XQWZ+4R4az4AC+C5r30cmhtHfhf0= -github.com/masa-finance/tee-types v1.2.0/go.mod h1:sB98t0axFlPi2d0zUPFZSQ84mPGwbr9eRY5yLLE3fSc= github.com/masa-finance/twitter-scraper v1.0.2 h1:him+wvYZHg/7EDdy73z1ceUywDJDRAhPLD2CSEa2Vfk= github.com/masa-finance/twitter-scraper v1.0.2/go.mod h1:38MY3g/h4V7Xl4HbW9lnkL8S3YiFZenBFv86hN57RG8= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 1bb8146b..a946535f 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -10,7 +10,7 @@ import ( . "github.com/onsi/gomega" "github.com/sirupsen/logrus" - teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/api/types" . "github.com/masa-finance/tee-worker/internal/api" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/pkg/client" @@ -44,8 +44,8 @@ var _ = Describe("API", func() { return err } - signature, err := c.CreateJobSignature(teetypes.Job{ - Type: teetypes.WebJob, + signature, err := c.CreateJobSignature(types.Job{ + Type: types.WebJob, Arguments: map[string]interface{}{}, }) if err != nil { @@ -72,8 +72,8 @@ var _ = Describe("API", func() { It("should submit a job and get the correct result", func() { // Step 1: Create the job request // we use TikTok transcription here as it's supported by all workers without any unique config - job := teetypes.Job{ - Type: teetypes.TiktokJob, + job := types.Job{ + Type: types.TiktokJob, Arguments: map[string]interface{}{ "type": "transcription", "video_url": "https://www.tiktok.com/@theblockrunner.com/video/7227579907361066282", @@ -107,7 +107,7 @@ var _ = Describe("API", func() { It("bubble up errors", func() { // Step 1: Create the job request - job := teetypes.Job{ + job := types.Job{ Type: "not-existing scraper", Arguments: map[string]interface{}{ "url": "google", diff --git a/internal/api/routes.go b/internal/api/routes.go index 7c631b52..7fe1b667 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -5,9 +5,11 @@ import ( "net/http" "github.com/labstack/echo/v4" + teejob "github.com/masa-finance/tee-worker/api/tee" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/jobserver" "github.com/masa-finance/tee-worker/pkg/tee" + "github.com/sirupsen/logrus" ) @@ -21,7 +23,7 @@ func generate(c echo.Context) error { job.WorkerID = tee.WorkerID // attach worker ID to job - encryptedSignature, err := job.GenerateJobSignature() + encryptedSignature, err := teejob.GenerateJobSignature(job) if err != nil { logrus.Errorf("Error while generating job signature: %s", err) return c.JSON(http.StatusInternalServerError, types.JobError{Error: err.Error()}) @@ -46,7 +48,7 @@ func add(jobServer *jobserver.JobServer) func(c echo.Context) error { return c.JSON(http.StatusBadRequest, types.JobError{Error: err.Error()}) } - job, err := jobRequest.DecryptJob() + job, err := teejob.DecryptJob(&jobRequest) if err != nil { logrus.Errorf("Error while decrypting job %s: %s", jobRequest, err) return c.JSON(http.StatusInternalServerError, types.JobError{Error: fmt.Sprintf("Error while decrypting job: %s", err.Error())}) @@ -84,7 +86,7 @@ func status(jobServer *jobserver.JobServer) func(c echo.Context) error { return c.JSON(http.StatusInternalServerError, types.JobError{Error: res.Error}) } - sealedData, err := res.Seal() + sealedData, err := teejob.SealJobResult(&res) if err != nil { logrus.Errorf("Error while sealing status response for job %s: %s", res.Job.UUID, err) return c.JSON(http.StatusInternalServerError, types.JobError{Error: err.Error()}) @@ -96,7 +98,7 @@ func status(jobServer *jobserver.JobServer) func(c echo.Context) error { } func result(c echo.Context) error { - payload := types.EncryptedRequest{ + payload := teejob.EncryptedRequest{ EncryptedResult: "", EncryptedRequest: "", } diff --git a/internal/apify/actors.go b/internal/apify/actors.go index 5a188f1c..a349e70b 100644 --- a/internal/apify/actors.go +++ b/internal/apify/actors.go @@ -1,6 +1,8 @@ package apify -import teetypes "github.com/masa-finance/tee-types/types" +import ( + "github.com/masa-finance/tee-worker/api/types" +) type ActorId string @@ -29,8 +31,8 @@ type defaultActorInput map[string]any type ActorConfig struct { ActorId ActorId DefaultInput defaultActorInput - Capabilities []teetypes.Capability - JobType teetypes.JobType + Capabilities []types.Capability + JobType types.JobType } // Actors is a list of actor configurations for Apify. Omitting LLM for now as it's not a standalone actor / has no dedicated capabilities @@ -38,37 +40,37 @@ var Actors = []ActorConfig{ { ActorId: ActorIds.RedditScraper, DefaultInput: defaultActorInput{}, - Capabilities: teetypes.RedditCaps, - JobType: teetypes.RedditJob, + Capabilities: types.RedditCaps, + JobType: types.RedditJob, }, { ActorId: ActorIds.TikTokSearchScraper, DefaultInput: defaultActorInput{"proxy": map[string]any{"useApifyProxy": true}}, - Capabilities: []teetypes.Capability{teetypes.CapSearchByQuery}, - JobType: teetypes.TiktokJob, + Capabilities: []types.Capability{types.CapSearchByQuery}, + JobType: types.TiktokJob, }, { ActorId: ActorIds.TikTokTrendingScraper, DefaultInput: defaultActorInput{}, - Capabilities: []teetypes.Capability{teetypes.CapSearchByTrending}, - JobType: teetypes.TiktokJob, + Capabilities: []types.Capability{types.CapSearchByTrending}, + JobType: types.TiktokJob, }, { ActorId: ActorIds.TwitterFollowers, DefaultInput: defaultActorInput{"maxFollowers": 200, "maxFollowings": 200}, - Capabilities: teetypes.TwitterApifyCaps, - JobType: teetypes.TwitterApifyJob, + Capabilities: []types.Capability{types.CapGetFollowing, types.CapGetFollowers}, + JobType: types.TwitterJob, }, { ActorId: ActorIds.WebScraper, DefaultInput: defaultActorInput{"startUrls": []map[string]any{{"url": "https://docs.learnbittensor.org"}}}, - Capabilities: teetypes.WebCaps, - JobType: teetypes.WebJob, + Capabilities: types.WebCaps, + JobType: types.WebJob, }, { ActorId: ActorIds.LinkedInSearchProfile, DefaultInput: defaultActorInput{}, - Capabilities: teetypes.LinkedInCaps, - JobType: teetypes.LinkedInJob, + Capabilities: types.LinkedInCaps, + JobType: types.LinkedInJob, }, } diff --git a/internal/capabilities/detector.go b/internal/capabilities/detector.go index 6ccfd4ac..9c7da40d 100644 --- a/internal/capabilities/detector.go +++ b/internal/capabilities/detector.go @@ -6,29 +6,29 @@ import ( "maps" - util "github.com/masa-finance/tee-types/pkg/util" - teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/apify" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs/twitter" "github.com/masa-finance/tee-worker/pkg/client" + util "github.com/masa-finance/tee-worker/pkg/util" "github.com/sirupsen/logrus" ) // JobServerInterface defines the methods we need from JobServer to avoid circular dependencies type JobServerInterface interface { - GetWorkerCapabilities() teetypes.WorkerCapabilities + GetWorkerCapabilities() types.WorkerCapabilities } // DetectCapabilities automatically detects available capabilities based on configuration // Always performs real capability detection by probing APIs and actors to ensure accurate reporting -func DetectCapabilities(jc config.JobConfiguration, jobServer JobServerInterface) teetypes.WorkerCapabilities { +func DetectCapabilities(jc config.JobConfiguration, jobServer JobServerInterface) types.WorkerCapabilities { // Always perform real capability detection to ensure accurate reporting // This guarantees miners report only capabilities they actually have access to - capabilities := make(teetypes.WorkerCapabilities) + capabilities := make(types.WorkerCapabilities) // Start with always available capabilities - maps.Copy(capabilities, teetypes.AlwaysAvailableCapabilities) + maps.Copy(capabilities, types.AlwaysAvailableCapabilities) // Check what Twitter authentication methods are available accounts := jc.GetStringSlice("twitter_accounts", nil) @@ -42,22 +42,52 @@ func DetectCapabilities(jc config.JobConfiguration, jobServer JobServerInterface hasApifyKey := hasValidApifyKey(apifyApiKey) hasLLMKey := geminiApiKey.IsValid() || claudeApiKey.IsValid() - // Add Twitter-specific capabilities based on available authentication + // Add Twitter capabilities based on available authentication + var twitterCaps []types.Capability + + // Add credential-based capabilities if we have accounts if hasAccounts { - capabilities[teetypes.TwitterCredentialJob] = teetypes.TwitterCredentialCaps + twitterCaps = append(twitterCaps, + types.CapSearchByQuery, + types.CapSearchByProfile, + types.CapGetById, + types.CapGetReplies, + types.CapGetRetweeters, + types.CapGetMedia, + types.CapGetProfileById, + types.CapGetTrends, + types.CapGetSpace, + types.CapGetProfile, + types.CapGetTweets, + ) } + // Add API-based capabilities if we have API keys if hasApiKeys { - // Start with basic API capabilities - apiCaps := make([]teetypes.Capability, len(teetypes.TwitterAPICaps)) - copy(apiCaps, teetypes.TwitterAPICaps) - - // Check for elevated API keys and add searchbyfullarchive capability + // Add basic API capabilities for any valid API key + twitterCaps = append(twitterCaps, + types.CapSearchByQuery, + types.CapSearchByProfile, + types.CapGetById, + types.CapGetReplies, + types.CapGetRetweeters, + types.CapGetMedia, + types.CapGetProfileById, + types.CapGetTrends, + types.CapGetSpace, + types.CapGetProfile, + types.CapGetTweets, + ) + + // Check for elevated API capabilities if hasElevatedApiKey(apiKeys) { - apiCaps = append(apiCaps, teetypes.CapSearchByFullArchive) + twitterCaps = append(twitterCaps, types.CapSearchByFullArchive) } + } - capabilities[teetypes.TwitterApiJob] = apiCaps + // Only add capabilities if we have any supported capabilities + if len(twitterCaps) > 0 { + capabilities[types.TwitterJob] = twitterCaps } if hasApifyKey { @@ -67,18 +97,18 @@ func DetectCapabilities(jc config.JobConfiguration, jobServer JobServerInterface logrus.Errorf("Failed to create Apify client for access probes: %v", err) } else { // Aggregate capabilities per job from accessible actors - jobToSet := map[teetypes.JobType]*util.Set[teetypes.Capability]{} + jobToSet := map[types.JobType]*util.Set[types.Capability]{} for _, actor := range apify.Actors { // Web requires a valid Gemini API key - if actor.JobType == teetypes.WebJob && !hasLLMKey { + if actor.JobType == types.WebJob && !hasLLMKey { logrus.Debug("Skipping Web actor due to missing Gemini key") continue } if ok, _ := c.ProbeActorAccess(actor.ActorId, actor.DefaultInput); ok { if _, exists := jobToSet[actor.JobType]; !exists { - jobToSet[actor.JobType] = util.NewSet[teetypes.Capability]() + jobToSet[actor.JobType] = util.NewSet[types.Capability]() } jobToSet[actor.JobType].Add(actor.Capabilities...) } else { @@ -94,32 +124,6 @@ func DetectCapabilities(jc config.JobConfiguration, jobServer JobServerInterface } } - // Add general TwitterJob capability if any Twitter auth is available - // TODO: this will get cleaned up with unique twitter capabilities - if hasAccounts || hasApiKeys || hasApifyKey { - var twitterJobCaps []teetypes.Capability - // Use the most comprehensive capabilities available - if hasAccounts { - twitterJobCaps = teetypes.TwitterCredentialCaps - } else { - // Use API capabilities if we only have keys - twitterJobCaps = make([]teetypes.Capability, len(teetypes.TwitterAPICaps)) - copy(twitterJobCaps, teetypes.TwitterAPICaps) - - // Check for elevated API keys and add searchbyfullarchive capability - if hasElevatedApiKey(apiKeys) { - twitterJobCaps = append(twitterJobCaps, teetypes.CapSearchByFullArchive) - } - } - - // Add Apify capabilities if available - if hasApifyKey { - twitterJobCaps = append(twitterJobCaps, teetypes.TwitterApifyCaps...) - } - - capabilities[teetypes.TwitterJob] = twitterJobCaps - } - return capabilities } diff --git a/internal/capabilities/detector_test.go b/internal/capabilities/detector_test.go index 7c263adc..c1d87858 100644 --- a/internal/capabilities/detector_test.go +++ b/internal/capabilities/detector_test.go @@ -7,23 +7,23 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/api/types" . "github.com/masa-finance/tee-worker/internal/capabilities" "github.com/masa-finance/tee-worker/internal/config" ) // MockJobServer implements JobServerInterface for testing type MockJobServer struct { - capabilities teetypes.WorkerCapabilities + capabilities types.WorkerCapabilities } -func (m *MockJobServer) GetWorkerCapabilities() teetypes.WorkerCapabilities { +func (m *MockJobServer) GetWorkerCapabilities() types.WorkerCapabilities { return m.capabilities } var _ = Describe("DetectCapabilities", func() { DescribeTable("capability detection scenarios", - func(jc config.JobConfiguration, jobServer JobServerInterface, expected teetypes.WorkerCapabilities) { + func(jc config.JobConfiguration, jobServer JobServerInterface, expected types.WorkerCapabilities) { got := DetectCapabilities(jc, jobServer) // Extract job type keys and sort for consistent comparison @@ -47,24 +47,24 @@ var _ = Describe("DetectCapabilities", func() { Entry("With JobServer - performs real detection (JobServer ignored)", config.JobConfiguration{}, &MockJobServer{ - capabilities: teetypes.WorkerCapabilities{ - teetypes.WebJob: {teetypes.CapScraper}, - teetypes.TelemetryJob: {teetypes.CapTelemetry}, - teetypes.TiktokJob: {teetypes.CapTranscription}, - teetypes.TwitterJob: {teetypes.CapSearchByQuery, teetypes.CapGetById, teetypes.CapGetProfileById}, + capabilities: types.WorkerCapabilities{ + types.WebJob: {types.CapScraper}, + types.TelemetryJob: {types.CapTelemetry}, + types.TiktokJob: {types.CapTranscription}, + types.TwitterJob: {types.CapSearchByQuery, types.CapGetById, types.CapGetProfileById}, }, }, - teetypes.WorkerCapabilities{ - teetypes.TelemetryJob: {teetypes.CapTelemetry}, - teetypes.TiktokJob: {teetypes.CapTranscription}, + types.WorkerCapabilities{ + types.TelemetryJob: {types.CapTelemetry}, + types.TiktokJob: {types.CapTranscription}, }, ), Entry("Without JobServer - basic capabilities only", config.JobConfiguration{}, nil, - teetypes.WorkerCapabilities{ - teetypes.TelemetryJob: {teetypes.CapTelemetry}, - teetypes.TiktokJob: {teetypes.CapTranscription}, + types.WorkerCapabilities{ + types.TelemetryJob: {types.CapTelemetry}, + types.TiktokJob: {types.CapTranscription}, }, ), Entry("With Twitter accounts - adds credential capabilities", @@ -72,36 +72,10 @@ var _ = Describe("DetectCapabilities", func() { "twitter_accounts": []string{"account1", "account2"}, }, nil, - teetypes.WorkerCapabilities{ - teetypes.TelemetryJob: {teetypes.CapTelemetry}, - teetypes.TiktokJob: {teetypes.CapTranscription}, - teetypes.TwitterCredentialJob: teetypes.TwitterCredentialCaps, - teetypes.TwitterJob: teetypes.TwitterCredentialCaps, - }, - ), - Entry("With Twitter API keys - adds API capabilities", - config.JobConfiguration{ - "twitter_api_keys": []string{"key1", "key2"}, - }, - nil, - teetypes.WorkerCapabilities{ - teetypes.TelemetryJob: {teetypes.CapTelemetry}, - teetypes.TiktokJob: {teetypes.CapTranscription}, - teetypes.TwitterApiJob: teetypes.TwitterAPICaps, - teetypes.TwitterJob: teetypes.TwitterAPICaps, - }, - ), - Entry("With mock elevated Twitter API keys - only basic capabilities detected", - config.JobConfiguration{ - "twitter_api_keys": []string{"Bearer abcd1234-ELEVATED"}, - }, - nil, - teetypes.WorkerCapabilities{ - teetypes.TelemetryJob: {teetypes.CapTelemetry}, - teetypes.TiktokJob: {teetypes.CapTranscription}, - // Note: Mock elevated keys will be detected as basic since we can't make real API calls in tests - teetypes.TwitterApiJob: teetypes.TwitterAPICaps, - teetypes.TwitterJob: teetypes.TwitterAPICaps, + types.WorkerCapabilities{ + types.TelemetryJob: {types.CapTelemetry}, + types.TiktokJob: {types.CapTranscription}, + types.TwitterJob: types.TwitterCaps, }, ), ) @@ -133,13 +107,13 @@ var _ = Describe("DetectCapabilities", func() { config.JobConfiguration{ "twitter_accounts": []string{"user1:pass1"}, }, - []string{"telemetry", "tiktok", "twitter", "twitter-credential"}, + []string{"telemetry", "tiktok", "twitter"}, ), Entry("With Twitter API keys", config.JobConfiguration{ "twitter_api_keys": []string{"key1"}, }, - []string{"telemetry", "tiktok", "twitter", "twitter-api"}, + []string{"telemetry", "tiktok", "twitter"}, ), ) }) @@ -158,24 +132,24 @@ var _ = Describe("DetectCapabilities", func() { caps := DetectCapabilities(jc, nil) // TikTok should gain search capabilities with valid key - tiktokCaps, ok := caps[teetypes.TiktokJob] + tiktokCaps, ok := caps[types.TiktokJob] Expect(ok).To(BeTrue(), "expected tiktok capabilities to be present") - Expect(tiktokCaps).To(ContainElement(teetypes.CapSearchByQuery), "expected tiktok to include CapSearchByQuery capability") - Expect(tiktokCaps).To(ContainElement(teetypes.CapSearchByTrending), "expected tiktok to include CapSearchByTrending capability") + Expect(tiktokCaps).To(ContainElement(types.CapSearchByQuery), "expected tiktok to include CapSearchByQuery capability") + Expect(tiktokCaps).To(ContainElement(types.CapSearchByTrending), "expected tiktok to include CapSearchByTrending capability") // Twitter-Apify job should be present with follower/following capabilities - twitterApifyCaps, ok := caps[teetypes.TwitterApifyJob] + twitterApifyCaps, ok := caps[types.TwitterJob] Expect(ok).To(BeTrue(), "expected twitter-apify capabilities to be present") - Expect(twitterApifyCaps).To(ContainElement(teetypes.CapGetFollowers), "expected twitter-apify to include CapGetFollowers capability") - Expect(twitterApifyCaps).To(ContainElement(teetypes.CapGetFollowing), "expected twitter-apify to include CapGetFollowing capability") + Expect(twitterApifyCaps).To(ContainElement(types.CapGetFollowers), "expected twitter-apify to include CapGetFollowers capability") + Expect(twitterApifyCaps).To(ContainElement(types.CapGetFollowing), "expected twitter-apify to include CapGetFollowing capability") // Reddit should be present (only if rented!) - redditCaps, hasReddit := caps[teetypes.RedditJob] + redditCaps, hasReddit := caps[types.RedditJob] Expect(hasReddit).To(BeTrue(), "expected reddit capabilities to be present") - Expect(redditCaps).To(ContainElement(teetypes.CapScrapeUrls), "expected reddit to include CapScrapeUrls capability") - Expect(redditCaps).To(ContainElement(teetypes.CapSearchPosts), "expected reddit to include CapSearchPosts capability") - Expect(redditCaps).To(ContainElement(teetypes.CapSearchUsers), "expected reddit to include CapSearchUsers capability") - Expect(redditCaps).To(ContainElement(teetypes.CapSearchCommunities), "expected reddit to include CapSearchCommunities capability") + Expect(redditCaps).To(ContainElement(types.CapScrapeUrls), "expected reddit to include CapScrapeUrls capability") + Expect(redditCaps).To(ContainElement(types.CapSearchPosts), "expected reddit to include CapSearchPosts capability") + Expect(redditCaps).To(ContainElement(types.CapSearchUsers), "expected reddit to include CapSearchUsers capability") + Expect(redditCaps).To(ContainElement(types.CapSearchCommunities), "expected reddit to include CapSearchCommunities capability") }) It("should add enhanced capabilities when valid Apify API key is provided alongside a Gemini API key", func() { apifyKey := os.Getenv("APIFY_API_KEY") @@ -195,15 +169,15 @@ var _ = Describe("DetectCapabilities", func() { caps := DetectCapabilities(jc, nil) // Web should be present - webCaps, hasWeb := caps[teetypes.WebJob] + webCaps, hasWeb := caps[types.WebJob] Expect(hasWeb).To(BeTrue(), "expected web capabilities to be present") - Expect(webCaps).To(ContainElement(teetypes.CapScraper), "expected web to include CapScraper capability") + Expect(webCaps).To(ContainElement(types.CapScraper), "expected web to include CapScraper capability") }) }) }) // Helper function to check if a job type exists in capabilities -func hasJobType(capabilities teetypes.WorkerCapabilities, jobName string) bool { - _, exists := capabilities[teetypes.JobType(jobName)] +func hasJobType(capabilities types.WorkerCapabilities, jobName string) bool { + _, exists := capabilities[types.JobType(jobName)] return exists } diff --git a/internal/config/config.go b/internal/config/config.go index 9bd4d233..c5e8938a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,8 +11,9 @@ import ( "time" "github.com/joho/godotenv" - teeargs "github.com/masa-finance/tee-types/args" "github.com/sirupsen/logrus" + + "github.com/masa-finance/tee-worker/api/args/llm/process" ) var ( @@ -346,9 +347,9 @@ type LlmConfig struct { // GetModelAndKey returns the first available model and API key based on which keys are valid func (lc LlmConfig) GetModelAndKey() (model string, key string, err error) { if lc.ClaudeApiKey.IsValid() { - return teeargs.LLMDefaultClaudeModel, string(lc.ClaudeApiKey), nil + return process.DefaultClaudeModel, string(lc.ClaudeApiKey), nil } else if lc.GeminiApiKey.IsValid() { - return teeargs.LLMDefaultGeminiModel, string(lc.GeminiApiKey), nil + return process.DefaultGeminiModel, string(lc.GeminiApiKey), nil } return "", "", errors.New("no valid llm api key found") } diff --git a/internal/jobs/linkedin.go b/internal/jobs/linkedin.go index 933b1d2a..8a111eb5 100644 --- a/internal/jobs/linkedin.go +++ b/internal/jobs/linkedin.go @@ -7,21 +7,20 @@ import ( "github.com/sirupsen/logrus" + "github.com/masa-finance/tee-worker/api/args" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs/linkedinapify" "github.com/masa-finance/tee-worker/internal/jobs/stats" "github.com/masa-finance/tee-worker/pkg/client" - teeargs "github.com/masa-finance/tee-types/args" - profileArgs "github.com/masa-finance/tee-types/args/linkedin/profile" - teetypes "github.com/masa-finance/tee-types/types" - profileTypes "github.com/masa-finance/tee-types/types/linkedin/profile" + pArgs "github.com/masa-finance/tee-worker/api/args/linkedin/profile" + pTypes "github.com/masa-finance/tee-worker/api/types/linkedin/profile" ) // LinkedInApifyClient defines the interface for the LinkedIn Apify client to allow mocking in tests type LinkedInApifyClient interface { - SearchProfiles(workerID string, args *profileArgs.Arguments, cursor client.Cursor) ([]*profileTypes.Profile, string, client.Cursor, error) + SearchProfiles(workerID string, args *pArgs.Arguments, cursor client.Cursor) ([]*pTypes.Profile, string, client.Cursor, error) ValidateApiKey() error } @@ -34,7 +33,7 @@ var NewLinkedInApifyClient = func(apiKey string, statsCollector *stats.StatsColl type LinkedInScraper struct { configuration config.JobConfiguration statsCollector *stats.StatsCollector - capabilities []teetypes.Capability + capabilities []types.Capability } func NewLinkedInScraper(jc config.JobConfiguration, statsCollector *stats.StatsCollector) *LinkedInScraper { @@ -42,7 +41,7 @@ func NewLinkedInScraper(jc config.JobConfiguration, statsCollector *stats.StatsC return &LinkedInScraper{ configuration: jc, statsCollector: statsCollector, - capabilities: teetypes.LinkedInCaps, + capabilities: types.LinkedInCaps, } } @@ -52,17 +51,17 @@ func (ls *LinkedInScraper) ExecuteJob(j types.Job) (types.JobResult, error) { // Require Apify key for LinkedIn scraping apifyApiKey := ls.configuration.GetString("apify_api_key", "") if apifyApiKey == "" { - msg := errors.New("Apify API key is required for LinkedIn job") + msg := errors.New("apify API key is required for LinkedIn job") return types.JobResult{Error: msg.Error()}, msg } - jobArgs, err := teeargs.UnmarshalJobArguments(teetypes.JobType(j.Type), map[string]any(j.Arguments)) + jobArgs, err := args.UnmarshalJobArguments(types.JobType(j.Type), map[string]any(j.Arguments)) if err != nil { msg := fmt.Errorf("failed to unmarshal job arguments: %w", err) return types.JobResult{Error: msg.Error()}, msg } - linkedinArgs, ok := jobArgs.(*profileArgs.Arguments) + linkedinArgs, ok := jobArgs.(*pArgs.Arguments) if !ok { return types.JobResult{Error: "invalid argument type for LinkedIn job"}, errors.New("invalid argument type") } @@ -84,7 +83,7 @@ func (ls *LinkedInScraper) ExecuteJob(j types.Job) (types.JobResult, error) { data, err := json.Marshal(profiles) if err != nil { - return types.JobResult{Error: fmt.Sprintf("error marshalling LinkedIn response")}, fmt.Errorf("error marshalling LinkedIn response: %w", err) + return types.JobResult{Error: "error marshalling LinkedIn response"}, fmt.Errorf("error marshalling LinkedIn response: %w", err) } return types.JobResult{ @@ -93,16 +92,3 @@ func (ls *LinkedInScraper) ExecuteJob(j types.Job) (types.JobResult, error) { NextCursor: cursor.String(), }, nil } - -// GetStructuredCapabilities returns the structured capabilities supported by the LinkedIn scraper -// based on the available credentials and API keys -func (ls *LinkedInScraper) GetStructuredCapabilities() teetypes.WorkerCapabilities { - capabilities := make(teetypes.WorkerCapabilities) - - apifyApiKey := ls.configuration.GetString("apify_api_key", "") - if apifyApiKey != "" { - capabilities[teetypes.LinkedInJob] = teetypes.LinkedInCaps - } - - return capabilities -} diff --git a/internal/jobs/linkedin_test.go b/internal/jobs/linkedin_test.go index dc5763e0..f7136b79 100644 --- a/internal/jobs/linkedin_test.go +++ b/internal/jobs/linkedin_test.go @@ -17,9 +17,8 @@ import ( "github.com/masa-finance/tee-worker/internal/jobs/stats" "github.com/masa-finance/tee-worker/pkg/client" - profileArgs "github.com/masa-finance/tee-types/args/linkedin/profile" - teetypes "github.com/masa-finance/tee-types/types" - profileTypes "github.com/masa-finance/tee-types/types/linkedin/profile" + profileArgs "github.com/masa-finance/tee-worker/api/args/linkedin/profile" + profileTypes "github.com/masa-finance/tee-worker/api/types/linkedin/profile" ) // MockLinkedInApifyClient is a mock implementation of the LinkedInApifyClient. @@ -68,7 +67,7 @@ var _ = Describe("LinkedInScraper", func() { job = types.Job{ UUID: "test-uuid", - Type: teetypes.LinkedInJob, + Type: types.LinkedInJob, } }) @@ -77,31 +76,24 @@ var _ = Describe("LinkedInScraper", func() { }) Context("ExecuteJob", func() { - It("should return an error for invalid arguments", func() { - job.Arguments = map[string]any{"invalid": "args"} - result, err := scraper.ExecuteJob(job) - Expect(err).To(HaveOccurred()) - Expect(result.Error).To(ContainSubstring("failed to unmarshal job arguments")) - }) - It("should return an error when Apify API key is missing", func() { cfg := config.JobConfiguration{} scraper = jobs.NewLinkedInScraper(cfg, statsCollector) job.Arguments = map[string]any{ - "type": teetypes.CapSearchByProfile, + "type": types.CapSearchByProfile, "searchQuery": "software engineer", "maxItems": 10, } result, err := scraper.ExecuteJob(job) Expect(err).To(HaveOccurred()) - Expect(result.Error).To(ContainSubstring("Apify API key is required for LinkedIn job")) + Expect(result.Error).To(ContainSubstring("apify API key is required for LinkedIn job")) }) It("should call SearchProfiles and return data and next cursor", func() { job.Arguments = map[string]any{ - "type": teetypes.CapSearchByProfile, + "type": types.CapSearchByProfile, "searchQuery": "software engineer", "maxItems": 10, } @@ -131,7 +123,7 @@ var _ = Describe("LinkedInScraper", func() { Expect(workerID).To(Equal("test-worker")) Expect(args.Query).To(Equal("software engineer")) Expect(args.MaxItems).To(Equal(uint(10))) - Expect(args.QueryType).To(Equal(teetypes.CapSearchByProfile)) + Expect(args.Type).To(Equal(types.CapSearchByProfile)) return expectedProfiles, "dataset-123", client.Cursor("next-cursor"), nil } @@ -152,7 +144,7 @@ var _ = Describe("LinkedInScraper", func() { It("should handle errors from the LinkedIn client", func() { job.Arguments = map[string]any{ - "type": teetypes.CapSearchByProfile, + "type": types.CapSearchByProfile, "searchQuery": "software engineer", "maxItems": 10, } @@ -173,7 +165,7 @@ var _ = Describe("LinkedInScraper", func() { return nil, errors.New("client creation failed") } job.Arguments = map[string]any{ - "type": teetypes.CapSearchByProfile, + "type": types.CapSearchByProfile, "searchQuery": "software engineer", "maxItems": 10, } @@ -185,7 +177,7 @@ var _ = Describe("LinkedInScraper", func() { It("should return an error when dataset ID is missing", func() { job.Arguments = map[string]any{ - "type": teetypes.CapSearchByProfile, + "type": types.CapSearchByProfile, "searchQuery": "software engineer", "maxItems": 10, } @@ -201,7 +193,7 @@ var _ = Describe("LinkedInScraper", func() { It("should handle JSON marshalling errors", func() { job.Arguments = map[string]any{ - "type": teetypes.CapSearchByProfile, + "type": types.CapSearchByProfile, "searchQuery": "software engineer", "maxItems": 10, } @@ -225,7 +217,7 @@ var _ = Describe("LinkedInScraper", func() { It("should handle empty profile results", func() { job.Arguments = map[string]any{ - "type": teetypes.CapSearchByProfile, + "type": types.CapSearchByProfile, "searchQuery": "nonexistent", "maxItems": 10, } @@ -245,37 +237,6 @@ var _ = Describe("LinkedInScraper", func() { }) }) - Context("GetStructuredCapabilities", func() { - It("should return LinkedIn capabilities when Apify API key is present", func() { - cfg := config.JobConfiguration{ - "apify_api_key": "test-key", - } - scraper = jobs.NewLinkedInScraper(cfg, statsCollector) - - capabilities := scraper.GetStructuredCapabilities() - Expect(capabilities).To(HaveKey(teetypes.LinkedInJob)) - Expect(capabilities[teetypes.LinkedInJob]).To(ContainElement(teetypes.CapSearchByProfile)) - }) - - It("should return empty capabilities when Apify API key is missing", func() { - cfg := config.JobConfiguration{} - scraper = jobs.NewLinkedInScraper(cfg, statsCollector) - - capabilities := scraper.GetStructuredCapabilities() - Expect(capabilities).NotTo(HaveKey(teetypes.LinkedInJob)) - }) - - It("should return empty capabilities when Apify API key is empty", func() { - cfg := config.JobConfiguration{ - "apify_api_key": "", - } - scraper = jobs.NewLinkedInScraper(cfg, statsCollector) - - capabilities := scraper.GetStructuredCapabilities() - Expect(capabilities).NotTo(HaveKey(teetypes.LinkedInJob)) - }) - }) - // Integration tests that use the real client Context("Integration tests", func() { var apifyKey string @@ -301,12 +262,12 @@ var _ = Describe("LinkedInScraper", func() { integrationScraper := jobs.NewLinkedInScraper(cfg, integrationStatsCollector) jobArgs := profileArgs.Arguments{ - QueryType: teetypes.CapSearchByProfile, - Query: "software engineer", - MaxItems: 10, + Type: types.CapSearchByProfile, + Query: "software engineer", + MaxItems: 10, } - // Marshal jobArgs to map[string]any so it can be used as JobArguments + // Marshal jobArgs to map[string]any so it can be used as JobArgument var jobArgsMap map[string]any jobArgsBytes, err := json.Marshal(jobArgs) Expect(err).NotTo(HaveOccurred()) @@ -315,7 +276,7 @@ var _ = Describe("LinkedInScraper", func() { job := types.Job{ UUID: "integration-test-uuid", - Type: teetypes.LinkedInJob, + Type: types.LinkedInJob, WorkerID: "test-worker", Arguments: jobArgsMap, Timeout: 60 * time.Second, @@ -338,22 +299,7 @@ var _ = Describe("LinkedInScraper", func() { fmt.Println(string(prettyJSON)) }) - It("should expose capabilities only when APIFY_API_KEY is present", func() { - cfg := config.JobConfiguration{ - "apify_api_key": apifyKey, - } - integrationStatsCollector := stats.StartCollector(128, cfg) - integrationScraper := jobs.NewLinkedInScraper(cfg, integrationStatsCollector) - - caps := integrationScraper.GetStructuredCapabilities() - if apifyKey != "" { - Expect(caps[teetypes.LinkedInJob]).NotTo(BeEmpty()) - Expect(caps[teetypes.LinkedInJob]).To(ContainElement(teetypes.CapSearchByProfile)) - } else { - // Expect no capabilities when key is missing - _, ok := caps[teetypes.LinkedInJob] - Expect(ok).To(BeFalse()) - } - }) + // Note: Capability detection is now centralized in capabilities/detector.go + // Individual scraper capability tests have been removed }) }) diff --git a/internal/jobs/linkedinapify/client.go b/internal/jobs/linkedinapify/client.go index 0124a374..adb55d3f 100644 --- a/internal/jobs/linkedinapify/client.go +++ b/internal/jobs/linkedinapify/client.go @@ -4,8 +4,8 @@ import ( "encoding/json" "fmt" - profileArgs "github.com/masa-finance/tee-types/args/linkedin/profile" - profileTypes "github.com/masa-finance/tee-types/types/linkedin/profile" + profileArgs "github.com/masa-finance/tee-worker/api/args/linkedin/profile" + profileTypes "github.com/masa-finance/tee-worker/api/types/linkedin/profile" "github.com/masa-finance/tee-worker/internal/apify" "github.com/masa-finance/tee-worker/internal/jobs/stats" "github.com/masa-finance/tee-worker/pkg/client" diff --git a/internal/jobs/linkedinapify/client_test.go b/internal/jobs/linkedinapify/client_test.go index b17b73fd..9bf2681d 100644 --- a/internal/jobs/linkedinapify/client_test.go +++ b/internal/jobs/linkedinapify/client_test.go @@ -13,9 +13,9 @@ import ( "github.com/masa-finance/tee-worker/internal/jobs/stats" "github.com/masa-finance/tee-worker/pkg/client" - profileArgs "github.com/masa-finance/tee-types/args/linkedin/profile" - "github.com/masa-finance/tee-types/types" - "github.com/masa-finance/tee-types/types/linkedin/profile" + profileArgs "github.com/masa-finance/tee-worker/api/args/linkedin/profile" + "github.com/masa-finance/tee-worker/api/types" + "github.com/masa-finance/tee-worker/api/types/linkedin/profile" ) // MockApifyClient is a mock implementation of the ApifyClient. @@ -70,10 +70,9 @@ var _ = Describe("LinkedInApifyClient", func() { Describe("SearchProfiles", func() { It("should construct the correct actor input", func() { - args := profileArgs.Arguments{ - Query: "software engineer", - MaxItems: 10, - } + args := profileArgs.NewArguments() + args.Query = "software engineer" + args.MaxItems = 10 mockClient.RunActorAndGetResponseFunc = func(actorID apify.ActorId, input any, cursor client.Cursor, limit uint) (*client.DatasetResponse, client.Cursor, error) { Expect(actorID).To(Equal(apify.ActorIds.LinkedInSearchProfile)) @@ -97,10 +96,9 @@ var _ = Describe("LinkedInApifyClient", func() { return nil, "", expectedErr } - args := profileArgs.Arguments{ - Query: "test query", - MaxItems: 5, - } + args := profileArgs.NewArguments() + args.Query = "test query" + args.MaxItems = 5 _, _, _, err := linkedinClient.SearchProfiles("test-worker", &args, client.EmptyCursor) Expect(err).To(MatchError(expectedErr)) }) @@ -116,10 +114,9 @@ var _ = Describe("LinkedInApifyClient", func() { return dataset, "next", nil } - args := profileArgs.Arguments{ - Query: "test query", - MaxItems: 1, - } + args := profileArgs.NewArguments() + args.Query = "test query" + args.MaxItems = 1 results, _, _, err := linkedinClient.SearchProfiles("test-worker", &args, client.EmptyCursor) Expect(err).NotTo(HaveOccurred()) Expect(results).To(BeEmpty()) // The invalid item should be skipped @@ -147,10 +144,9 @@ var _ = Describe("LinkedInApifyClient", func() { return dataset, "next", nil } - args := profileArgs.Arguments{ - Query: "test query", - MaxItems: 2, - } + args := profileArgs.NewArguments() + args.Query = "test query" + args.MaxItems = 2 results, _, _, err := linkedinClient.SearchProfiles("test-worker", &args, client.EmptyCursor) Expect(err).NotTo(HaveOccurred()) Expect(results).To(HaveLen(2)) @@ -206,12 +202,11 @@ var _ = Describe("LinkedInApifyClient", func() { realClient, err := linkedinapify.NewClient(apifyKey, statsCollector) Expect(err).NotTo(HaveOccurred()) - args := profileArgs.Arguments{ - QueryType: types.CapSearchByProfile, - Query: "software engineer", - MaxItems: 1, - ScraperMode: profile.ScraperModeShort, - } + args := profileArgs.NewArguments() + args.Type = types.CapSearchByProfile + args.Query = "software engineer" + args.MaxItems = 1 + args.ScraperMode = profile.ScraperModeShort results, datasetId, cursor, err := realClient.SearchProfiles("test-worker", &args, client.EmptyCursor) Expect(err).NotTo(HaveOccurred()) diff --git a/internal/jobs/llmapify/client.go b/internal/jobs/llmapify/client.go index 48cb9f9a..d7d7309c 100644 --- a/internal/jobs/llmapify/client.go +++ b/internal/jobs/llmapify/client.go @@ -5,8 +5,8 @@ import ( "errors" "fmt" - teeargs "github.com/masa-finance/tee-types/args" - teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/api/args/llm" + "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/apify" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs/stats" @@ -54,7 +54,7 @@ func (c *ApifyClient) ValidateApiKey() error { return c.client.ValidateApiKey() } -func (c *ApifyClient) Process(workerID string, args teeargs.LLMProcessorArguments, cursor client.Cursor) ([]*teetypes.LLMProcessorResult, client.Cursor, error) { +func (c *ApifyClient) Process(workerID string, args llm.Process, cursor client.Cursor) ([]*types.LLMProcessorResult, client.Cursor, error) { if c.statsCollector != nil { c.statsCollector.Add(workerID, stats.LLMQueries, 1) } @@ -64,7 +64,7 @@ func (c *ApifyClient) Process(workerID string, args teeargs.LLMProcessorArgument return nil, client.EmptyCursor, err } - input, err := args.ToLLMProcessorRequest(model, key) + input, err := args.ToProcessorRequest(model, key) if err != nil { return nil, client.EmptyCursor, err } @@ -78,10 +78,10 @@ func (c *ApifyClient) Process(workerID string, args teeargs.LLMProcessorArgument return nil, client.EmptyCursor, err } - response := make([]*teetypes.LLMProcessorResult, 0, len(dataset.Data.Items)) + response := make([]*types.LLMProcessorResult, 0, len(dataset.Data.Items)) for i, item := range dataset.Data.Items { - var resp teetypes.LLMProcessorResult + var resp types.LLMProcessorResult if err := json.Unmarshal(item, &resp); err != nil { logrus.Warnf("Failed to unmarshal llm result at index %d: %v", i, err) continue diff --git a/internal/jobs/llmapify/client_test.go b/internal/jobs/llmapify/client_test.go index 417e88c4..69497c79 100644 --- a/internal/jobs/llmapify/client_test.go +++ b/internal/jobs/llmapify/client_test.go @@ -10,13 +10,12 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/masa-finance/tee-worker/api/args/llm/process" + "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/apify" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs/llmapify" "github.com/masa-finance/tee-worker/pkg/client" - - teeargs "github.com/masa-finance/tee-types/args" - teetypes "github.com/masa-finance/tee-types/types" ) // MockApifyClient is a mock implementation of the ApifyClient. @@ -68,15 +67,14 @@ var _ = Describe("LLMApifyClient", func() { Describe("Process", func() { It("should construct the correct actor input", func() { - args := teeargs.LLMProcessorArguments{ - DatasetId: "test-dataset-id", - Prompt: "test-prompt", - } + llmArgs := process.NewArguments() + llmArgs.DatasetId = "test-dataset-id" + llmArgs.Prompt = "test-prompt" // Marshal and unmarshal to apply defaults - jsonData, err := json.Marshal(args) + jsonData, err := json.Marshal(llmArgs) Expect(err).ToNot(HaveOccurred()) - err = json.Unmarshal(jsonData, &args) + err = json.Unmarshal(jsonData, &llmArgs) Expect(err).ToNot(HaveOccurred()) mockClient.RunActorAndGetResponseFunc = func(actorID apify.ActorId, input any, cursor client.Cursor, limit uint) (*client.DatasetResponse, client.Cursor, error) { @@ -84,20 +82,20 @@ var _ = Describe("LLMApifyClient", func() { Expect(limit).To(Equal(uint(1))) // Verify the input is correctly converted to LLMProcessorRequest - request, ok := input.(teetypes.LLMProcessorRequest) + request, ok := input.(types.LLMProcessorRequest) Expect(ok).To(BeTrue()) Expect(request.InputDatasetId).To(Equal("test-dataset-id")) Expect(request.Prompt).To(Equal("test-prompt")) - Expect(request.LLMProviderApiKey).To(Equal("test-claude-llm-key")) // should be set from constructor - Expect(request.Model).To(Equal(teeargs.LLMDefaultClaudeModel)) // default model - Expect(request.MultipleColumns).To(Equal(teeargs.LLMDefaultMultipleColumns)) // default value - Expect(request.MaxTokens).To(Equal(teeargs.LLMDefaultMaxTokens)) // default value - Expect(request.Temperature).To(Equal(strconv.FormatFloat(teeargs.LLMDefaultTemperature, 'f', -1, 64))) // default value + Expect(request.LLMProviderApiKey).To(Equal("test-claude-llm-key")) // should be set from constructor + Expect(request.Model).To(Equal(process.DefaultClaudeModel)) // default model + Expect(request.MultipleColumns).To(Equal(process.DefaultMultipleColumns)) // default value + Expect(request.MaxTokens).To(Equal(process.DefaultMaxTokens)) // default value + Expect(request.Temperature).To(Equal(strconv.FormatFloat(process.DefaultTemperature, 'f', -1, 64))) // default value return &client.DatasetResponse{Data: client.ApifyDatasetData{Items: []json.RawMessage{}}}, "next", nil } - _, _, processErr := llmClient.Process("test-worker", args, client.EmptyCursor) + _, _, processErr := llmClient.Process("test-worker", llmArgs, client.EmptyCursor) Expect(processErr).NotTo(HaveOccurred()) }) @@ -107,11 +105,10 @@ var _ = Describe("LLMApifyClient", func() { return nil, "", expectedErr } - args := teeargs.LLMProcessorArguments{ - DatasetId: "test-dataset-id", - Prompt: "test-prompt", - } - _, _, err := llmClient.Process("test-worker", args, client.EmptyCursor) + llmArgs := process.NewArguments() + llmArgs.DatasetId = "test-dataset-id" + llmArgs.Prompt = "test-prompt" + _, _, err := llmClient.Process("test-worker", llmArgs, client.EmptyCursor) Expect(err).To(MatchError(expectedErr)) }) @@ -126,11 +123,10 @@ var _ = Describe("LLMApifyClient", func() { return dataset, "next", nil } - args := teeargs.LLMProcessorArguments{ - DatasetId: "test-dataset-id", - Prompt: "test-prompt", - } - results, _, err := llmClient.Process("test-worker", args, client.EmptyCursor) + llmArgs := process.NewArguments() + llmArgs.DatasetId = "test-dataset-id" + llmArgs.Prompt = "test-prompt" + results, _, err := llmClient.Process("test-worker", llmArgs, client.EmptyCursor) Expect(err).NotTo(HaveOccurred()) Expect(results).To(BeEmpty()) // The invalid item should be skipped }) @@ -148,11 +144,10 @@ var _ = Describe("LLMApifyClient", func() { return dataset, "next", nil } - args := teeargs.LLMProcessorArguments{ - DatasetId: "test-dataset-id", - Prompt: "test-prompt", - } - results, cursor, err := llmClient.Process("test-worker", args, client.EmptyCursor) + llmArgs := process.NewArguments() + llmArgs.DatasetId = "test-dataset-id" + llmArgs.Prompt = "test-prompt" + results, cursor, err := llmClient.Process("test-worker", llmArgs, client.EmptyCursor) Expect(err).NotTo(HaveOccurred()) Expect(cursor).To(Equal(client.Cursor("next"))) Expect(results).To(HaveLen(1)) @@ -175,11 +170,10 @@ var _ = Describe("LLMApifyClient", func() { return dataset, "next", nil } - args := teeargs.LLMProcessorArguments{ - DatasetId: "test-dataset-id", - Prompt: "test-prompt", - } - results, _, err := llmClient.Process("test-worker", args, client.EmptyCursor) + llmArgs := process.NewArguments() + llmArgs.DatasetId = "test-dataset-id" + llmArgs.Prompt = "test-prompt" + results, _, err := llmClient.Process("test-worker", llmArgs, client.EmptyCursor) Expect(err).NotTo(HaveOccurred()) Expect(results).To(HaveLen(2)) Expect(results[0].LLMResponse).To(Equal("First summary.")) @@ -187,15 +181,14 @@ var _ = Describe("LLMApifyClient", func() { }) It("should use custom values when provided", func() { - args := teeargs.LLMProcessorArguments{ - DatasetId: "test-dataset-id", - Prompt: "test-prompt", - MaxTokens: 500, - Temperature: 0.5, - } + llmArgs := process.NewArguments() + llmArgs.DatasetId = "test-dataset-id" + llmArgs.Prompt = "test-prompt" + llmArgs.MaxTokens = 500 + llmArgs.Temperature = 0.5 mockClient.RunActorAndGetResponseFunc = func(actorID apify.ActorId, input any, cursor client.Cursor, limit uint) (*client.DatasetResponse, client.Cursor, error) { - request, ok := input.(teetypes.LLMProcessorRequest) + request, ok := input.(types.LLMProcessorRequest) Expect(ok).To(BeTrue()) Expect(request.MaxTokens).To(Equal(uint(500))) Expect(request.Temperature).To(Equal("0.5")) @@ -204,7 +197,7 @@ var _ = Describe("LLMApifyClient", func() { return &client.DatasetResponse{Data: client.ApifyDatasetData{Items: []json.RawMessage{}}}, "next", nil } - _, _, err := llmClient.Process("test-worker", args, client.EmptyCursor) + _, _, err := llmClient.Process("test-worker", llmArgs, client.EmptyCursor) Expect(err).NotTo(HaveOccurred()) }) }) @@ -258,17 +251,16 @@ var _ = Describe("LLMApifyClient", func() { realClient, err := llmapify.NewClient(apifyKey, config.LlmConfig{GeminiApiKey: config.LlmApiKey(geminiKey)}, nil) Expect(err).NotTo(HaveOccurred()) - args := teeargs.LLMProcessorArguments{ - DatasetId: "V6tyuuZIgfiETl1cl", - Prompt: "summarize the content of this webpage ${markdown}", - } + llmArgs := process.NewArguments() + llmArgs.DatasetId = "V6tyuuZIgfiETl1cl" + llmArgs.Prompt = "summarize the content of this webpage ${markdown}" // Marshal and unmarshal to apply defaults - jsonData, err := json.Marshal(args) + jsonData, err := json.Marshal(llmArgs) Expect(err).ToNot(HaveOccurred()) - err = json.Unmarshal(jsonData, &args) + err = json.Unmarshal(jsonData, &llmArgs) Expect(err).ToNot(HaveOccurred()) - results, cursor, err := realClient.Process("test-worker", args, client.EmptyCursor) + results, cursor, err := realClient.Process("test-worker", llmArgs, client.EmptyCursor) Expect(err).NotTo(HaveOccurred()) Expect(results).NotTo(BeEmpty()) Expect(results[0]).NotTo(BeNil()) diff --git a/internal/jobs/reddit.go b/internal/jobs/reddit.go index f0bcb513..4404288c 100644 --- a/internal/jobs/reddit.go +++ b/internal/jobs/reddit.go @@ -9,24 +9,22 @@ import ( "github.com/sirupsen/logrus" + "github.com/masa-finance/tee-worker/api/args" + "github.com/masa-finance/tee-worker/api/args/reddit/search" "github.com/masa-finance/tee-worker/api/types" - "github.com/masa-finance/tee-worker/api/types/reddit" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs/redditapify" "github.com/masa-finance/tee-worker/internal/jobs/stats" "github.com/masa-finance/tee-worker/pkg/client" - - teeargs "github.com/masa-finance/tee-types/args" - teetypes "github.com/masa-finance/tee-types/types" ) // RedditApifyClient defines the interface for the Reddit Apify client. // This allows for mocking in tests. type RedditApifyClient interface { - ScrapeUrls(workerID string, urls []teetypes.RedditStartURL, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) - SearchPosts(workerID string, queries []string, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) - SearchCommunities(workerID string, queries []string, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) - SearchUsers(workerID string, queries []string, skipPosts bool, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) + ScrapeUrls(workerID string, urls []types.RedditStartURL, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) + SearchPosts(workerID string, queries []string, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) + SearchCommunities(workerID string, queries []string, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) + SearchUsers(workerID string, queries []string, skipPosts bool, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) } // NewRedditApifyClient is a function variable that can be replaced in tests. @@ -38,7 +36,7 @@ var NewRedditApifyClient = func(apiKey string, statsCollector *stats.StatsCollec type RedditScraper struct { configuration config.RedditConfig statsCollector *stats.StatsCollector - capabilities []teetypes.Capability + capabilities []types.Capability } func NewRedditScraper(jc config.JobConfiguration, statsCollector *stats.StatsCollector) *RedditScraper { @@ -47,21 +45,21 @@ func NewRedditScraper(jc config.JobConfiguration, statsCollector *stats.StatsCol return &RedditScraper{ configuration: config, statsCollector: statsCollector, - capabilities: teetypes.RedditCaps, + capabilities: types.RedditCaps, } } func (r *RedditScraper) ExecuteJob(j types.Job) (types.JobResult, error) { logrus.WithField("job_uuid", j.UUID).Info("Starting ExecuteJob for Reddit scrape") - jobArgs, err := teeargs.UnmarshalJobArguments(teetypes.JobType(j.Type), map[string]any(j.Arguments)) + jobArgs, err := args.UnmarshalJobArguments(types.JobType(j.Type), map[string]any(j.Arguments)) if err != nil { msg := fmt.Errorf("failed to unmarshal job arguments: %w", err) return types.JobResult{Error: msg.Error()}, msg } // Type assert to Reddit arguments - redditArgs, ok := jobArgs.(*teeargs.RedditArguments) + redditArgs, ok := jobArgs.(*search.Arguments) if !ok { return types.JobResult{Error: "invalid argument type for Reddit job"}, errors.New("invalid argument type") } @@ -75,11 +73,11 @@ func (r *RedditScraper) ExecuteJob(j types.Job) (types.JobResult, error) { commonArgs := redditapify.CommonArgs{} commonArgs.CopyFromArgs(redditArgs) - switch redditArgs.QueryType { - case teetypes.RedditScrapeUrls: - urls := make([]teetypes.RedditStartURL, 0, len(redditArgs.URLs)) + switch redditArgs.Type { + case types.CapScrapeUrls: + urls := make([]types.RedditStartURL, 0, len(redditArgs.URLs)) for _, u := range redditArgs.URLs { - urls = append(urls, teetypes.RedditStartURL{ + urls = append(urls, types.RedditStartURL{ URL: u, Method: "GET", }) @@ -88,31 +86,31 @@ func (r *RedditScraper) ExecuteJob(j types.Job) (types.JobResult, error) { resp, cursor, err := redditClient.ScrapeUrls(j.WorkerID, urls, redditArgs.After, commonArgs, client.Cursor(redditArgs.NextCursor), redditArgs.MaxResults) return processRedditResponse(j, resp, cursor, err) - case teetypes.RedditSearchUsers: + case types.CapSearchUsers: resp, cursor, err := redditClient.SearchUsers(j.WorkerID, redditArgs.Queries, redditArgs.SkipPosts, commonArgs, client.Cursor(redditArgs.NextCursor), redditArgs.MaxResults) return processRedditResponse(j, resp, cursor, err) - case teetypes.RedditSearchPosts: + case types.CapSearchPosts: resp, cursor, err := redditClient.SearchPosts(j.WorkerID, redditArgs.Queries, redditArgs.After, commonArgs, client.Cursor(redditArgs.NextCursor), redditArgs.MaxResults) return processRedditResponse(j, resp, cursor, err) - case teetypes.RedditSearchCommunities: + case types.CapSearchCommunities: resp, cursor, err := redditClient.SearchCommunities(j.WorkerID, redditArgs.Queries, commonArgs, client.Cursor(redditArgs.NextCursor), redditArgs.MaxResults) return processRedditResponse(j, resp, cursor, err) default: - return types.JobResult{Error: "invalid type for Reddit job"}, fmt.Errorf("invalid type for Reddit job: %s", redditArgs.QueryType) + return types.JobResult{Error: "invalid type for Reddit job"}, fmt.Errorf("invalid type for Reddit job: %s", redditArgs.Type) } } -func processRedditResponse(j types.Job, resp []*reddit.Response, cursor client.Cursor, err error) (types.JobResult, error) { +func processRedditResponse(j types.Job, resp []*types.RedditResponse, cursor client.Cursor, err error) (types.JobResult, error) { if err != nil { return types.JobResult{Error: fmt.Sprintf("error while scraping Reddit: %s", err.Error())}, fmt.Errorf("error scraping Reddit: %w", err) } data, err := json.Marshal(resp) if err != nil { - return types.JobResult{Error: fmt.Sprintf("error marshalling Reddit response")}, fmt.Errorf("error marshalling Reddit response: %w", err) + return types.JobResult{Error: "error marshalling Reddit response"}, fmt.Errorf("error marshalling Reddit response: %w", err) } return types.JobResult{ Data: data, @@ -120,17 +118,3 @@ func processRedditResponse(j types.Job, resp []*reddit.Response, cursor client.C NextCursor: cursor.String(), }, nil } - -// GetStructuredCapabilities returns the structured capabilities supported by this Twitter scraper -// based on the available credentials and API keys -func (rs *RedditScraper) GetStructuredCapabilities() teetypes.WorkerCapabilities { - capabilities := make(teetypes.WorkerCapabilities) - - // Add Apify-specific capabilities based on available API key - // TODO: We should verify whether each of the actors is actually available through this API key - if rs.configuration.ApifyApiKey != "" { - capabilities[teetypes.RedditJob] = teetypes.RedditCaps - } - - return capabilities -} diff --git a/internal/jobs/reddit_test.go b/internal/jobs/reddit_test.go index 36ab11ea..344dbf8d 100644 --- a/internal/jobs/reddit_test.go +++ b/internal/jobs/reddit_test.go @@ -10,25 +10,22 @@ import ( "github.com/sirupsen/logrus" "github.com/masa-finance/tee-worker/api/types" - "github.com/masa-finance/tee-worker/api/types/reddit" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs" "github.com/masa-finance/tee-worker/internal/jobs/redditapify" "github.com/masa-finance/tee-worker/internal/jobs/stats" "github.com/masa-finance/tee-worker/pkg/client" - - teetypes "github.com/masa-finance/tee-types/types" ) // MockRedditApifyClient is a mock implementation of the RedditApifyClient. type MockRedditApifyClient struct { - ScrapeUrlsFunc func(urls []teetypes.RedditStartURL, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) - SearchPostsFunc func(queries []string, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) - SearchCommunitiesFunc func(queries []string, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) - SearchUsersFunc func(queries []string, skipPosts bool, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) + ScrapeUrlsFunc func(urls []types.RedditStartURL, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) + SearchPostsFunc func(queries []string, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) + SearchCommunitiesFunc func(queries []string, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) + SearchUsersFunc func(queries []string, skipPosts bool, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) } -func (m *MockRedditApifyClient) ScrapeUrls(_ string, urls []teetypes.RedditStartURL, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { +func (m *MockRedditApifyClient) ScrapeUrls(_ string, urls []types.RedditStartURL, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { if m != nil && m.ScrapeUrlsFunc != nil { res, cursor, err := m.ScrapeUrlsFunc(urls, after, args, cursor, maxResults) for i, r := range res { @@ -39,21 +36,21 @@ func (m *MockRedditApifyClient) ScrapeUrls(_ string, urls []teetypes.RedditStart return nil, "", nil } -func (m *MockRedditApifyClient) SearchPosts(_ string, queries []string, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { +func (m *MockRedditApifyClient) SearchPosts(_ string, queries []string, after time.Time, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { if m != nil && m.SearchPostsFunc != nil { return m.SearchPostsFunc(queries, after, args, cursor, maxResults) } return nil, "", nil } -func (m *MockRedditApifyClient) SearchCommunities(_ string, queries []string, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { +func (m *MockRedditApifyClient) SearchCommunities(_ string, queries []string, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { if m != nil && m.SearchCommunitiesFunc != nil { return m.SearchCommunitiesFunc(queries, args, cursor, maxResults) } return nil, "", nil } -func (m *MockRedditApifyClient) SearchUsers(_ string, queries []string, skipPosts bool, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { +func (m *MockRedditApifyClient) SearchUsers(_ string, queries []string, skipPosts bool, args redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { if m != nil && m.SearchUsersFunc != nil { return m.SearchUsersFunc(queries, skipPosts, args, cursor, maxResults) } @@ -83,7 +80,7 @@ var _ = Describe("RedditScraper", func() { job = types.Job{ UUID: "test-uuid", - Type: teetypes.RedditJob, + Type: types.RedditJob, } }) @@ -100,20 +97,20 @@ var _ = Describe("RedditScraper", func() { "https://www.reddit.com/r/HHGTTG/comments/1jynlrz/the_entire_series_after_restaurant_at_the_end_of/", } job.Arguments = map[string]any{ - "type": teetypes.RedditScrapeUrls, + "type": types.CapScrapeUrls, "urls": testUrls, } - mockClient.ScrapeUrlsFunc = func(urls []teetypes.RedditStartURL, after time.Time, cArgs redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { + mockClient.ScrapeUrlsFunc = func(urls []types.RedditStartURL, after time.Time, cArgs redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { Expect(urls).To(HaveLen(1)) Expect(urls[0].URL).To(Equal(testUrls[0])) - return []*reddit.Response{{TypeSwitch: &reddit.TypeSwitch{Type: reddit.UserResponse}, User: &reddit.User{ID: "user1", DataType: string(reddit.UserResponse)}}}, "next", nil + return []*types.RedditResponse{{Type: types.RedditUserItem, User: &types.RedditUser{ID: "user1", DataType: string(types.RedditUserItem)}}}, "next", nil } result, err := scraper.ExecuteJob(job) Expect(err).NotTo(HaveOccurred()) Expect(result.NextCursor).To(Equal("next")) - var resp []*reddit.Response + var resp []*types.RedditResponse err = json.Unmarshal(result.Data, &resp) Expect(err).NotTo(HaveOccurred()) Expect(resp).To(HaveLen(1)) @@ -124,19 +121,19 @@ var _ = Describe("RedditScraper", func() { It("should call SearchUsers for the correct QueryType", func() { job.Arguments = map[string]any{ - "type": teetypes.RedditSearchUsers, + "type": types.CapSearchUsers, "queries": []string{"user-query"}, } - mockClient.SearchUsersFunc = func(queries []string, skipPosts bool, cArgs redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { + mockClient.SearchUsersFunc = func(queries []string, skipPosts bool, cArgs redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { Expect(queries).To(Equal([]string{"user-query"})) - return []*reddit.Response{{TypeSwitch: &reddit.TypeSwitch{Type: reddit.UserResponse}, User: &reddit.User{ID: "user2", DataType: string(reddit.UserResponse)}}}, "next-user", nil + return []*types.RedditResponse{{Type: types.RedditUserItem, User: &types.RedditUser{ID: "user2", DataType: string(types.RedditUserItem)}}}, "next-user", nil } result, err := scraper.ExecuteJob(job) Expect(err).NotTo(HaveOccurred()) Expect(result.NextCursor).To(Equal("next-user")) - var resp []*reddit.Response + var resp []*types.RedditResponse err = json.Unmarshal(result.Data, &resp) Expect(err).NotTo(HaveOccurred()) Expect(resp).To(HaveLen(1)) @@ -147,19 +144,19 @@ var _ = Describe("RedditScraper", func() { It("should call SearchPosts for the correct QueryType", func() { job.Arguments = map[string]any{ - "type": teetypes.RedditSearchPosts, + "type": types.CapSearchPosts, "queries": []string{"post-query"}, } - mockClient.SearchPostsFunc = func(queries []string, after time.Time, cArgs redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { + mockClient.SearchPostsFunc = func(queries []string, after time.Time, cArgs redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { Expect(queries).To(Equal([]string{"post-query"})) - return []*reddit.Response{{TypeSwitch: &reddit.TypeSwitch{Type: reddit.PostResponse}, Post: &reddit.Post{ID: "post1", DataType: string(reddit.PostResponse)}}}, "next-post", nil + return []*types.RedditResponse{{Type: types.RedditPostItem, Post: &types.RedditPost{ID: "post1", DataType: string(types.RedditPostItem)}}}, "next-post", nil } result, err := scraper.ExecuteJob(job) Expect(err).NotTo(HaveOccurred()) Expect(result.NextCursor).To(Equal("next-post")) - var resp []*reddit.Response + var resp []*types.RedditResponse err = json.Unmarshal(result.Data, &resp) Expect(err).NotTo(HaveOccurred()) Expect(resp).To(HaveLen(1)) @@ -170,19 +167,19 @@ var _ = Describe("RedditScraper", func() { It("should call SearchCommunities for the correct QueryType", func() { job.Arguments = map[string]any{ - "type": teetypes.RedditSearchCommunities, + "type": types.CapSearchCommunities, "queries": []string{"community-query"}, } - mockClient.SearchCommunitiesFunc = func(queries []string, cArgs redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { + mockClient.SearchCommunitiesFunc = func(queries []string, cArgs redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { Expect(queries).To(Equal([]string{"community-query"})) - return []*reddit.Response{{TypeSwitch: &reddit.TypeSwitch{Type: reddit.CommunityResponse}, Community: &reddit.Community{ID: "comm1", DataType: string(reddit.CommunityResponse)}}}, "next-comm", nil + return []*types.RedditResponse{{Type: types.RedditCommunityItem, Community: &types.RedditCommunity{ID: "comm1", DataType: string(types.RedditCommunityItem)}}}, "next-comm", nil } result, err := scraper.ExecuteJob(job) Expect(err).NotTo(HaveOccurred()) Expect(result.NextCursor).To(Equal("next-comm")) - var resp []*reddit.Response + var resp []*types.RedditResponse err = json.Unmarshal(result.Data, &resp) Expect(err).NotTo(HaveOccurred()) Expect(resp).To(HaveLen(1)) @@ -204,12 +201,12 @@ var _ = Describe("RedditScraper", func() { It("should handle errors from the reddit client", func() { job.Arguments = map[string]any{ - "type": teetypes.RedditSearchPosts, + "type": types.CapSearchPosts, "queries": []string{"post-query"}, } expectedErr := errors.New("client error") - mockClient.SearchPostsFunc = func(queries []string, after time.Time, cArgs redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { + mockClient.SearchPostsFunc = func(queries []string, after time.Time, cArgs redditapify.CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { return nil, "", expectedErr } @@ -224,7 +221,7 @@ var _ = Describe("RedditScraper", func() { return nil, errors.New("client creation failed") } job.Arguments = map[string]any{ - "type": teetypes.RedditSearchPosts, + "type": types.CapSearchPosts, "queries": []string{"post-query"}, } diff --git a/internal/jobs/redditapify/client.go b/internal/jobs/redditapify/client.go index e90e7e75..79f23439 100644 --- a/internal/jobs/redditapify/client.go +++ b/internal/jobs/redditapify/client.go @@ -7,18 +7,16 @@ import ( "github.com/sirupsen/logrus" - "github.com/masa-finance/tee-worker/api/types/reddit" + "github.com/masa-finance/tee-worker/api/args/reddit/search" + "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/apify" "github.com/masa-finance/tee-worker/internal/jobs/stats" "github.com/masa-finance/tee-worker/pkg/client" - - teeargs "github.com/masa-finance/tee-types/args" - teetypes "github.com/masa-finance/tee-types/types" ) // CommonArgs holds the parameters that all Reddit searches support, in a single struct type CommonArgs struct { - Sort teetypes.RedditSortType + Sort types.RedditSortType IncludeNSFW bool MaxItems uint MaxPosts uint @@ -27,7 +25,7 @@ type CommonArgs struct { MaxUsers uint } -func (ca *CommonArgs) CopyFromArgs(a *teeargs.RedditArguments) { +func (ca *CommonArgs) CopyFromArgs(a *search.Arguments) { ca.Sort = a.Sort ca.IncludeNSFW = a.IncludeNSFW ca.MaxItems = a.MaxItems @@ -52,23 +50,23 @@ func (args *CommonArgs) ToActorRequest() RedditActorRequest { // RedditActorRequest represents the query parameters for the Apify Reddit Scraper actor. // Based on the input schema of https://apify.com/trudax/reddit-scraper type RedditActorRequest struct { - Type teetypes.RedditQueryType `json:"type,omitempty"` - Searches []string `json:"searches,omitempty"` - StartUrls []teetypes.RedditStartURL `json:"startUrls,omitempty"` - Sort teetypes.RedditSortType `json:"sort,omitempty"` - PostDateLimit *time.Time `json:"postDateLimit,omitempty"` - IncludeNSFW bool `json:"includeNSFW"` - MaxItems uint `json:"maxItems,omitempty"` // Total number of items to scrape - MaxPostCount uint `json:"maxPostCount,omitempty"` // Max number of posts per page - MaxComments uint `json:"maxComments,omitempty"` // Max number of comments per page - MaxCommunitiesCount uint `json:"maxCommunitiesCount,omitempty"` // Max number of communities per page - MaxUserCount uint `json:"maxUserCount,omitempty"` // Max number of users per page - SearchComments bool `json:"searchComments"` - SearchCommunities bool `json:"searchCommunities"` - SearchPosts bool `json:"searchPosts"` - SearchUsers bool `json:"searchUsers"` - SkipUserPosts bool `json:"skipUserPosts"` - SkipComments bool `json:"skipComments"` + Type types.Capability `json:"type,omitempty"` + Searches []string `json:"searches,omitempty"` + StartUrls []types.RedditStartURL `json:"startUrls,omitempty"` + Sort types.RedditSortType `json:"sort,omitempty"` + PostDateLimit *time.Time `json:"postDateLimit,omitempty"` + IncludeNSFW bool `json:"includeNSFW"` + MaxItems uint `json:"maxItems,omitempty"` // Total number of items to scrape + MaxPostCount uint `json:"maxPostCount,omitempty"` // Max number of posts per page + MaxComments uint `json:"maxComments,omitempty"` // Max number of comments per page + MaxCommunitiesCount uint `json:"maxCommunitiesCount,omitempty"` // Max number of communities per page + MaxUserCount uint `json:"maxUserCount,omitempty"` // Max number of users per page + SearchComments bool `json:"searchComments"` + SearchCommunities bool `json:"searchCommunities"` + SearchPosts bool `json:"searchPosts"` + SearchUsers bool `json:"searchUsers"` + SkipUserPosts bool `json:"skipUserPosts"` + SkipComments bool `json:"skipComments"` } // RedditApifyClient wraps the generic Apify client for Reddit-specific operations @@ -102,7 +100,7 @@ func (c *RedditApifyClient) ValidateApiKey() error { } // ScrapeUrls scrapes Reddit URLs -func (c *RedditApifyClient) ScrapeUrls(workerID string, urls []teetypes.RedditStartURL, after time.Time, args CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { +func (c *RedditApifyClient) ScrapeUrls(workerID string, urls []types.RedditStartURL, after time.Time, args CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { input := args.ToActorRequest() input.StartUrls = urls input.Searches = nil @@ -119,7 +117,7 @@ func (c *RedditApifyClient) ScrapeUrls(workerID string, urls []teetypes.RedditSt } // SearchPosts searches Reddit posts -func (c *RedditApifyClient) SearchPosts(workerID string, queries []string, after time.Time, args CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { +func (c *RedditApifyClient) SearchPosts(workerID string, queries []string, after time.Time, args CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { input := args.ToActorRequest() input.Searches = queries input.StartUrls = nil @@ -136,7 +134,7 @@ func (c *RedditApifyClient) SearchPosts(workerID string, queries []string, after } // SearchCommunities searches Reddit communities -func (c *RedditApifyClient) SearchCommunities(workerID string, queries []string, args CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { +func (c *RedditApifyClient) SearchCommunities(workerID string, queries []string, args CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { input := args.ToActorRequest() input.Searches = queries input.StartUrls = nil @@ -147,7 +145,7 @@ func (c *RedditApifyClient) SearchCommunities(workerID string, queries []string, } // SearchUsers searches Reddit users -func (c *RedditApifyClient) SearchUsers(workerID string, queries []string, skipPosts bool, args CommonArgs, cursor client.Cursor, maxResults uint) ([]*reddit.Response, client.Cursor, error) { +func (c *RedditApifyClient) SearchUsers(workerID string, queries []string, skipPosts bool, args CommonArgs, cursor client.Cursor, maxResults uint) ([]*types.RedditResponse, client.Cursor, error) { input := args.ToActorRequest() input.Searches = queries input.StartUrls = nil @@ -159,7 +157,7 @@ func (c *RedditApifyClient) SearchUsers(workerID string, queries []string, skipP } // getProfiles runs the actor and retrieves profiles from the dataset -func (c *RedditApifyClient) queryReddit(workerID string, input RedditActorRequest, cursor client.Cursor, limit uint) ([]*reddit.Response, client.Cursor, error) { +func (c *RedditApifyClient) queryReddit(workerID string, input RedditActorRequest, cursor client.Cursor, limit uint) ([]*types.RedditResponse, client.Cursor, error) { if c.statsCollector != nil { c.statsCollector.Add(workerID, stats.RedditQueries, 1) } @@ -172,9 +170,9 @@ func (c *RedditApifyClient) queryReddit(workerID string, input RedditActorReques return nil, client.EmptyCursor, err } - response := make([]*reddit.Response, 0, len(dataset.Data.Items)) + response := make([]*types.RedditResponse, 0, len(dataset.Data.Items)) for i, item := range dataset.Data.Items { - var resp reddit.Response + var resp types.RedditResponse if err := json.Unmarshal(item, &resp); err != nil { logrus.Warnf("Failed to unmarshal profile at index %d: %v", i, err) continue diff --git a/internal/jobs/redditapify/client_test.go b/internal/jobs/redditapify/client_test.go index 1712157e..3a3d59ac 100644 --- a/internal/jobs/redditapify/client_test.go +++ b/internal/jobs/redditapify/client_test.go @@ -8,12 +8,11 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/masa-finance/tee-worker/api/args/reddit/search" + "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/apify" "github.com/masa-finance/tee-worker/internal/jobs/redditapify" "github.com/masa-finance/tee-worker/pkg/client" - - teeargs "github.com/masa-finance/tee-types/args" - teetypes "github.com/masa-finance/tee-types/types" ) // MockApifyClient is a mock implementation of the ApifyClient. @@ -63,7 +62,7 @@ var _ = Describe("RedditApifyClient", func() { Describe("ScrapeUrls", func() { It("should construct the correct actor input", func() { - urls := []teetypes.RedditStartURL{{URL: "http://reddit.com/r/golang"}} + urls := []types.RedditStartURL{{URL: "http://reddit.com/r/golang"}} after := time.Now() args := redditapify.CommonArgs{MaxPosts: 10} @@ -98,7 +97,7 @@ var _ = Describe("RedditApifyClient", func() { Expect(req.Searches).To(Equal(queries)) Expect(req.StartUrls).To(BeNil()) Expect(*req.PostDateLimit).To(BeTemporally("~", after, time.Second)) - Expect(req.Type).To(Equal(teetypes.RedditQueryType("posts"))) + Expect(req.Type).To(Equal(types.CapSearchPosts)) Expect(req.SearchPosts).To(BeTrue()) Expect(req.SkipComments).To(BeFalse()) Expect(req.MaxComments).To(Equal(uint(5))) @@ -120,7 +119,7 @@ var _ = Describe("RedditApifyClient", func() { req := input.(redditapify.RedditActorRequest) Expect(req.Searches).To(Equal(queries)) Expect(req.StartUrls).To(BeNil()) - Expect(req.Type).To(Equal(teetypes.RedditQueryType("communities"))) + Expect(req.Type).To(Equal(types.CapSearchCommunities)) Expect(req.SearchCommunities).To(BeTrue()) return &client.DatasetResponse{Data: client.ApifyDatasetData{Items: []json.RawMessage{}}}, "next", nil } @@ -140,7 +139,7 @@ var _ = Describe("RedditApifyClient", func() { req := input.(redditapify.RedditActorRequest) Expect(req.Searches).To(Equal(queries)) Expect(req.StartUrls).To(BeNil()) - Expect(req.Type).To(Equal(teetypes.RedditQueryType("users"))) + Expect(req.Type).To(Equal(types.CapSearchUsers)) Expect(req.SearchUsers).To(BeTrue()) Expect(req.SkipUserPosts).To(BeTrue()) return &client.DatasetResponse{Data: client.ApifyDatasetData{Items: []json.RawMessage{}}}, "next", nil @@ -201,8 +200,8 @@ var _ = Describe("RedditApifyClient", func() { Describe("CommonArgs", func() { It("should copy from RedditArguments correctly", func() { - redditArgs := &teeargs.RedditArguments{ - Sort: teetypes.RedditSortTop, + redditArgs := &search.Arguments{ + Sort: types.RedditSortTop, IncludeNSFW: true, MaxItems: 1, MaxPosts: 2, @@ -213,7 +212,7 @@ var _ = Describe("RedditApifyClient", func() { commonArgs := redditapify.CommonArgs{} commonArgs.CopyFromArgs(redditArgs) - Expect(commonArgs.Sort).To(Equal(teetypes.RedditSortTop)) + Expect(commonArgs.Sort).To(Equal(types.RedditSortTop)) Expect(commonArgs.IncludeNSFW).To(BeTrue()) Expect(commonArgs.MaxItems).To(Equal(uint(1))) Expect(commonArgs.MaxPosts).To(Equal(uint(2))) @@ -224,7 +223,7 @@ var _ = Describe("RedditApifyClient", func() { It("should convert to RedditActorRequest correctly", func() { commonArgs := redditapify.CommonArgs{ - Sort: teetypes.RedditSortNew, + Sort: types.RedditSortNew, IncludeNSFW: true, MaxItems: 10, MaxPosts: 20, @@ -234,7 +233,7 @@ var _ = Describe("RedditApifyClient", func() { } actorReq := commonArgs.ToActorRequest() - Expect(actorReq.Sort).To(Equal(teetypes.RedditSortNew)) + Expect(actorReq.Sort).To(Equal(types.RedditSortNew)) Expect(actorReq.IncludeNSFW).To(BeTrue()) Expect(actorReq.MaxItems).To(Equal(uint(10))) Expect(actorReq.MaxPostCount).To(Equal(uint(20))) diff --git a/internal/jobs/stats/stats.go b/internal/jobs/stats/stats.go index dd449de9..85f1bd5d 100644 --- a/internal/jobs/stats/stats.go +++ b/internal/jobs/stats/stats.go @@ -5,7 +5,7 @@ import ( "sync" "time" - teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/capabilities" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/versioning" @@ -14,7 +14,7 @@ import ( // WorkerCapabilitiesProvider abstracts capability retrieval to avoid import cycles type WorkerCapabilitiesProvider interface { - GetWorkerCapabilities() teetypes.WorkerCapabilities + GetWorkerCapabilities() types.WorkerCapabilities } // These are the types of statistics that we can add. The value is the JSON key that will be used for serialization. @@ -66,7 +66,7 @@ type Stats struct { CurrentTimeUnix int64 `json:"current_time"` WorkerID string `json:"worker_id"` Stats map[string]map[StatType]uint `json:"stats"` - ReportedCapabilities teetypes.WorkerCapabilities `json:"reported_capabilities"` + ReportedCapabilities types.WorkerCapabilities `json:"reported_capabilities"` WorkerVersion string `json:"worker_version"` ApplicationVersion string `json:"application_version"` sync.Mutex diff --git a/internal/jobs/telemetry.go b/internal/jobs/telemetry.go index 837e1886..db92afa6 100644 --- a/internal/jobs/telemetry.go +++ b/internal/jobs/telemetry.go @@ -1,7 +1,6 @@ package jobs import ( - teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs/stats" @@ -16,13 +15,6 @@ func NewTelemetryJob(jc config.JobConfiguration, c *stats.StatsCollector) Teleme return TelemetryJob{collector: c} } -// GetStructuredCapabilities returns the structured capabilities supported by the telemetry job -func (t TelemetryJob) GetStructuredCapabilities() teetypes.WorkerCapabilities { - return teetypes.WorkerCapabilities{ - teetypes.TelemetryJob: teetypes.AlwaysAvailableTelemetryCaps, - } -} - func (t TelemetryJob) ExecuteJob(j types.Job) (types.JobResult, error) { logrus.Debug("Executing telemetry job") diff --git a/internal/jobs/telemetry_test.go b/internal/jobs/telemetry_test.go index 7c2b4732..2a8960a9 100644 --- a/internal/jobs/telemetry_test.go +++ b/internal/jobs/telemetry_test.go @@ -8,7 +8,6 @@ import ( . "github.com/onsi/gomega" "github.com/sirupsen/logrus" - teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/config" . "github.com/masa-finance/tee-worker/internal/jobs" @@ -41,7 +40,7 @@ var _ = Describe("Telemetry Job", func() { // Execute the telemetry job job := types.Job{ - Type: teetypes.TelemetryJob, + Type: types.TelemetryJob, WorkerID: "telemetry-test", } @@ -87,7 +86,7 @@ var _ = Describe("Telemetry Job", func() { telemetryJobNoStats := NewTelemetryJob(config.JobConfiguration{}, nil) job := types.Job{ - Type: teetypes.TelemetryJob, + Type: types.TelemetryJob, WorkerID: "telemetry-test-no-stats", } @@ -100,14 +99,7 @@ var _ = Describe("Telemetry Job", func() { logrus.WithField("error", result.Error).Info("Telemetry job handled missing stats collector correctly") }) - It("should return structured capabilities", func() { - capabilities := telemetryJob.GetStructuredCapabilities() - - Expect(capabilities).NotTo(BeEmpty()) - Expect(capabilities).To(HaveLen(1)) - Expect(capabilities[teetypes.TelemetryJob]).To(ContainElement(teetypes.CapTelemetry)) - - logrus.WithField("capabilities", capabilities).Info("Telemetry job capabilities verified") - }) + // Note: Capability detection is now centralized in capabilities/detector.go + // Individual scraper capability tests have been removed }) }) diff --git a/internal/jobs/tiktok.go b/internal/jobs/tiktok.go index f1e327e6..a3c8b223 100644 --- a/internal/jobs/tiktok.go +++ b/internal/jobs/tiktok.go @@ -10,8 +10,10 @@ import ( "strings" "time" - teeargs "github.com/masa-finance/tee-types/args" - teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/api/args" + "github.com/masa-finance/tee-worker/api/args/tiktok/query" + "github.com/masa-finance/tee-worker/api/args/tiktok/transcription" + "github.com/masa-finance/tee-worker/api/args/tiktok/trending" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs/stats" @@ -41,18 +43,6 @@ type TikTokTranscriber struct { httpClient *http.Client } -// GetStructuredCapabilities returns the structured capabilities supported by the TikTok transcriber -func (t *TikTokTranscriber) GetStructuredCapabilities() teetypes.WorkerCapabilities { - caps := make([]teetypes.Capability, 0, len(teetypes.AlwaysAvailableTiktokCaps)+len(teetypes.TiktokSearchCaps)) - caps = append(caps, teetypes.AlwaysAvailableTiktokCaps...) - if t.configuration.ApifyApiKey != "" { - caps = append(caps, teetypes.TiktokSearchCaps...) - } - return teetypes.WorkerCapabilities{ - teetypes.TiktokJob: caps, - } -} - // NewTikTokTranscriber creates and initializes a new TikTokTranscriber. // It sets default values for the API configuration. func NewTikTokTranscriber(jc config.JobConfiguration, statsCollector *stats.StatsCollector) *TikTokTranscriber { @@ -105,17 +95,17 @@ func (ttt *TikTokTranscriber) ExecuteJob(j types.Job) (types.JobResult, error) { logrus.WithField("job_uuid", j.UUID).Info("Starting ExecuteJob for TikTok job") // Use the centralized type-safe unmarshaller - jobArgs, err := teeargs.UnmarshalJobArguments(teetypes.JobType(j.Type), map[string]any(j.Arguments)) + jobArgs, err := args.UnmarshalJobArguments(types.JobType(j.Type), map[string]any(j.Arguments)) if err != nil { return types.JobResult{Error: "Failed to unmarshal job arguments"}, fmt.Errorf("unmarshal job arguments: %w", err) } // Branch by argument type (transcription vs search) - if transcriptionArgs, ok := jobArgs.(*teeargs.TikTokTranscriptionArguments); ok { + if transcriptionArgs, ok := jobArgs.(*transcription.Arguments); ok { return ttt.executeTranscription(j, transcriptionArgs) - } else if searchByQueryArgs, ok := jobArgs.(*teeargs.TikTokSearchByQueryArguments); ok { + } else if searchByQueryArgs, ok := jobArgs.(*query.Arguments); ok { return ttt.executeSearchByQuery(j, searchByQueryArgs) - } else if searchByTrendingArgs, ok := jobArgs.(*teeargs.TikTokSearchByTrendingArguments); ok { + } else if searchByTrendingArgs, ok := jobArgs.(*trending.Arguments); ok { return ttt.executeSearchByTrending(j, searchByTrendingArgs) } else { return types.JobResult{Error: "invalid argument type for TikTok job"}, fmt.Errorf("invalid argument type") @@ -123,7 +113,7 @@ func (ttt *TikTokTranscriber) ExecuteJob(j types.Job) (types.JobResult, error) { } // executeTranscription calls the external transcription service and returns a normalized result -func (ttt *TikTokTranscriber) executeTranscription(j types.Job, a *teeargs.TikTokTranscriptionArguments) (types.JobResult, error) { +func (ttt *TikTokTranscriber) executeTranscription(j types.Job, a *transcription.Arguments) (types.JobResult, error) { logrus.WithField("job_uuid", j.UUID).Info("Starting ExecuteJob for TikTok transcription") if ttt.configuration.TranscriptionEndpoint == "" { @@ -132,13 +122,13 @@ func (ttt *TikTokTranscriber) executeTranscription(j types.Job, a *teeargs.TikTo } // Use the centralized type-safe unmarshaller - jobArgs, err := teeargs.UnmarshalJobArguments(teetypes.JobType(j.Type), map[string]any(j.Arguments)) + jobArgs, err := args.UnmarshalJobArguments(types.JobType(j.Type), map[string]any(j.Arguments)) if err != nil { return types.JobResult{Error: "Failed to unmarshal job arguments"}, fmt.Errorf("unmarshal job arguments: %w", err) } // Type assert to TikTok arguments - tiktokArgs, ok := jobArgs.(*teeargs.TikTokTranscriptionArguments) + tiktokArgs, ok := jobArgs.(*transcription.Arguments) if !ok { return types.JobResult{Error: "invalid argument type for TikTok job"}, fmt.Errorf("invalid argument type") } @@ -215,7 +205,7 @@ func (ttt *TikTokTranscriber) executeTranscription(j types.Job, a *teeargs.TikTo // Sub-Step 3.2: Extract Transcription and Metadata if len(parsedAPIResponse.Transcripts) == 0 { - errMsg := "No transcripts found in API response" + errMsg := "no transcripts found in API response" logrus.WithField("job_uuid", j.UUID).Warn(errMsg) ttt.stats.Add(j.WorkerID, stats.TikTokTranscriptionErrors, 1) // Or a different stat for "no_transcript_found" return types.JobResult{Error: errMsg}, errors.New(errMsg) @@ -270,7 +260,7 @@ func (ttt *TikTokTranscriber) executeTranscription(j types.Job, a *teeargs.TikTo } // Process Result & Return - resultData := teetypes.TikTokTranscriptionResult{ + resultData := types.TikTokTranscriptionResult{ TranscriptionText: plainTextTranscription, DetectedLanguage: languageCode, VideoTitle: parsedAPIResponse.VideoTitle, @@ -294,7 +284,7 @@ func (ttt *TikTokTranscriber) executeTranscription(j types.Job, a *teeargs.TikTo } // executeSearchByQuery runs the epctex/tiktok-search-scraper actor and returns results -func (ttt *TikTokTranscriber) executeSearchByQuery(j types.Job, a *teeargs.TikTokSearchByQueryArguments) (types.JobResult, error) { +func (ttt *TikTokTranscriber) executeSearchByQuery(j types.Job, a *query.Arguments) (types.JobResult, error) { c, err := tiktokapify.NewTikTokApifyClient(ttt.configuration.ApifyApiKey) if err != nil { ttt.stats.Add(j.WorkerID, stats.TikTokAuthErrors, 1) @@ -325,7 +315,7 @@ func (ttt *TikTokTranscriber) executeSearchByQuery(j types.Job, a *teeargs.TikTo } // executeSearchByTrending runs the lexis-solutions/tiktok-trending-videos-scraper actor and returns results -func (ttt *TikTokTranscriber) executeSearchByTrending(j types.Job, a *teeargs.TikTokSearchByTrendingArguments) (types.JobResult, error) { +func (ttt *TikTokTranscriber) executeSearchByTrending(j types.Job, a *trending.Arguments) (types.JobResult, error) { c, err := tiktokapify.NewTikTokApifyClient(ttt.configuration.ApifyApiKey) if err != nil { ttt.stats.Add(j.WorkerID, stats.TikTokAuthErrors, 1) diff --git a/internal/jobs/tiktok_test.go b/internal/jobs/tiktok_test.go index cb8973c0..5c9b5f44 100644 --- a/internal/jobs/tiktok_test.go +++ b/internal/jobs/tiktok_test.go @@ -10,7 +10,6 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/config" . "github.com/masa-finance/tee-worker/internal/jobs" @@ -45,14 +44,14 @@ var _ = Describe("TikTok", func() { Context("when a valid TikTok URL is provided", func() { It("should successfully transcribe the video and record success stats", func(ctx SpecContext) { videoURL := "https://www.tiktok.com/@theblockrunner.com/video/7227579907361066282" - jobArguments := map[string]interface{}{ - "type": teetypes.CapTranscription, + JobArgument := map[string]interface{}{ + "type": types.CapTranscription, "video_url": videoURL, } job := types.Job{ - Type: teetypes.TiktokJob, - Arguments: jobArguments, + Type: types.TiktokJob, + Arguments: JobArgument, WorkerID: "tiktok-test-worker-happy", UUID: "test-uuid-happy", } @@ -69,7 +68,7 @@ var _ = Describe("TikTok", func() { Expect(res.Data).NotTo(BeNil()) Expect(res.Data).NotTo(BeEmpty()) - var transcriptionResult teetypes.TikTokTranscriptionResult + var transcriptionResult types.TikTokTranscriptionResult err = json.Unmarshal(res.Data, &transcriptionResult) Expect(err).NotTo(HaveOccurred(), "Failed to unmarshal result data") @@ -115,14 +114,14 @@ var _ = Describe("TikTok", func() { Context("when arguments are invalid", func() { It("should return an error if VideoURL is empty and not record error stats", func() { - jobArguments := map[string]interface{}{ - "type": teetypes.CapTranscription, + JobArgument := map[string]interface{}{ + "type": types.CapTranscription, "video_url": "", // Empty URL } job := types.Job{ - Type: teetypes.TiktokJob, - Arguments: jobArguments, + Type: types.TiktokJob, + Arguments: JobArgument, WorkerID: "tiktok-test-worker-invalid", UUID: "test-uuid-invalid", } @@ -174,9 +173,9 @@ var _ = Describe("TikTok", func() { t := NewTikTokTranscriber(jobConfig, statsCollector) j := types.Job{ - Type: teetypes.TiktokJob, + Type: types.TiktokJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "search": []string{"crypto", "ai"}, "max_items": 5, "end_page": 1, @@ -190,7 +189,7 @@ var _ = Describe("TikTok", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var items []*teetypes.TikTokSearchByQueryResult + var items []*types.TikTokSearchByQueryResult err = json.Unmarshal(res.Data, &items) Expect(err).NotTo(HaveOccurred()) Expect(items).NotTo(BeEmpty()) @@ -235,9 +234,9 @@ var _ = Describe("TikTok", func() { t := NewTikTokTranscriber(jobConfig, statsCollector) j := types.Job{ - Type: teetypes.TiktokJob, + Type: types.TiktokJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByTrending, + "type": types.CapSearchByTrending, "country_code": "US", "sort_by": "repost", "max_items": 5, @@ -251,7 +250,7 @@ var _ = Describe("TikTok", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var items []*teetypes.TikTokSearchByTrending + var items []*types.TikTokSearchByTrending err = json.Unmarshal(res.Data, &items) Expect(err).NotTo(HaveOccurred()) Expect(items).NotTo(BeEmpty()) @@ -290,9 +289,9 @@ var _ = Describe("TikTok", func() { t := NewTikTokTranscriber(jobConfig, statsCollector) j := types.Job{ - Type: teetypes.TiktokJob, + Type: types.TiktokJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "search": []string{"tiktok"}, "max_items": 1, "end_page": 1, diff --git a/internal/jobs/tiktokapify/client.go b/internal/jobs/tiktokapify/client.go index 6ad22f24..48c9b077 100644 --- a/internal/jobs/tiktokapify/client.go +++ b/internal/jobs/tiktokapify/client.go @@ -4,8 +4,9 @@ import ( "encoding/json" "fmt" - teeargs "github.com/masa-finance/tee-types/args" - teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/api/args/tiktok/query" + "github.com/masa-finance/tee-worker/api/args/tiktok/trending" + "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/apify" "github.com/masa-finance/tee-worker/pkg/client" ) @@ -43,7 +44,7 @@ func (c *TikTokApifyClient) ValidateApiKey() error { } // SearchByQuery runs the search actor and returns typed results -func (c *TikTokApifyClient) SearchByQuery(input teeargs.TikTokSearchByQueryArguments, cursor client.Cursor, limit uint) ([]*teetypes.TikTokSearchByQueryResult, client.Cursor, error) { +func (c *TikTokApifyClient) SearchByQuery(input query.Arguments, cursor client.Cursor, limit uint) ([]*types.TikTokSearchByQueryResult, client.Cursor, error) { // Map snake_case fields to Apify actor's expected camelCase input startUrls := input.StartUrls if startUrls == nil { @@ -79,9 +80,9 @@ func (c *TikTokApifyClient) SearchByQuery(input teeargs.TikTokSearchByQueryArgum return nil, "", fmt.Errorf("apify run (search): %w", err) } - var results []*teetypes.TikTokSearchByQueryResult + var results []*types.TikTokSearchByQueryResult for _, raw := range dataset.Data.Items { - var item teetypes.TikTokSearchByQueryResult + var item types.TikTokSearchByQueryResult if err := json.Unmarshal(raw, &item); err != nil { // Skip any items whose structure doesn't match continue @@ -92,7 +93,7 @@ func (c *TikTokApifyClient) SearchByQuery(input teeargs.TikTokSearchByQueryArgum } // SearchByTrending runs the trending actor and returns typed results -func (c *TikTokApifyClient) SearchByTrending(input teeargs.TikTokSearchByTrendingArguments, cursor client.Cursor, limit uint) ([]*teetypes.TikTokSearchByTrending, client.Cursor, error) { +func (c *TikTokApifyClient) SearchByTrending(input trending.Arguments, cursor client.Cursor, limit uint) ([]*types.TikTokSearchByTrending, client.Cursor, error) { request := TikTokSearchByTrendingRequest{ CountryCode: input.CountryCode, SortBy: input.SortBy, @@ -115,9 +116,9 @@ func (c *TikTokApifyClient) SearchByTrending(input teeargs.TikTokSearchByTrendin return nil, "", fmt.Errorf("apify run (trending): %w", err) } - var results []*teetypes.TikTokSearchByTrending + var results []*types.TikTokSearchByTrending for _, raw := range dataset.Data.Items { - var item teetypes.TikTokSearchByTrending + var item types.TikTokSearchByTrending if err := json.Unmarshal(raw, &item); err != nil { continue } diff --git a/internal/jobs/twitter.go b/internal/jobs/twitter.go index a7cecb2e..a1fd4222 100644 --- a/internal/jobs/twitter.go +++ b/internal/jobs/twitter.go @@ -8,12 +8,11 @@ import ( "strings" "time" - teeargs "github.com/masa-finance/tee-types/args" - teetypes "github.com/masa-finance/tee-types/types" - "github.com/masa-finance/tee-worker/internal/jobs/twitterx" "github.com/masa-finance/tee-worker/pkg/client" + "github.com/masa-finance/tee-worker/api/args" + twitterargs "github.com/masa-finance/tee-worker/api/args/twitter" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs/stats" @@ -24,7 +23,7 @@ import ( "github.com/sirupsen/logrus" ) -func (ts *TwitterScraper) convertTwitterScraperTweetToTweetResult(tweet twitterscraper.Tweet) *teetypes.TweetResult { +func (ts *TwitterScraper) convertTwitterScraperTweetToTweetResult(tweet twitterscraper.Tweet) *types.TweetResult { id, err := strconv.ParseInt(tweet.ID, 10, 64) if err != nil { logrus.Warnf("failed to convert tweet ID to int64: %s", tweet.ID) @@ -34,7 +33,7 @@ func (ts *TwitterScraper) convertTwitterScraperTweetToTweetResult(tweet twitters createdAt := time.Unix(tweet.Timestamp, 0).UTC() logrus.Debug("Converting Tweet ID: ", id) // Changed to Debug - return &teetypes.TweetResult{ + return &types.TweetResult{ ID: id, TweetID: tweet.ID, ConversationID: tweet.ConversationID, @@ -54,20 +53,20 @@ func (ts *TwitterScraper) convertTwitterScraperTweetToTweetResult(tweet twitters Retweets: tweet.Retweets, URLs: tweet.URLs, Username: tweet.Username, - Photos: func() []teetypes.Photo { - var photos []teetypes.Photo + Photos: func() []types.Photo { + var photos []types.Photo for _, photo := range tweet.Photos { - photos = append(photos, teetypes.Photo{ + photos = append(photos, types.Photo{ ID: photo.ID, URL: photo.URL, }) } return photos }(), - Videos: func() []teetypes.Video { - var videos []teetypes.Video + Videos: func() []types.Video { + var videos []types.Video for _, video := range tweet.Videos { - videos = append(videos, teetypes.Video{ + videos = append(videos, types.Video{ ID: video.ID, Preview: video.Preview, URL: video.URL, @@ -146,7 +145,6 @@ func (ts *TwitterScraper) getApiScraper(j types.Job) (*twitterx.TwitterXScraper, // getApifyScraper returns an Apify client func (ts *TwitterScraper) getApifyScraper(j types.Job) (*twitterapify.TwitterApifyClient, error) { - // TODO: We should verify whether each of the actors is actually available through this API key if ts.configuration.ApifyApiKey == "" { ts.statsCollector.Add(j.WorkerID, stats.TwitterAuthErrors, 1) return nil, fmt.Errorf("no Apify API key available") @@ -185,79 +183,24 @@ func filterMap[T any, R any](slice []T, f func(T) (R, bool)) []R { return result } -func (ts *TwitterScraper) ScrapeFollowersForProfile(j types.Job, baseDir string, username string, count int) ([]*twitterscraper.Profile, error) { - scraper, account, err := ts.getCredentialScraper(j, baseDir) - if err != nil { - return nil, err - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - followingResponse, errString, _ := scraper.FetchFollowers(username, count, "") - if errString != "" { - fetchErr := fmt.Errorf("error fetching followers: %s", errString) - if ts.handleError(j, fetchErr, account) { - return nil, fetchErr - } - logrus.Errorf("[-] Error fetching followers: %s", errString) - return nil, fetchErr - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterProfiles, uint(len(followingResponse))) - return followingResponse, nil -} - -func (ts *TwitterScraper) ScrapeTweetsProfile(j types.Job, baseDir string, username string) (twitterscraper.Profile, error) { - logrus.Infof("[ScrapeTweetsProfile] Starting profile scraping for username: %s", username) +func (ts *TwitterScraper) SearchByProfile(j types.Job, baseDir string, username string) (twitterscraper.Profile, error) { scraper, account, err := ts.getCredentialScraper(j, baseDir) if err != nil { - logrus.Errorf("[ScrapeTweetsProfile] Failed to get credential scraper: %v", err) + logrus.Errorf("failed to get credential scraper: %v", err) return twitterscraper.Profile{}, err } - - logrus.Infof("[ScrapeTweetsProfile] About to increment TwitterScrapes stat for WorkerID: %s", j.WorkerID) ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - logrus.Infof("[ScrapeTweetsProfile] TwitterScrapes incremented, now calling scraper.GetProfile") - profile, err := scraper.GetProfile(username) if err != nil { - logrus.Errorf("[ScrapeTweetsProfile] scraper.GetProfile failed for username %s: %v", username, err) + logrus.Errorf("scraper.GetProfile failed for username %s: %v", username, err) _ = ts.handleError(j, err, account) return twitterscraper.Profile{}, err } - - logrus.Infof("[ScrapeTweetsProfile] Profile retrieved successfully for username: %s, profile: %+v", username, profile) - logrus.Infof("[ScrapeTweetsProfile] About to increment TwitterProfiles stat for WorkerID: %s", j.WorkerID) ts.statsCollector.Add(j.WorkerID, stats.TwitterProfiles, 1) - logrus.Infof("[ScrapeTweetsProfile] TwitterProfiles incremented successfully") - return profile, nil } -func (ts *TwitterScraper) ScrapeTweetsByFullArchiveSearchQuery(j types.Job, baseDir string, query string, count int) ([]*teetypes.TweetResult, error) { - return ts.queryTweets(j, twitterx.TweetsAll, baseDir, query, count) -} - -func (ts *TwitterScraper) ScrapeTweetsByRecentSearchQuery(j types.Job, baseDir string, query string, count int) ([]*teetypes.TweetResult, error) { - return ts.queryTweets(j, twitterx.TweetsSearchRecent, baseDir, query, count) -} - -func (ts *TwitterScraper) queryTweets(j types.Job, baseQueryEndpoint string, baseDir string, query string, count int) ([]*teetypes.TweetResult, error) { - // Try credentials first, fallback to API for CapSearchByQuery - scraper, account, err := ts.getCredentialScraper(j, baseDir) - if err == nil { - return ts.scrapeTweetsWithCredentials(j, query, count, scraper, account) - } - - // Fallback to API - twitterXScraper, apiKey, apiErr := ts.getApiScraper(j) - if apiErr != nil { - ts.statsCollector.Add(j.WorkerID, stats.TwitterAuthErrors, 1) - return nil, fmt.Errorf("no Twitter accounts or API keys available") - } - return ts.scrapeTweets(j, baseQueryEndpoint, query, count, twitterXScraper, apiKey) -} - -func (ts *TwitterScraper) queryTweetsWithCredentials(j types.Job, baseDir string, query string, count int) ([]*teetypes.TweetResult, error) { +func (ts *TwitterScraper) SearchByQuery(j types.Job, baseDir string, query string, count int) ([]*types.TweetResult, error) { scraper, account, err := ts.getCredentialScraper(j, baseDir) if err != nil { return nil, err @@ -265,17 +208,17 @@ func (ts *TwitterScraper) queryTweetsWithCredentials(j types.Job, baseDir string return ts.scrapeTweetsWithCredentials(j, query, count, scraper, account) } -func (ts *TwitterScraper) queryTweetsWithApiKey(j types.Job, baseQueryEndpoint string, query string, count int) ([]*teetypes.TweetResult, error) { +func (ts *TwitterScraper) SearchByFullArchive(j types.Job, baseQueryEndpoint string, query string, count int) ([]*types.TweetResult, error) { twitterXScraper, apiKey, err := ts.getApiScraper(j) if err != nil { return nil, err } - return ts.scrapeTweets(j, baseQueryEndpoint, query, count, twitterXScraper, apiKey) + return ts.scrapeTweetsWithAPI(j, baseQueryEndpoint, query, count, twitterXScraper, apiKey) } -func (ts *TwitterScraper) scrapeTweetsWithCredentials(j types.Job, query string, count int, scraper *twitter.Scraper, account *twitter.TwitterAccount) ([]*teetypes.TweetResult, error) { +func (ts *TwitterScraper) scrapeTweetsWithCredentials(j types.Job, query string, count int, scraper *twitter.Scraper, account *twitter.TwitterAccount) ([]*types.TweetResult, error) { ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - tweets := make([]*teetypes.TweetResult, 0, count) + tweets := make([]*types.TweetResult, 0, count) ctx, cancel := context.WithTimeout(context.Background(), j.Timeout) defer cancel() @@ -295,15 +238,14 @@ func (ts *TwitterScraper) scrapeTweetsWithCredentials(j types.Job, query string, return tweets, nil } -// scrapeTweets uses an existing scraper instance -func (ts *TwitterScraper) scrapeTweets(j types.Job, baseQueryEndpoint string, query string, count int, twitterXScraper *twitterx.TwitterXScraper, apiKey *twitter.TwitterApiKey) ([]*teetypes.TweetResult, error) { +func (ts *TwitterScraper) scrapeTweetsWithAPI(j types.Job, baseQueryEndpoint string, query string, count int, twitterXScraper *twitterx.TwitterXScraper, apiKey *twitter.TwitterApiKey) ([]*types.TweetResult, error) { ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) if baseQueryEndpoint == twitterx.TweetsAll && apiKey.Type == twitter.TwitterApiKeyTypeBase { return nil, fmt.Errorf("this API key is a base/Basic key and does not have access to full archive search. Please use an elevated/Pro API key") } - tweets := make([]*teetypes.TweetResult, 0, count) + tweets := make([]*types.TweetResult, 0, count) cursor := "" deadline := time.Now().Add(j.Timeout) @@ -339,7 +281,7 @@ func (ts *TwitterScraper) scrapeTweets(j types.Job, baseQueryEndpoint string, qu return nil, fmt.Errorf("failed to parse tweet ID '%s' from twitterx: %w", tX.ID, convErr) } - newTweet := &teetypes.TweetResult{ + newTweet := &types.TweetResult{ ID: tweetIDInt, TweetID: tX.ID, AuthorID: tX.AuthorID, @@ -357,7 +299,7 @@ func (ts *TwitterScraper) scrapeTweets(j types.Job, baseQueryEndpoint string, qu //} //if tX.PublicMetrics != nil { - newTweet.PublicMetrics = teetypes.PublicMetrics{ + newTweet.PublicMetrics = types.PublicMetrics{ RetweetCount: tX.PublicMetrics.RetweetCount, ReplyCount: tX.PublicMetrics.ReplyCount, LikeCount: tX.PublicMetrics.LikeCount, @@ -365,7 +307,7 @@ func (ts *TwitterScraper) scrapeTweets(j types.Job, baseQueryEndpoint string, qu BookmarkCount: tX.PublicMetrics.BookmarkCount, } //} - // if tX.PossiblySensitive is available in twitterx.TweetData and teetypes.TweetResult has PossiblySensitive: + // if tX.PossiblySensitive is available in twitterx.TweetData and types.TweetResult has PossiblySensitive: // newTweet.PossiblySensitive = tX.PossiblySensitive // Also, fields like IsQuoted, Photos, Videos etc. would need to be populated if tX provides them. // Currently, this mapping is simpler than convertTwitterScraperTweetToTweetResult. @@ -393,29 +335,7 @@ EndLoop: return tweets, nil } -func (ts *TwitterScraper) ScrapeTweetByID(j types.Job, baseDir string, tweetID string) (*teetypes.TweetResult, error) { - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - - scraper, account, err := ts.getCredentialScraper(j, baseDir) - if err != nil { - return nil, err - } - - tweet, err := scraper.GetTweet(tweetID) - if err != nil { - _ = ts.handleError(j, err, account) - return nil, err - } - if tweet == nil { - return nil, fmt.Errorf("tweet not found or error occurred, but error was nil") - } - - tweetResult := ts.convertTwitterScraperTweetToTweetResult(*tweet) - ts.statsCollector.Add(j.WorkerID, stats.TwitterTweets, 1) - return tweetResult, nil -} - -func (ts *TwitterScraper) GetTweet(j types.Job, baseDir, tweetID string) (*teetypes.TweetResult, error) { +func (ts *TwitterScraper) GetTweet(j types.Job, baseDir, tweetID string) (*types.TweetResult, error) { scraper, account, err := ts.getCredentialScraper(j, baseDir) if err != nil { return nil, err @@ -435,14 +355,14 @@ func (ts *TwitterScraper) GetTweet(j types.Job, baseDir, tweetID string) (*teety return tweetResult, nil } -func (ts *TwitterScraper) GetTweetReplies(j types.Job, baseDir, tweetID string, cursor string) ([]*teetypes.TweetResult, error) { +func (ts *TwitterScraper) GetTweetReplies(j types.Job, baseDir, tweetID string, cursor string) ([]*types.TweetResult, error) { scraper, account, err := ts.getCredentialScraper(j, baseDir) if err != nil { return nil, err } ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - var replies []*teetypes.TweetResult + var replies []*types.TweetResult scrapedTweets, threadEntries, err := scraper.GetTweetReplies(tweetID, cursor) if err != nil { @@ -453,7 +373,7 @@ func (ts *TwitterScraper) GetTweetReplies(j types.Job, baseDir, tweetID string, for i, scrapedTweet := range scrapedTweets { newTweetResult := ts.convertTwitterScraperTweetToTweetResult(*scrapedTweet) if i < len(threadEntries) { - // Assuming teetypes.TweetResult has a ThreadCursor field (struct, not pointer) + // Assuming types.TweetResult has a ThreadCursor field (struct, not pointer) newTweetResult.ThreadCursor.Cursor = threadEntries[i].Cursor newTweetResult.ThreadCursor.CursorType = threadEntries[i].CursorType newTweetResult.ThreadCursor.FocalTweetID = threadEntries[i].FocalTweetID @@ -484,14 +404,14 @@ func (ts *TwitterScraper) GetTweetRetweeters(j types.Job, baseDir, tweetID strin return retweeters, nil } -func (ts *TwitterScraper) GetUserTweets(j types.Job, baseDir, username string, count int, cursor string) ([]*teetypes.TweetResult, string, error) { +func (ts *TwitterScraper) GetUserTweets(j types.Job, baseDir, username string, count int, cursor string) ([]*types.TweetResult, string, error) { scraper, account, err := ts.getCredentialScraper(j, baseDir) if err != nil { return nil, "", err } ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - var tweets []*teetypes.TweetResult + var tweets []*types.TweetResult var nextCursor string if cursor != "" { @@ -524,14 +444,14 @@ func (ts *TwitterScraper) GetUserTweets(j types.Job, baseDir, username string, c return tweets, nextCursor, nil } -func (ts *TwitterScraper) GetUserMedia(j types.Job, baseDir, username string, count int, cursor string) ([]*teetypes.TweetResult, string, error) { +func (ts *TwitterScraper) GetUserMedia(j types.Job, baseDir, username string, count int, cursor string) ([]*types.TweetResult, string, error) { scraper, account, err := ts.getCredentialScraper(j, baseDir) if err != nil { return nil, "", err } ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - var media []*teetypes.TweetResult + var media []*types.TweetResult var nextCursor string ctx, cancel := context.WithTimeout(context.Background(), j.Timeout) defer cancel() @@ -585,137 +505,6 @@ func (ts *TwitterScraper) GetUserMedia(j types.Job, baseDir, username string, co return media, nextCursor, nil } -func (ts *TwitterScraper) GetHomeTweets(j types.Job, baseDir string, count int, cursor string) ([]*teetypes.TweetResult, string, error) { - scraper, account, err := ts.getCredentialScraper(j, baseDir) - if err != nil { - return nil, "", err - } - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - - var tweets []*teetypes.TweetResult - var nextCursor string - - if cursor != "" { - fetchedTweets, fetchCursor, fetchErr := scraper.FetchHomeTweets(count, cursor) - if fetchErr != nil { - _ = ts.handleError(j, fetchErr, account) - return nil, "", fetchErr - } - for _, tweet := range fetchedTweets { - newTweetResult := ts.convertTwitterScraperTweetToTweetResult(*tweet) - tweets = append(tweets, newTweetResult) - } - nextCursor = fetchCursor - } else { - ctx, cancel := context.WithTimeout(context.Background(), j.Timeout) - defer cancel() - for tweetScraped := range scraper.GetHomeTweets(ctx, count) { - if tweetScraped.Error != nil { - _ = ts.handleError(j, tweetScraped.Error, account) - return nil, "", tweetScraped.Error - } - newTweetResult := ts.convertTwitterScraperTweetToTweetResult(tweetScraped.Tweet) - tweets = append(tweets, newTweetResult) - if len(tweets) >= count && count > 0 { - break - } - } - if len(tweets) > 0 { - nextCursor = strconv.FormatInt(tweets[len(tweets)-1].ID, 10) - } - } - ts.statsCollector.Add(j.WorkerID, stats.TwitterTweets, uint(len(tweets))) - return tweets, nextCursor, nil -} - -func (ts *TwitterScraper) GetForYouTweets(j types.Job, baseDir string, count int, cursor string) ([]*teetypes.TweetResult, string, error) { - scraper, account, err := ts.getCredentialScraper(j, baseDir) - if err != nil { - return nil, "", err - } - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - - var tweets []*teetypes.TweetResult - var nextCursor string - - if cursor != "" { - fetchedTweets, fetchCursor, fetchErr := scraper.FetchForYouTweets(count, cursor) - if fetchErr != nil { - _ = ts.handleError(j, fetchErr, account) - return nil, "", fetchErr - } - for _, tweet := range fetchedTweets { - newTweetResult := ts.convertTwitterScraperTweetToTweetResult(*tweet) - tweets = append(tweets, newTweetResult) - } - nextCursor = fetchCursor - } else { - ctx, cancel := context.WithTimeout(context.Background(), j.Timeout) - defer cancel() - for tweetScraped := range scraper.GetForYouTweets(ctx, count) { - if tweetScraped.Error != nil { - _ = ts.handleError(j, tweetScraped.Error, account) - return nil, "", tweetScraped.Error - } - newTweetResult := ts.convertTwitterScraperTweetToTweetResult(tweetScraped.Tweet) - tweets = append(tweets, newTweetResult) - if len(tweets) >= count && count > 0 { - break - } - } - if len(tweets) > 0 { - nextCursor = strconv.FormatInt(tweets[len(tweets)-1].ID, 10) - } - } - ts.statsCollector.Add(j.WorkerID, stats.TwitterTweets, uint(len(tweets))) - return tweets, nextCursor, nil -} - -func (ts *TwitterScraper) GetBookmarks(j types.Job, baseDir string, count int, cursor string) ([]*teetypes.TweetResult, string, error) { - scraper, account, err := ts.getCredentialScraper(j, baseDir) - if err != nil { - return nil, "", err - } - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - var bookmarks []*teetypes.TweetResult - - ctx, cancel := context.WithTimeout(context.Background(), j.Timeout) - defer cancel() - cursorInt := 0 - if cursor != "" { - var parseErr error - cursorInt, parseErr = strconv.Atoi(cursor) - if parseErr != nil { - logrus.Warnf("Invalid cursor value for GetBookmarks '%s', using default 0: %v", cursor, parseErr) - cursorInt = 0 // Ensure it's reset if parse fails - } - } - for tweetScraped := range scraper.GetBookmarks(ctx, cursorInt) { - if tweetScraped.Error != nil { - _ = ts.handleError(j, tweetScraped.Error, account) - return nil, "", tweetScraped.Error - } - newTweetResult := ts.convertTwitterScraperTweetToTweetResult(tweetScraped.Tweet) - bookmarks = append(bookmarks, newTweetResult) - if len(bookmarks) >= count && count > 0 { - break - } - } - - var nextCursor string - if len(bookmarks) > 0 { - // The twitterscraper GetBookmarks cursor is an offset. - // The next cursor should be the current offset + number of items fetched in this batch. - nextCursor = strconv.Itoa(cursorInt + len(bookmarks)) - } else if cursor != "" { - // If no bookmarks were fetched but a cursor was provided, retain it or signal no change - nextCursor = cursor - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterTweets, uint(len(bookmarks))) - return bookmarks, nextCursor, nil -} - func (ts *TwitterScraper) GetProfileByID(j types.Job, baseDir, userID string) (*twitterscraper.Profile, error) { scraper, account, err := ts.getCredentialScraper(j, baseDir) if err != nil { @@ -732,98 +521,6 @@ func (ts *TwitterScraper) GetProfileByID(j types.Job, baseDir, userID string) (* return &profile, nil } -// GetProfileByIDWithApiKey fetches user profile using Twitter API key -func (ts *TwitterScraper) GetProfileByIDWithApiKey(j types.Job, userID string, apiKey *twitter.TwitterApiKey) (*twitterx.TwitterXProfileResponse, error) { - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - - apiClient := client.NewTwitterXClient(apiKey.Key) - twitterXScraper := twitterx.NewTwitterXScraper(apiClient) - - profile, err := twitterXScraper.GetProfileByID(userID) - if err != nil { - if ts.handleError(j, err, nil) { - return nil, err - } - return nil, err - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterProfiles, 1) - return profile, nil -} - -// GetTweetByIDWithApiKey fetches a tweet using Twitter API key -func (ts *TwitterScraper) GetTweetByIDWithApiKey(j types.Job, tweetID string, apiKey *twitter.TwitterApiKey) (*teetypes.TweetResult, error) { - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - - apiClient := client.NewTwitterXClient(apiKey.Key) - twitterXScraper := twitterx.NewTwitterXScraper(apiClient) - - tweetData, err := twitterXScraper.GetTweetByID(tweetID) - if err != nil { - if ts.handleError(j, err, nil) { - return nil, err - } - return nil, err - } - - // Convert TwitterXTweetData to TweetResult - tweetIDInt, convErr := strconv.ParseInt(tweetData.ID, 10, 64) - if convErr != nil { - logrus.Errorf("Failed to convert tweet ID '%s' to int64: %v", tweetData.ID, convErr) - return nil, fmt.Errorf("failed to parse tweet ID '%s': %w", tweetData.ID, convErr) - } - - // Parse the created_at time string - createdAt, timeErr := time.Parse(time.RFC3339, tweetData.CreatedAt) - if timeErr != nil { - logrus.Warnf("Failed to parse created_at time '%s': %v", tweetData.CreatedAt, timeErr) - createdAt = time.Now() // fallback to current time - } - - tweetResult := &teetypes.TweetResult{ - ID: tweetIDInt, - TweetID: tweetData.ID, - AuthorID: tweetData.AuthorID, - Text: tweetData.Text, - ConversationID: tweetData.ConversationID, - UserID: tweetData.AuthorID, - CreatedAt: createdAt, - Username: tweetData.Username, - Lang: tweetData.Lang, - PublicMetrics: teetypes.PublicMetrics{ - RetweetCount: tweetData.PublicMetrics.RetweetCount, - ReplyCount: tweetData.PublicMetrics.ReplyCount, - LikeCount: tweetData.PublicMetrics.LikeCount, - QuoteCount: tweetData.PublicMetrics.QuoteCount, - BookmarkCount: tweetData.PublicMetrics.BookmarkCount, - }, - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterTweets, 1) - return tweetResult, nil -} - -func (ts *TwitterScraper) SearchProfile(j types.Job, query string, count int) ([]*twitterscraper.ProfileResult, error) { - scraper, _, err := ts.getCredentialScraper(j, ts.configuration.DataDir) - if err != nil { - return nil, err - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - var profiles []*twitterscraper.ProfileResult - ctx, cancel := context.WithTimeout(context.Background(), j.Timeout) - defer cancel() - - for profile := range scraper.SearchProfiles(ctx, query, count) { - profiles = append(profiles, profile) - if len(profiles) >= count && count > 0 { - break - } - } - ts.statsCollector.Add(j.WorkerID, stats.TwitterProfiles, uint(len(profiles))) - return profiles, nil -} - func (ts *TwitterScraper) GetTrends(j types.Job, baseDir string) ([]string, error) { scraper, account, err := ts.getCredentialScraper(j, baseDir) if err != nil { @@ -840,40 +537,7 @@ func (ts *TwitterScraper) GetTrends(j types.Job, baseDir string) ([]string, erro return trends, nil } -func (ts *TwitterScraper) GetFollowers(j types.Job, baseDir, user string, count int) ([]*twitterscraper.Profile, error) { - scraper, account, err := ts.getCredentialScraper(j, baseDir) - if err != nil { - return nil, err - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - followers, _, fetchErr := scraper.FetchFollowers(user, count, "") - if fetchErr != nil { - _ = ts.handleError(j, fetchErr, account) - return nil, fetchErr - } - ts.statsCollector.Add(j.WorkerID, stats.TwitterProfiles, uint(len(followers))) - return followers, nil -} - -func (ts *TwitterScraper) GetFollowing(j types.Job, baseDir, username string, count int) ([]*twitterscraper.Profile, error) { - scraper, account, err := ts.getCredentialScraper(j, baseDir) - if err != nil { - return nil, err - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - following, _, fetchErr := scraper.FetchFollowing(username, count, "") - if fetchErr != nil { - _ = ts.handleError(j, fetchErr, account) // Assuming FetchFollowing returns error, not errString - return nil, fetchErr - } - ts.statsCollector.Add(j.WorkerID, stats.TwitterProfiles, uint(len(following))) - return following, nil -} - -// getFollowersApify retrieves followers using Apify -func (ts *TwitterScraper) getFollowersApify(j types.Job, username string, maxResults uint, cursor client.Cursor) ([]*teetypes.ProfileResultApify, client.Cursor, error) { +func (ts *TwitterScraper) getFollowersApify(j types.Job, username string, maxResults uint, cursor client.Cursor) ([]*types.ProfileResultApify, client.Cursor, error) { apifyScraper, err := ts.getApifyScraper(j) if err != nil { return nil, "", err @@ -890,8 +554,7 @@ func (ts *TwitterScraper) getFollowersApify(j types.Job, username string, maxRes return followers, nextCursor, nil } -// getFollowingApify retrieves following using Apify -func (ts *TwitterScraper) getFollowingApify(j types.Job, username string, maxResults uint, cursor client.Cursor) ([]*teetypes.ProfileResultApify, client.Cursor, error) { +func (ts *TwitterScraper) getFollowingApify(j types.Job, username string, maxResults uint, cursor client.Cursor) ([]*types.ProfileResultApify, client.Cursor, error) { apifyScraper, err := ts.getApifyScraper(j) if err != nil { return nil, "", err @@ -924,50 +587,11 @@ func (ts *TwitterScraper) GetSpace(j types.Job, baseDir, spaceID string) (*twitt return space, nil } -func (ts *TwitterScraper) FetchHomeTweets(j types.Job, baseDir string, count int, cursor string) ([]*twitterscraper.Tweet, string, error) { - scraper, account, err := ts.getCredentialScraper(j, baseDir) - if err != nil { - return nil, "", err - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - tweets, nextCursor, fetchErr := scraper.FetchHomeTweets(count, cursor) - if fetchErr != nil { - _ = ts.handleError(j, fetchErr, account) - return nil, "", fetchErr - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterTweets, uint(len(tweets))) - return tweets, nextCursor, nil -} - -func (ts *TwitterScraper) FetchForYouTweets(j types.Job, baseDir string, count int, cursor string) ([]*twitterscraper.Tweet, string, error) { - scraper, account, err := ts.getCredentialScraper(j, baseDir) - if err != nil { - return nil, "", err - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterScrapes, 1) - tweets, nextCursor, fetchErr := scraper.FetchForYouTweets(count, cursor) - if fetchErr != nil { - _ = ts.handleError(j, fetchErr, account) - return nil, "", fetchErr - } - - ts.statsCollector.Add(j.WorkerID, stats.TwitterTweets, uint(len(tweets))) - return tweets, nextCursor, nil -} - -// TwitterScraperConfig is now defined in api/types to avoid duplication and circular imports - -// twitterScraperRuntimeConfig holds the runtime configuration without JSON tags to prevent credential serialization -// Unified config: use types.TwitterScraperConfig directly - type TwitterScraper struct { configuration config.TwitterScraperConfig accountManager *twitter.TwitterAccountManager statsCollector *stats.StatsCollector - capabilities map[teetypes.Capability]bool + capabilities map[types.Capability]bool } func NewTwitterScraper(jc config.JobConfiguration, c *stats.StatsCollector) *TwitterScraper { @@ -985,275 +609,85 @@ func NewTwitterScraper(jc config.JobConfiguration, c *stats.StatsCollector) *Twi configuration: config, accountManager: accountManager, statsCollector: c, - capabilities: map[teetypes.Capability]bool{ - teetypes.CapSearchByQuery: true, - teetypes.CapSearchByFullArchive: true, - teetypes.CapSearchByProfile: true, - teetypes.CapGetById: true, - teetypes.CapGetReplies: true, - teetypes.CapGetRetweeters: true, - teetypes.CapGetTweets: true, - teetypes.CapGetMedia: true, - teetypes.CapGetHomeTweets: true, - teetypes.CapGetForYouTweets: true, - teetypes.CapGetProfileById: true, - teetypes.CapGetTrends: true, - teetypes.CapGetFollowing: true, - teetypes.CapGetFollowers: true, - teetypes.CapGetSpace: true, + capabilities: map[types.Capability]bool{ + // Credential-based capabilities + types.CapSearchByQuery: true, + types.CapSearchByProfile: true, + types.CapGetById: true, + types.CapGetReplies: true, + types.CapGetTweets: true, + types.CapGetMedia: true, + types.CapGetProfileById: true, + types.CapGetTrends: true, + types.CapGetSpace: true, + types.CapGetProfile: true, + + // API-based capabilities + types.CapSearchByFullArchive: true, + + // Apify-based capabilities + types.CapGetFollowing: true, + types.CapGetFollowers: true, }, } } -// GetStructuredCapabilities returns the structured capabilities supported by this Twitter scraper -// based on the available credentials and API keys -func (ts *TwitterScraper) GetStructuredCapabilities() teetypes.WorkerCapabilities { - capabilities := make(teetypes.WorkerCapabilities) - - // Check if we have Twitter accounts for credential-based scraping - if len(ts.configuration.Accounts) > 0 { - var credCaps []teetypes.Capability - for capability, enabled := range ts.capabilities { - if enabled { - credCaps = append(credCaps, capability) - } - } - if len(credCaps) > 0 { - capabilities[teetypes.TwitterCredentialJob] = credCaps - } - } - - // Check if we have API keys for API-based scraping - if len(ts.configuration.ApiKeys) > 0 { - apiCaps := make([]teetypes.Capability, len(teetypes.TwitterAPICaps)) - copy(apiCaps, teetypes.TwitterAPICaps) - - // Check for elevated API capabilities - if ts.accountManager != nil { - for _, apiKey := range ts.accountManager.GetApiKeys() { - if apiKey.Type == twitter.TwitterApiKeyTypeElevated { - apiCaps = append(apiCaps, teetypes.CapSearchByFullArchive) - break - } - } - } - - capabilities[teetypes.TwitterApiJob] = apiCaps - } - - // Add Apify-specific capabilities based on available API key - // TODO: We should verify whether each of the actors is actually available through this API key - if ts.configuration.ApifyApiKey != "" { - capabilities[teetypes.TwitterApifyJob] = teetypes.TwitterApifyCaps - } - - // Add general twitter scraper capability (uses best available method) - if len(ts.configuration.Accounts) > 0 || len(ts.configuration.ApiKeys) > 0 { - var generalCaps []teetypes.Capability - if len(ts.configuration.Accounts) > 0 { - // Use all capabilities if we have accounts - for capability, enabled := range ts.capabilities { - if enabled { - generalCaps = append(generalCaps, capability) - } - } - } else { - // Use API capabilities if we only have keys - generalCaps = make([]teetypes.Capability, len(teetypes.TwitterAPICaps)) - copy(generalCaps, teetypes.TwitterAPICaps) - // Check for elevated capabilities - if ts.accountManager != nil { - for _, apiKey := range ts.accountManager.GetApiKeys() { - if apiKey.Type == twitter.TwitterApiKeyTypeElevated { - generalCaps = append(generalCaps, teetypes.CapSearchByFullArchive) - break - } - } - } - } - - capabilities[teetypes.TwitterJob] = generalCaps - } - - return capabilities -} - -type TwitterScrapeStrategy interface { - Execute(j types.Job, ts *TwitterScraper, jobArgs *teeargs.TwitterSearchArguments) (types.JobResult, error) -} - -func getScrapeStrategy(jobType teetypes.JobType) TwitterScrapeStrategy { - switch jobType { - case teetypes.TwitterCredentialJob: - return &CredentialScrapeStrategy{} - case teetypes.TwitterApiJob: - return &ApiKeyScrapeStrategy{} - case teetypes.TwitterApifyJob: - return &ApifyScrapeStrategy{} - default: - return &DefaultScrapeStrategy{} - } -} - -type CredentialScrapeStrategy struct{} - -func (s *CredentialScrapeStrategy) Execute(j types.Job, ts *TwitterScraper, jobArgs *teeargs.TwitterSearchArguments) (types.JobResult, error) { - capability := jobArgs.GetCapability() - switch capability { - case teetypes.CapSearchByQuery: - tweets, err := ts.queryTweetsWithCredentials(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults) - return processResponse(tweets, "", err) - case teetypes.CapSearchByFullArchive: - logrus.Warn("Full archive search with credential-only implementation may have limited results") - tweets, err := ts.queryTweetsWithCredentials(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults) - return processResponse(tweets, "", err) - default: - return defaultStrategyFallback(j, ts, jobArgs) - } -} - -type ApiKeyScrapeStrategy struct{} - -func (s *ApiKeyScrapeStrategy) Execute(j types.Job, ts *TwitterScraper, jobArgs *teeargs.TwitterSearchArguments) (types.JobResult, error) { +// executeCapability routes the job to the appropriate method based on capability +func (ts *TwitterScraper) executeCapability(j types.Job, jobArgs *twitterargs.Search) (types.JobResult, error) { capability := jobArgs.GetCapability() - switch capability { - case teetypes.CapSearchByQuery: - tweets, err := ts.queryTweetsWithApiKey(j, twitterx.TweetsSearchRecent, jobArgs.Query, jobArgs.MaxResults) - return processResponse(tweets, "", err) - case teetypes.CapSearchByFullArchive: - tweets, err := ts.queryTweetsWithApiKey(j, twitterx.TweetsAll, jobArgs.Query, jobArgs.MaxResults) - return processResponse(tweets, "", err) - case teetypes.CapGetProfileById: - _, apiKey, err := ts.getApiScraper(j) - if err != nil { - return types.JobResult{Error: err.Error()}, err - } - profile, err := ts.GetProfileByIDWithApiKey(j, jobArgs.Query, apiKey) - return processResponse(profile, "", err) - case teetypes.CapGetById: - _, apiKey, err := ts.getApiScraper(j) - if err != nil { - return types.JobResult{Error: err.Error()}, err - } - tweet, err := ts.GetTweetByIDWithApiKey(j, jobArgs.Query, apiKey) - return processResponse(tweet, "", err) - default: - return defaultStrategyFallback(j, ts, jobArgs) - } -} - -type ApifyScrapeStrategy struct{} -func (s *ApifyScrapeStrategy) Execute(j types.Job, ts *TwitterScraper, jobArgs *teeargs.TwitterSearchArguments) (types.JobResult, error) { - capability := teetypes.Capability(jobArgs.QueryType) switch capability { - case teetypes.CapGetFollowers: + // Apify-based capabilities + case types.CapGetFollowers: followers, nextCursor, err := ts.getFollowersApify(j, jobArgs.Query, uint(jobArgs.MaxResults), client.Cursor(jobArgs.NextCursor)) return processResponse(followers, nextCursor.String(), err) - case teetypes.CapGetFollowing: + case types.CapGetFollowing: following, nextCursor, err := ts.getFollowingApify(j, jobArgs.Query, uint(jobArgs.MaxResults), client.Cursor(jobArgs.NextCursor)) return processResponse(following, nextCursor.String(), err) - default: - return types.JobResult{Error: fmt.Sprintf("unsupported capability %s for Apify job", capability)}, fmt.Errorf("unsupported capability %s for Apify job", capability) - } -} -type DefaultScrapeStrategy struct{} - -// FIXED: Now using validated QueryType from centralized unmarshaller (addresses the TODO comment) -func (s *DefaultScrapeStrategy) Execute(j types.Job, ts *TwitterScraper, jobArgs *teeargs.TwitterSearchArguments) (types.JobResult, error) { - capability := teetypes.Capability(jobArgs.QueryType) - switch capability { - case teetypes.CapGetFollowers, teetypes.CapGetFollowing: - // Priority: Apify > Credentials for general TwitterJob - // TODO: We should verify whether each of the actors is actually available through this API key - if ts.configuration.ApifyApiKey != "" { - // Use Apify strategy - apifyStrategy := &ApifyScrapeStrategy{} - return apifyStrategy.Execute(j, ts, jobArgs) - } - // Fall back to credential-based strategy - credentialStrategy := &CredentialScrapeStrategy{} - return credentialStrategy.Execute(j, ts, jobArgs) - case teetypes.CapSearchByQuery: - // Priority: Credentials > API for searchbyquery - if len(ts.configuration.Accounts) > 0 { - credentialStrategy := &CredentialScrapeStrategy{} - return credentialStrategy.Execute(j, ts, jobArgs) - } - // Fall back to API strategy - tweets, err := ts.queryTweets(j, twitterx.TweetsSearchRecent, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults) - return processResponse(tweets, "", err) - case teetypes.CapSearchByFullArchive: - tweets, err := ts.queryTweets(j, twitterx.TweetsAll, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults) + // API-based capabilities + case types.CapSearchByFullArchive: + tweets, err := ts.SearchByFullArchive(j, twitterx.TweetsAll, jobArgs.Query, jobArgs.MaxResults) return processResponse(tweets, "", err) - default: - return defaultStrategyFallback(j, ts, jobArgs) - } -} - -func retryWithCursor[T any]( - j types.Job, - baseDir string, - count int, - cursor string, - fn func(j types.Job, baseDir string, currentCount int, currentCursor string) ([]*T, string, error), -) (types.JobResult, error) { - records := make([]*T, 0, count) - deadline := time.Now().Add(j.Timeout) - currentCursor := cursor // Use 'currentCursor' to manage pagination state within the loop - - for (len(records) < count || count == 0) && time.Now().Before(deadline) { // Allow count == 0 to fetch all available up to timeout - numToFetch := count - len(records) - if count == 0 { // If count is 0, fetch a reasonable batch size, e.g. 100, or let fn decide - numToFetch = 100 // Or another default batch size if fn doesn't handle count=0 well for batching - } - if numToFetch <= 0 && count > 0 { - break - } - - results, nextInternalCursor, err := fn(j, baseDir, numToFetch, currentCursor) - if err != nil { - if len(records) > 0 { - logrus.Warnf("Error during paginated fetch, returning partial results. Error: %v", err) - return processResponse(records, currentCursor, nil) - } - return processResponse(nil, "", err) - } - if len(results) > 0 { - records = append(records, results...) - } + // Credential-based capabilities + case types.CapSearchByQuery: + tweets, err := ts.SearchByQuery(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults) + return processResponse(tweets, "", err) + case types.CapSearchByProfile: + profile, err := ts.SearchByProfile(j, ts.configuration.DataDir, jobArgs.Query) + return processResponse(profile, "", err) + case types.CapGetById: + tweet, err := ts.GetTweet(j, ts.configuration.DataDir, jobArgs.Query) + return processResponse(tweet, "", err) + case types.CapGetReplies: + replies, err := ts.GetTweetReplies(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.NextCursor) + return processResponse(replies, jobArgs.NextCursor, err) + case types.CapGetRetweeters: + retweeters, err := ts.GetTweetRetweeters(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults, jobArgs.NextCursor) + return processResponse(retweeters, jobArgs.NextCursor, err) + case types.CapGetMedia: + media, nextCursor, err := ts.GetUserMedia(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults, jobArgs.NextCursor) + return processResponse(media, nextCursor, err) + case types.CapGetProfileById: + profile, err := ts.GetProfileByID(j, ts.configuration.DataDir, jobArgs.Query) + return processResponse(profile, "", err) + case types.CapGetTrends: + trends, err := ts.GetTrends(j, ts.configuration.DataDir) + return processResponse(trends, "", err) + case types.CapGetSpace: + space, err := ts.GetSpace(j, ts.configuration.DataDir, jobArgs.Query) + return processResponse(space, "", err) + case types.CapGetProfile: + profile, err := ts.SearchByProfile(j, ts.configuration.DataDir, jobArgs.Query) + return processResponse(profile, "", err) + case types.CapGetTweets: + tweets, nextCursor, err := ts.GetUserTweets(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults, jobArgs.NextCursor) + return processResponse(tweets, nextCursor, err) - if nextInternalCursor == "" || nextInternalCursor == currentCursor { // No more pages or cursor stuck - currentCursor = nextInternalCursor // Update to the last known cursor - break - } - currentCursor = nextInternalCursor - if count > 0 && len(records) >= count { // Check if desired count is reached - break - } + default: + return types.JobResult{Error: fmt.Sprintf("unsupported capability: %s", capability)}, fmt.Errorf("unsupported capability: %s", capability) } - return processResponse(records, currentCursor, nil) -} - -func retryWithCursorAndQuery[T any]( - j types.Job, - baseDir string, - query string, - count int, - cursor string, - fn func(j types.Job, baseDir string, currentQuery string, currentCount int, currentCursor string) ([]*T, string, error), -) (types.JobResult, error) { - return retryWithCursor( - j, - baseDir, - count, - cursor, - func(jInner types.Job, baseDirInner string, currentCountInner int, currentCursorInner string) ([]*T, string, error) { - return fn(jInner, baseDirInner, query, currentCountInner, currentCursorInner) - }, - ) } func processResponse(response any, nextCursor string, err error) (types.JobResult, error) { @@ -1269,71 +703,19 @@ func processResponse(response any, nextCursor string, err error) (types.JobResul return types.JobResult{Data: dat, NextCursor: nextCursor}, nil } -func defaultStrategyFallback(j types.Job, ts *TwitterScraper, jobArgs *teeargs.TwitterSearchArguments) (types.JobResult, error) { - capability := jobArgs.GetCapability() - switch capability { - case teetypes.CapSearchByProfile: - profile, err := ts.ScrapeTweetsProfile(j, ts.configuration.DataDir, jobArgs.Query) - return processResponse(profile, "", err) - case teetypes.CapGetById: - tweet, err := ts.GetTweet(j, ts.configuration.DataDir, jobArgs.Query) - return processResponse(tweet, "", err) - case teetypes.CapGetReplies: - // GetTweetReplies takes a cursor for a specific part of a thread, not general pagination of all replies. - // The retryWithCursor logic might not directly apply unless GetTweetReplies is adapted for broader pagination. - replies, err := ts.GetTweetReplies(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.NextCursor) - return processResponse(replies, jobArgs.NextCursor, err) // Pass original NextCursor as it's specific - case teetypes.CapGetRetweeters: - // Similar to GetTweetReplies, cursor is for a specific page. - retweeters, err := ts.GetTweetRetweeters(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults, jobArgs.NextCursor) - // GetTweetRetweeters in twitterscraper returns (profiles, nextCursorStr, error) - // The current ts.GetTweetRetweeters doesn't return the next cursor. This should be updated if pagination is needed here. - // For now, assuming it fetches one batch or handles its own pagination internally up to MaxResults. - return processResponse(retweeters, "", err) // Assuming no next cursor from this specific call structure - case teetypes.CapGetTweets: - return retryWithCursorAndQuery(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults, jobArgs.NextCursor, ts.GetUserTweets) - case teetypes.CapGetMedia: - return retryWithCursorAndQuery(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults, jobArgs.NextCursor, ts.GetUserMedia) - case teetypes.CapGetHomeTweets: - return retryWithCursor(j, ts.configuration.DataDir, jobArgs.MaxResults, jobArgs.NextCursor, ts.GetHomeTweets) - case teetypes.CapGetForYouTweets: - return retryWithCursor(j, ts.configuration.DataDir, jobArgs.MaxResults, jobArgs.NextCursor, ts.GetForYouTweets) - case teetypes.CapGetProfileById: - profile, err := ts.GetProfileByID(j, ts.configuration.DataDir, jobArgs.Query) - return processResponse(profile, "", err) - case teetypes.CapGetTrends: - trends, err := ts.GetTrends(j, ts.configuration.DataDir) - return processResponse(trends, "", err) - case teetypes.CapGetFollowing: - following, err := ts.GetFollowing(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults) - return processResponse(following, "", err) - case teetypes.CapGetFollowers: - followers, err := ts.GetFollowers(j, ts.configuration.DataDir, jobArgs.Query, jobArgs.MaxResults) - return processResponse(followers, "", err) - case teetypes.CapGetSpace: - space, err := ts.GetSpace(j, ts.configuration.DataDir, jobArgs.Query) - return processResponse(space, "", err) - } - return types.JobResult{Error: "invalid search type in defaultStrategyFallback: " + string(jobArgs.QueryType)}, fmt.Errorf("invalid search type: %s", jobArgs.QueryType) -} - -// ExecuteJob runs a job using the appropriate scrape strategy based on the job type. +// ExecuteJob runs a Twitter job using capability-based routing. // It first unmarshals the job arguments using the centralized type-safe unmarshaller. -// Then it runs the appropriate scrape strategy's Execute method, passing in the job, TwitterScraper, and job arguments. -// If the result is empty, it returns an error. -// If the result is not empty, it unmarshals the result into a slice of TweetResult and returns the result. -// If the unmarshaling fails, it returns an error. -// If the unmarshaled result is empty, it returns an error. +// Then it routes to the appropriate method based on the capability. func (ts *TwitterScraper) ExecuteJob(j types.Job) (types.JobResult, error) { - // Use the centralized unmarshaller from tee-types - this addresses the TODO comment! - jobArgs, err := teeargs.UnmarshalJobArguments(teetypes.JobType(j.Type), map[string]any(j.Arguments)) + // Use the centralized unmarshaller from tee-types + jobArgs, err := args.UnmarshalJobArguments(types.JobType(j.Type), map[string]any(j.Arguments)) if err != nil { logrus.Errorf("Error while unmarshalling job arguments for job ID %s, type %s: %v", j.UUID, j.Type, err) return types.JobResult{Error: "error unmarshalling job arguments"}, err } // Type assert to Twitter arguments - args, ok := jobArgs.(*teeargs.TwitterSearchArguments) + args, ok := jobArgs.(*twitterargs.Search) if !ok { logrus.Errorf("Expected Twitter arguments for job ID %s, type %s", j.UUID, j.Type) return types.JobResult{Error: "invalid argument type for Twitter job"}, fmt.Errorf("invalid argument type") @@ -1342,29 +724,29 @@ func (ts *TwitterScraper) ExecuteJob(j types.Job) (types.JobResult, error) { // Log the capability for debugging logrus.Debugf("Executing Twitter job ID %s with capability: %s", j.UUID, args.GetCapability()) - strategy := getScrapeStrategy(j.Type) - - jobResult, err := strategy.Execute(j, ts, args) + // Route based on capability + jobResult, err := ts.executeCapability(j, args) if err != nil { logrus.Errorf("Error executing job ID %s, type %s: %v", j.UUID, j.Type, err) return types.JobResult{Error: "error executing job"}, err } // Check if raw data is empty - if jobResult.Data == nil || len(jobResult.Data) == 0 { + if len(jobResult.Data) == 0 { logrus.Errorf("Job result data is empty for job ID %s, type %s", j.UUID, j.Type) return types.JobResult{Error: "job result data is empty"}, fmt.Errorf("job result data is empty") } + // Validate the result based on operation type switch { case args.IsSingleTweetOperation(): - var result *teetypes.TweetResult + var result *types.TweetResult if err := jobResult.Unmarshal(&result); err != nil { logrus.Errorf("Error while unmarshalling single tweet result for job ID %s, type %s: %v", j.UUID, j.Type, err) return types.JobResult{Error: "error unmarshalling single tweet result for final validation"}, err } case args.IsMultipleTweetOperation(): - var results []*teetypes.TweetResult + var results []*types.TweetResult if err := jobResult.Unmarshal(&results); err != nil { logrus.Errorf("Error while unmarshalling multiple tweet result for job ID %s, type %s: %v", j.UUID, j.Type, err) return types.JobResult{Error: "error unmarshalling multiple tweet result for final validation"}, err diff --git a/internal/jobs/twitter_test.go b/internal/jobs/twitter_test.go index 0858ff45..b5fb7307 100644 --- a/internal/jobs/twitter_test.go +++ b/internal/jobs/twitter_test.go @@ -7,8 +7,6 @@ import ( "strings" "time" - teetypes "github.com/masa-finance/tee-types/types" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/sirupsen/logrus" @@ -106,9 +104,9 @@ var _ = Describe("Twitter Scraper", func() { "data_dir": tempDir, }, statsCollector) res, err := scraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "query": "NASA", "max_results": 1, }, @@ -116,7 +114,7 @@ var _ = Describe("Twitter Scraper", func() { }) Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var results []*teetypes.TweetResult + var results []*types.TweetResult err = res.Unmarshal(&results) Expect(err).NotTo(HaveOccurred()) Expect(results).ToNot(BeEmpty()) @@ -131,9 +129,9 @@ var _ = Describe("Twitter Scraper", func() { "data_dir": tempDir, }, statsCollector) res, err := scraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterApiJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "query": "NASA", "max_results": 1, }, @@ -141,7 +139,7 @@ var _ = Describe("Twitter Scraper", func() { }) Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var results []*teetypes.TweetResult + var results []*types.TweetResult err = res.Unmarshal(&results) Expect(err).NotTo(HaveOccurred()) Expect(results).ToNot(BeEmpty()) @@ -157,9 +155,9 @@ var _ = Describe("Twitter Scraper", func() { }, statsCollector) // Try to run credential-only job with only API key res, err := scraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "query": "NASA", "max_results": 1, }, @@ -179,9 +177,9 @@ var _ = Describe("Twitter Scraper", func() { "data_dir": tempDir, }, statsCollector) res, err := scraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "query": "nasa", "max_results": 10, }, @@ -189,7 +187,7 @@ var _ = Describe("Twitter Scraper", func() { }) Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var results []*teetypes.TweetResult + var results []*types.TweetResult err = res.Unmarshal(&results) Expect(err).NotTo(HaveOccurred()) Expect(results).ToNot(BeEmpty()) @@ -200,9 +198,9 @@ var _ = Describe("Twitter Scraper", func() { "data_dir": tempDir, }, statsCollector) res, err := scraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterApiJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "query": "NASA", "max_results": 1, }, @@ -221,9 +219,9 @@ var _ = Describe("Twitter Scraper", func() { "data_dir": tempDir, }, statsCollector) res, err := scraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterApiJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByFullArchive, + "type": types.CapSearchByFullArchive, "query": "NASA", "max_results": 1, }, @@ -235,7 +233,7 @@ var _ = Describe("Twitter Scraper", func() { } Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var results []*teetypes.TweetResult + var results []*types.TweetResult err = res.Unmarshal(&results) Expect(err).NotTo(HaveOccurred()) Expect(results).ToNot(BeEmpty()) @@ -246,9 +244,9 @@ var _ = Describe("Twitter Scraper", func() { Context("General Twitter Scraper Tests", func() { It("should scrape tweets with a search query", func() { j := types.Job{ - Type: teetypes.TwitterJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "query": "nasa", "max_results": 10, }, @@ -258,7 +256,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var results []*teetypes.TweetResult + var results []*types.TweetResult err = res.Unmarshal(&results) Expect(err).NotTo(HaveOccurred()) Expect(results).ToNot(BeEmpty()) @@ -276,9 +274,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("TWITTER_ACCOUNTS is not set") } j := types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByProfile, + "type": types.CapSearchByProfile, "query": "NASA_Marshall", }, Timeout: 10 * time.Second, @@ -303,9 +301,9 @@ var _ = Describe("Twitter Scraper", func() { It("should get tweet by ID", func() { res, err := twitterScraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetById, + "type": types.CapGetById, "query": "1881258110712492142", }, Timeout: 10 * time.Second, @@ -313,7 +311,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var tweet *teetypes.TweetResult + var tweet *types.TweetResult err = res.Unmarshal(&tweet) Expect(err).NotTo(HaveOccurred()) Expect(tweet).NotTo(BeNil()) @@ -326,9 +324,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("TWITTER_ACCOUNTS is not set") } j := types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetReplies, + "type": types.CapGetReplies, "query": "1234567890", }, Timeout: 10 * time.Second, @@ -337,7 +335,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var replies []*teetypes.TweetResult + var replies []*types.TweetResult err = res.Unmarshal(&replies) Expect(err).NotTo(HaveOccurred()) Expect(replies).ToNot(BeEmpty()) @@ -355,9 +353,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("TWITTER_ACCOUNTS is not set") } j := types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetRetweeters, + "type": types.CapGetRetweeters, "query": "1234567890", "max_results": 5, }, @@ -385,9 +383,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("TWITTER_ACCOUNTS is not set") } j := types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetTweets, + "type": types.CapGetTweets, "query": "NASA", "max_results": 5, }, @@ -397,7 +395,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var tweets []*teetypes.TweetResult + var tweets []*types.TweetResult err = res.Unmarshal(&tweets) Expect(err).NotTo(HaveOccurred()) Expect(len(tweets)).ToNot(BeZero()) @@ -415,9 +413,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("TWITTER_ACCOUNTS is not set") } res, err := twitterScraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetMedia, + "type": types.CapGetMedia, "query": "NASA", "max_results": 5, }, @@ -426,81 +424,21 @@ var _ = Describe("Twitter Scraper", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var media []*teetypes.TweetResult + var media []*types.TweetResult err = res.Unmarshal(&media) Expect(err).NotTo(HaveOccurred()) Expect(media).ToNot(BeEmpty()) Expect(len(media[0].Photos) + len(media[0].Videos)).ToNot(BeZero()) }) - It("should fetch home tweets", func() { - if len(twitterAccounts) == 0 { - Skip("TWITTER_ACCOUNTS is not set") - } - j := types.Job{ - Type: teetypes.TwitterCredentialJob, - Arguments: map[string]interface{}{ - "type": teetypes.CapGetHomeTweets, - "max_results": 5, - }, - Timeout: 10 * time.Second, - } - res, err := twitterScraper.ExecuteJob(j) - Expect(err).NotTo(HaveOccurred()) - Expect(res.Error).To(BeEmpty()) - - var tweets []*teetypes.TweetResult - err = res.Unmarshal(&tweets) - Expect(err).NotTo(HaveOccurred()) - Expect(len(tweets)).ToNot(BeZero()) - Expect(tweets[0].Text).ToNot(BeEmpty()) - - // Wait briefly for asynchronous stats processing to complete - time.Sleep(100 * time.Millisecond) - - Expect(statsCollector.Stats.Stats[j.WorkerID][stats.TwitterScrapes]).To(BeNumerically("==", 1)) - Expect(statsCollector.Stats.Stats[j.WorkerID][stats.TwitterTweets]).To(BeNumerically("==", uint(len(tweets)))) - }) - - It("should fetch for you tweets", func() { - if len(twitterAccounts) == 0 { - Skip("TWITTER_ACCOUNTS is not set") - } - j := types.Job{ - Type: teetypes.TwitterCredentialJob, - Arguments: map[string]interface{}{ - "type": teetypes.CapGetForYouTweets, - "max_results": 5, - }, - Timeout: 10 * time.Second, - } - res, err := twitterScraper.ExecuteJob(j) - - Expect(err).NotTo(HaveOccurred()) - Expect(res.Error).To(BeEmpty()) - - var tweets []*teetypes.TweetResult - err = res.Unmarshal(&tweets) - Expect(err).NotTo(HaveOccurred()) - Expect(len(tweets)).ToNot(BeZero()) - Expect(tweets).ToNot(BeEmpty()) - Expect(tweets[0].Text).ToNot(BeEmpty()) - - // Wait briefly for asynchronous stats processing to complete - time.Sleep(100 * time.Millisecond) - - Expect(statsCollector.Stats.Stats[j.WorkerID][stats.TwitterScrapes]).To(BeNumerically("==", 1)) - Expect(statsCollector.Stats.Stats[j.WorkerID][stats.TwitterTweets]).To(BeNumerically("==", uint(len(tweets)))) - }) - It("should fetch profile by ID", func() { if len(twitterAccounts) == 0 { Skip("TWITTER_ACCOUNTS is not set") } j := types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetProfileById, + "type": types.CapGetProfileById, "query": "44196397", // Elon Musk's Twitter ID }, Timeout: 10 * time.Second, @@ -526,9 +464,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("TWITTER_ACCOUNTS is not set") } j := types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetFollowing, + "type": types.CapGetFollowing, "query": "NASA", "max_results": 5, }, @@ -556,9 +494,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("TWITTER_ACCOUNTS is not set") } j := types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetFollowers, + "type": types.CapGetFollowers, "query": "NASA", }, Timeout: 10 * time.Second, @@ -586,9 +524,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("TWITTER_ACCOUNTS is not set") } j := types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetTrends, + "type": types.CapGetTrends, }, Timeout: 10 * time.Second, } @@ -614,9 +552,9 @@ var _ = Describe("Twitter Scraper", func() { "data_dir": tempDir, }, statsCollector) res, err := scraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterApiJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetById, + "type": types.CapGetById, "query": "1881258110712492142", }, Timeout: 10 * time.Second, @@ -625,7 +563,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(res.Error).To(BeEmpty()) // Use the proper TweetResult type (the API converts TwitterXTweetData to TweetResult) - var tweet *teetypes.TweetResult + var tweet *types.TweetResult err = res.Unmarshal(&tweet) Expect(err).NotTo(HaveOccurred()) Expect(tweet).NotTo(BeNil()) @@ -656,9 +594,9 @@ var _ = Describe("Twitter Scraper", func() { "data_dir": tempDir, }, statsCollector) res, err := scraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterApiJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetProfileById, + "type": types.CapGetProfileById, "query": "44196397", // Elon Musk's Twitter ID }, Timeout: 10 * time.Second, @@ -688,9 +626,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("Needs to be constructed to fetch live spaces first - hard to test with hardcoded IDs") res, err := twitterScraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetSpace, + "type": types.CapGetSpace, "query": "1YpKkZEWlBaxj", }, Timeout: 10 * time.Second, @@ -708,7 +646,7 @@ var _ = Describe("Twitter Scraper", func() { Skip("Returns 'job result is empty' even when account has bookmarks") j := types.Job{ - Type: teetypes.TwitterJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ "type": "getbookmarks", // not yet in teetypes until it's supported "max_results": 5, @@ -719,7 +657,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var bookmarks []*teetypes.TweetResult + var bookmarks []*types.TweetResult err = res.Unmarshal(&bookmarks) Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) @@ -735,9 +673,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("Needs full archive key in TWITTER_API_KEYS to run") j := types.Job{ - Type: teetypes.TwitterApiJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByFullArchive, + "type": types.CapSearchByFullArchive, "query": "AI", "max_results": 2, }, @@ -747,7 +685,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var results []*teetypes.TweetResult + var results []*types.TweetResult err = res.Unmarshal(&results) Expect(err).NotTo(HaveOccurred()) Expect(results).ToNot(BeEmpty()) @@ -764,9 +702,9 @@ var _ = Describe("Twitter Scraper", func() { Skip("Needs full archive key (elevated) in TWITTER_API_KEYS to run") j := types.Job{ - Type: teetypes.TwitterCredentialJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByFullArchive, + "type": types.CapSearchByFullArchive, "query": "#AI", "max_results": 2, }, @@ -776,7 +714,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var results []*teetypes.TweetResult + var results []*types.TweetResult err = res.Unmarshal(&results) Expect(err).NotTo(HaveOccurred()) Expect(results).ToNot(BeEmpty()) @@ -799,9 +737,9 @@ var _ = Describe("Twitter Scraper", func() { }, statsCollector) j := types.Job{ - Type: teetypes.TwitterApifyJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetFollowers, + "type": types.CapGetFollowers, "query": "elonmusk", "max_results": 200, }, @@ -812,7 +750,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var followers []*teetypes.ProfileResultApify + var followers []*types.ProfileResultApify err = res.Unmarshal(&followers) Expect(err).NotTo(HaveOccurred()) Expect(followers).ToNot(BeEmpty()) @@ -830,9 +768,9 @@ var _ = Describe("Twitter Scraper", func() { }, statsCollector) j := types.Job{ - Type: teetypes.TwitterApifyJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetFollowing, + "type": types.CapGetFollowing, "query": "elonmusk", "max_results": 200, }, @@ -843,7 +781,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(err).NotTo(HaveOccurred()) Expect(res.Error).To(BeEmpty()) - var following []*teetypes.ProfileResultApify + var following []*types.ProfileResultApify err = res.Unmarshal(&following) Expect(err).NotTo(HaveOccurred()) Expect(following).ToNot(BeEmpty()) @@ -861,9 +799,9 @@ var _ = Describe("Twitter Scraper", func() { "data_dir": tempDir, }, statsCollector) res, err := scraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapGetFollowers, + "type": types.CapGetFollowers, "query": "elonmusk", "max_results": 200, }, @@ -873,7 +811,7 @@ var _ = Describe("Twitter Scraper", func() { Expect(res.Error).To(BeEmpty()) // Should return ProfileResultApify (from Apify) not twitterscraper.Profile - var followers []*teetypes.ProfileResultApify + var followers []*types.ProfileResultApify err = res.Unmarshal(&followers) Expect(err).NotTo(HaveOccurred()) Expect(followers).ToNot(BeEmpty()) @@ -884,9 +822,9 @@ var _ = Describe("Twitter Scraper", func() { Context("Error Handling", func() { It("should handle negative count values in job arguments", func() { res, err := twitterScraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "query": "test", "count": -5, // Invalid negative value }, @@ -899,9 +837,9 @@ var _ = Describe("Twitter Scraper", func() { It("should handle negative max_results values in job arguments", func() { res, err := twitterScraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "query": "test", "max_results": -10, // Invalid negative value }, @@ -914,7 +852,7 @@ var _ = Describe("Twitter Scraper", func() { It("should handle invalid capability for job type", func() { res, err := twitterScraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterApiJob, // API job type + Type: types.TwitterJob, // API job type Arguments: map[string]interface{}{ "type": "invalidcapability", // Invalid capability "query": "test", @@ -928,9 +866,9 @@ var _ = Describe("Twitter Scraper", func() { It("should handle capability not available for specific job type", func() { res, err := twitterScraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterApiJob, // API job type - doesn't support getfollowers + Type: types.TwitterJob, // API job type - doesn't support getfollowers Arguments: map[string]interface{}{ - "type": teetypes.CapGetFollowers, // Valid capability but not for TwitterApiJob + "type": types.CapGetFollowers, // Valid capability but not for TwitterJob "query": "test", }, Timeout: 10 * time.Second, @@ -943,9 +881,9 @@ var _ = Describe("Twitter Scraper", func() { It("should handle invalid JSON data structure", func() { // Create a job with arguments that will cause JSON unmarshalling to fail res, err := twitterScraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "query": "test", "max_results": "not_a_number", // String instead of int }, @@ -961,7 +899,7 @@ var _ = Describe("Twitter Scraper", func() { res, err := twitterScraper.ExecuteJob(types.Job{ Type: "unknown-job-type", // Invalid job type Arguments: map[string]interface{}{ - "type": teetypes.CapSearchByQuery, + "type": types.CapSearchByQuery, "query": "test", }, Timeout: 10 * time.Second, @@ -973,7 +911,7 @@ var _ = Describe("Twitter Scraper", func() { It("should handle empty arguments map", func() { res, err := twitterScraper.ExecuteJob(types.Job{ - Type: teetypes.TwitterJob, + Type: types.TwitterJob, Arguments: map[string]interface{}{}, // Empty arguments Timeout: 10 * time.Second, }) diff --git a/internal/jobs/twitterapify/client.go b/internal/jobs/twitterapify/client.go index cbdddca5..27356dd9 100644 --- a/internal/jobs/twitterapify/client.go +++ b/internal/jobs/twitterapify/client.go @@ -4,8 +4,8 @@ import ( "encoding/json" "fmt" - util "github.com/masa-finance/tee-types/pkg/util" - teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/api/types" + util "github.com/masa-finance/tee-worker/pkg/util" "github.com/masa-finance/tee-worker/internal/apify" "github.com/masa-finance/tee-worker/pkg/client" "github.com/sirupsen/logrus" @@ -44,7 +44,7 @@ func (c *TwitterApifyClient) ValidateApiKey() error { } // GetFollowers retrieves followers for a username using Apify -func (c *TwitterApifyClient) GetFollowers(username string, maxResults uint, cursor client.Cursor) ([]*teetypes.ProfileResultApify, client.Cursor, error) { +func (c *TwitterApifyClient) GetFollowers(username string, maxResults uint, cursor client.Cursor) ([]*types.ProfileResultApify, client.Cursor, error) { minimum := uint(200) // Ensure minimum of 200 as required by the actor @@ -63,7 +63,7 @@ func (c *TwitterApifyClient) GetFollowers(username string, maxResults uint, curs } // GetFollowing retrieves following for a username using Apify -func (c *TwitterApifyClient) GetFollowing(username string, cursor client.Cursor, maxResults uint) ([]*teetypes.ProfileResultApify, client.Cursor, error) { +func (c *TwitterApifyClient) GetFollowing(username string, cursor client.Cursor, maxResults uint) ([]*types.ProfileResultApify, client.Cursor, error) { minimum := uint(200) // Ensure minimum of 200 as required by the actor @@ -82,15 +82,15 @@ func (c *TwitterApifyClient) GetFollowing(username string, cursor client.Cursor, } // getProfiles runs the actor and retrieves profiles from the dataset -func (c *TwitterApifyClient) getProfiles(input FollowerActorRunRequest, cursor client.Cursor, limit uint) ([]*teetypes.ProfileResultApify, client.Cursor, error) { +func (c *TwitterApifyClient) getProfiles(input FollowerActorRunRequest, cursor client.Cursor, limit uint) ([]*types.ProfileResultApify, client.Cursor, error) { dataset, nextCursor, err := c.apifyClient.RunActorAndGetResponse(apify.ActorIds.TwitterFollowers, input, cursor, limit) if err != nil { return nil, client.EmptyCursor, err } - profiles := make([]*teetypes.ProfileResultApify, 0, len(dataset.Data.Items)) + profiles := make([]*types.ProfileResultApify, 0, len(dataset.Data.Items)) for i, item := range dataset.Data.Items { - var profile teetypes.ProfileResultApify + var profile types.ProfileResultApify if err := json.Unmarshal(item, &profile); err != nil { logrus.Warnf("Failed to unmarshal profile at index %d: %v", i, err) continue diff --git a/internal/jobs/web.go b/internal/jobs/web.go index 75d0d0dc..df7cf370 100644 --- a/internal/jobs/web.go +++ b/internal/jobs/web.go @@ -7,6 +7,10 @@ import ( "github.com/sirupsen/logrus" + "github.com/masa-finance/tee-worker/api/args" + "github.com/masa-finance/tee-worker/api/args/llm" + "github.com/masa-finance/tee-worker/api/args/llm/process" + "github.com/masa-finance/tee-worker/api/args/web" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs/llmapify" @@ -14,14 +18,12 @@ import ( "github.com/masa-finance/tee-worker/internal/jobs/webapify" "github.com/masa-finance/tee-worker/pkg/client" - teeargs "github.com/masa-finance/tee-types/args" - "github.com/masa-finance/tee-types/pkg/util" - teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/pkg/util" ) // WebApifyClient defines the interface for the Web Apify client to allow mocking in tests type WebApifyClient interface { - Scrape(workerID string, args teeargs.WebArguments, cursor client.Cursor) ([]*teetypes.WebScraperResult, string, client.Cursor, error) + Scrape(workerID string, args web.Page, cursor client.Cursor) ([]*types.WebScraperResult, string, client.Cursor, error) } // NewWebApifyClient is a function variable that can be replaced in tests. @@ -33,7 +35,7 @@ var NewWebApifyClient = func(apiKey string, statsCollector *stats.StatsCollector // LLMApify is the interface for the LLM processor client // Only the Process method is required for this flow type LLMApify interface { - Process(workerID string, args teeargs.LLMProcessorArguments, cursor client.Cursor) ([]*teetypes.LLMProcessorResult, client.Cursor, error) + Process(workerID string, args llm.Process, cursor client.Cursor) ([]*types.LLMProcessorResult, client.Cursor, error) } // NewLLMApifyClient is a function variable to allow injection in tests @@ -44,7 +46,7 @@ var NewLLMApifyClient = func(apiKey string, llmConfig config.LlmConfig, statsCol type WebScraper struct { configuration config.WebConfig statsCollector *stats.StatsCollector - capabilities []teetypes.Capability + capabilities []types.Capability } func NewWebScraper(jc config.JobConfiguration, statsCollector *stats.StatsCollector) *WebScraper { @@ -53,7 +55,7 @@ func NewWebScraper(jc config.JobConfiguration, statsCollector *stats.StatsCollec return &WebScraper{ configuration: cfg, statsCollector: statsCollector, - capabilities: teetypes.WebCaps, + capabilities: types.WebCaps, } } @@ -62,17 +64,17 @@ func (w *WebScraper) ExecuteJob(j types.Job) (types.JobResult, error) { // Require Gemini key for LLM processing in Web flow if !w.configuration.GeminiApiKey.IsValid() { - msg := errors.New("Gemini API key is required for Web job") + msg := errors.New("gemini API key is required for Web job") return types.JobResult{Error: msg.Error()}, msg } - jobArgs, err := teeargs.UnmarshalJobArguments(teetypes.JobType(j.Type), map[string]any(j.Arguments)) + jobArgs, err := args.UnmarshalJobArguments(types.JobType(j.Type), map[string]any(j.Arguments)) if err != nil { msg := fmt.Errorf("failed to unmarshal job arguments: %w", err) return types.JobResult{Error: msg.Error()}, msg } - webArgs, ok := jobArgs.(*teeargs.WebArguments) + webArgs, ok := jobArgs.(*web.Page) if !ok { return types.JobResult{Error: "invalid argument type for Web job"}, errors.New("invalid argument type") } @@ -98,11 +100,11 @@ func (w *WebScraper) ExecuteJob(j types.Job) (types.JobResult, error) { return types.JobResult{Error: "error creating LLM Apify client"}, fmt.Errorf("failed to create LLM Apify client: %w", err) } - llmArgs := teeargs.LLMProcessorArguments{ + llmArgs := llm.Process{ DatasetId: datasetId, Prompt: "summarize the content of this webpage, focusing on keywords and topics: ${markdown}", - MaxTokens: teeargs.LLMDefaultMaxTokens, - Temperature: teeargs.LLMDefaultTemperature, + MaxTokens: process.DefaultMaxTokens, + Temperature: process.DefaultTemperature, Items: uint(len(webResp)), } llmResp, _, llmErr := llmClient.Process(j.WorkerID, llmArgs, client.EmptyCursor) @@ -119,7 +121,7 @@ func (w *WebScraper) ExecuteJob(j types.Job) (types.JobResult, error) { data, err := json.Marshal(webResp) if err != nil { - return types.JobResult{Error: fmt.Sprintf("error marshalling Web response")}, fmt.Errorf("error marshalling Web response: %w", err) + return types.JobResult{Error: "error marshalling Web response"}, fmt.Errorf("error marshalling Web response: %w", err) } if w.statsCollector != nil { @@ -132,15 +134,3 @@ func (w *WebScraper) ExecuteJob(j types.Job) (types.JobResult, error) { NextCursor: cursor.String(), }, nil } - -// GetStructuredCapabilities returns the structured capabilities supported by the Web scraper -// based on the available credentials and API keys -func (ws *WebScraper) GetStructuredCapabilities() teetypes.WorkerCapabilities { - capabilities := make(teetypes.WorkerCapabilities) - - if ws.configuration.ApifyApiKey != "" && ws.configuration.GeminiApiKey.IsValid() { - capabilities[teetypes.WebJob] = teetypes.WebCaps - } - - return capabilities -} diff --git a/internal/jobs/web_test.go b/internal/jobs/web_test.go index c11de050..db870acd 100644 --- a/internal/jobs/web_test.go +++ b/internal/jobs/web_test.go @@ -8,6 +8,8 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/masa-finance/tee-worker/api/args/llm" + "github.com/masa-finance/tee-worker/api/args/web" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs" @@ -15,17 +17,14 @@ import ( "github.com/masa-finance/tee-worker/internal/jobs/stats" "github.com/masa-finance/tee-worker/internal/jobs/webapify" "github.com/masa-finance/tee-worker/pkg/client" - - teeargs "github.com/masa-finance/tee-types/args" - teetypes "github.com/masa-finance/tee-types/types" ) // MockWebApifyClient is a mock implementation of the WebApifyClient. type MockWebApifyClient struct { - ScrapeFunc func(args teeargs.WebArguments) ([]*teetypes.WebScraperResult, string, client.Cursor, error) + ScrapeFunc func(args web.Page) ([]*types.WebScraperResult, string, client.Cursor, error) } -func (m *MockWebApifyClient) Scrape(_ string, args teeargs.WebArguments, _ client.Cursor) ([]*teetypes.WebScraperResult, string, client.Cursor, error) { +func (m *MockWebApifyClient) Scrape(_ string, args web.Page, _ client.Cursor) ([]*types.WebScraperResult, string, client.Cursor, error) { if m != nil && m.ScrapeFunc != nil { res, datasetId, next, err := m.ScrapeFunc(args) return res, datasetId, next, err @@ -36,14 +35,14 @@ func (m *MockWebApifyClient) Scrape(_ string, args teeargs.WebArguments, _ clien // MockLLMApifyClient is a mock implementation of the LLMApify interface // used to prevent external calls during unit tests. type MockLLMApifyClient struct { - ProcessFunc func(workerID string, args teeargs.LLMProcessorArguments, cursor client.Cursor) ([]*teetypes.LLMProcessorResult, client.Cursor, error) + ProcessFunc func(workerID string, args llm.Process, cursor client.Cursor) ([]*types.LLMProcessorResult, client.Cursor, error) } -func (m *MockLLMApifyClient) Process(workerID string, args teeargs.LLMProcessorArguments, cursor client.Cursor) ([]*teetypes.LLMProcessorResult, client.Cursor, error) { +func (m *MockLLMApifyClient) Process(workerID string, args llm.Process, cursor client.Cursor) ([]*types.LLMProcessorResult, client.Cursor, error) { if m != nil && m.ProcessFunc != nil { return m.ProcessFunc(workerID, args, cursor) } - return []*teetypes.LLMProcessorResult{}, client.EmptyCursor, nil + return []*types.LLMProcessorResult{}, client.EmptyCursor, nil } var _ = Describe("WebScraper", func() { @@ -68,9 +67,9 @@ var _ = Describe("WebScraper", func() { scraper = jobs.NewWebScraper(cfg, statsCollector) mockClient = &MockWebApifyClient{} mockLLM = &MockLLMApifyClient{ - ProcessFunc: func(workerID string, args teeargs.LLMProcessorArguments, cursor client.Cursor) ([]*teetypes.LLMProcessorResult, client.Cursor, error) { + ProcessFunc: func(workerID string, args llm.Process, cursor client.Cursor) ([]*types.LLMProcessorResult, client.Cursor, error) { // Return a single empty summary to avoid changing expectations - return []*teetypes.LLMProcessorResult{{LLMResponse: ""}}, client.EmptyCursor, nil + return []*types.LLMProcessorResult{{LLMResponse: ""}}, client.EmptyCursor, nil }, } @@ -84,7 +83,7 @@ var _ = Describe("WebScraper", func() { job = types.Job{ UUID: "test-uuid", - Type: teetypes.WebJob, + Type: types.WebJob, } }) @@ -103,22 +102,22 @@ var _ = Describe("WebScraper", func() { It("should call Scrape and return data and next cursor", func() { job.Arguments = map[string]any{ - "type": teetypes.WebScraper, + "type": types.WebScraper, "url": "https://example.com", "max_depth": 1, "max_pages": 2, } - mockClient.ScrapeFunc = func(args teeargs.WebArguments) ([]*teetypes.WebScraperResult, string, client.Cursor, error) { + mockClient.ScrapeFunc = func(args web.Page) ([]*types.WebScraperResult, string, client.Cursor, error) { Expect(args.URL).To(Equal("https://example.com")) - return []*teetypes.WebScraperResult{{URL: "https://example.com", Markdown: "# Hello"}}, "dataset-123", client.Cursor("next-cursor"), nil + return []*types.WebScraperResult{{URL: "https://example.com", Markdown: "# Hello"}}, "dataset-123", client.Cursor("next-cursor"), nil } result, err := scraper.ExecuteJob(job) Expect(err).NotTo(HaveOccurred()) Expect(result.NextCursor).To(Equal("next-cursor")) - var resp []*teetypes.WebScraperResult + var resp []*types.WebScraperResult err = json.Unmarshal(result.Data, &resp) Expect(err).NotTo(HaveOccurred()) Expect(resp).To(HaveLen(1)) @@ -128,14 +127,14 @@ var _ = Describe("WebScraper", func() { It("should handle errors from the web client", func() { job.Arguments = map[string]any{ - "type": teetypes.WebScraper, + "type": types.WebScraper, "url": "https://example.com", "max_depth": 0, "max_pages": 1, } expectedErr := errors.New("client error") - mockClient.ScrapeFunc = func(args teeargs.WebArguments) ([]*teetypes.WebScraperResult, string, client.Cursor, error) { + mockClient.ScrapeFunc = func(args web.Page) ([]*types.WebScraperResult, string, client.Cursor, error) { return nil, "", client.EmptyCursor, expectedErr } @@ -150,7 +149,7 @@ var _ = Describe("WebScraper", func() { return nil, errors.New("client creation failed") } job.Arguments = map[string]any{ - "type": teetypes.WebScraper, + "type": types.WebScraper, "url": "https://example.com", "max_depth": 0, "max_pages": 1, @@ -202,9 +201,9 @@ var _ = Describe("WebScraper", func() { job := types.Job{ UUID: "integration-test-uuid", - Type: teetypes.WebJob, + Type: types.WebJob, Arguments: map[string]any{ - "type": teetypes.WebScraper, + "type": types.WebScraper, "url": "https://docs.learnbittensor.org", "max_depth": maxDepth, "max_pages": maxPages, @@ -216,7 +215,7 @@ var _ = Describe("WebScraper", func() { Expect(result.Error).To(BeEmpty()) Expect(result.Data).NotTo(BeEmpty()) - var resp []*teetypes.WebScraperResult + var resp []*types.WebScraperResult err = json.Unmarshal(result.Data, &resp) Expect(err).NotTo(HaveOccurred()) @@ -230,23 +229,5 @@ var _ = Describe("WebScraper", func() { Expect(resp[i].Text).To(ContainSubstring("Bittensor")) } }) - - It("should expose capabilities only when both APIFY and GEMINI keys are present", func() { - cfg := config.JobConfiguration{ - "apify_api_key": apifyKey, - "gemini_api_key": geminiKey, - } - integrationStatsCollector := stats.StartCollector(128, cfg) - integrationScraper := jobs.NewWebScraper(cfg, integrationStatsCollector) - - caps := integrationScraper.GetStructuredCapabilities() - if apifyKey != "" && geminiKey != "" { - Expect(caps[teetypes.WebJob]).NotTo(BeEmpty()) - } else { - // Expect no capabilities when either key is missing - _, ok := caps[teetypes.WebJob] - Expect(ok).To(BeFalse()) - } - }) }) }) diff --git a/internal/jobs/webapify/client.go b/internal/jobs/webapify/client.go index e59d550d..729b7311 100644 --- a/internal/jobs/webapify/client.go +++ b/internal/jobs/webapify/client.go @@ -4,8 +4,8 @@ import ( "encoding/json" "fmt" - teeargs "github.com/masa-finance/tee-types/args" - teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/api/args/web" + "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/apify" "github.com/masa-finance/tee-worker/internal/jobs/stats" "github.com/masa-finance/tee-worker/pkg/client" @@ -41,12 +41,12 @@ func (c *ApifyClient) ValidateApiKey() error { return c.client.ValidateApiKey() } -func (c *ApifyClient) Scrape(workerID string, args teeargs.WebArguments, cursor client.Cursor) ([]*teetypes.WebScraperResult, string, client.Cursor, error) { +func (c *ApifyClient) Scrape(workerID string, args web.Page, cursor client.Cursor) ([]*types.WebScraperResult, string, client.Cursor, error) { if c.statsCollector != nil { c.statsCollector.Add(workerID, stats.WebQueries, 1) } - input := args.ToWebScraperRequest() + input := args.ToScraperRequest() limit := uint(args.MaxPages) dataset, nextCursor, err := c.client.RunActorAndGetResponse(apify.ActorIds.WebScraper, input, cursor, limit) @@ -57,10 +57,10 @@ func (c *ApifyClient) Scrape(workerID string, args teeargs.WebArguments, cursor return nil, "", client.EmptyCursor, err } - response := make([]*teetypes.WebScraperResult, 0, len(dataset.Data.Items)) + response := make([]*types.WebScraperResult, 0, len(dataset.Data.Items)) for i, item := range dataset.Data.Items { - var resp teetypes.WebScraperResult + var resp types.WebScraperResult if err := json.Unmarshal(item, &resp); err != nil { logrus.Warnf("Failed to unmarshal scrape result at index %d: %v", i, err) continue diff --git a/internal/jobs/webapify/client_test.go b/internal/jobs/webapify/client_test.go index a23867dd..7380c576 100644 --- a/internal/jobs/webapify/client_test.go +++ b/internal/jobs/webapify/client_test.go @@ -8,11 +8,10 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/masa-finance/tee-worker/api/args/web" "github.com/masa-finance/tee-worker/internal/apify" "github.com/masa-finance/tee-worker/internal/jobs/webapify" "github.com/masa-finance/tee-worker/pkg/client" - - teeargs "github.com/masa-finance/tee-types/args" ) // MockApifyClient is a mock implementation of the ApifyClient. @@ -66,7 +65,7 @@ var _ = Describe("WebApifyClient", func() { Describe("Scrape", func() { It("should construct the correct actor input", func() { - args := teeargs.WebArguments{ + args := web.Page{ URL: "https://example.com", MaxDepth: 1, MaxPages: 2, @@ -88,7 +87,7 @@ var _ = Describe("WebApifyClient", func() { return nil, "", expectedErr } - args := teeargs.WebArguments{ + args := web.Page{ URL: "https://example.com", MaxDepth: 0, MaxPages: 1, @@ -108,7 +107,7 @@ var _ = Describe("WebApifyClient", func() { return dataset, "next", nil } - args := teeargs.WebArguments{ + args := web.Page{ URL: "https://example.com", MaxDepth: 0, MaxPages: 1, @@ -133,7 +132,7 @@ var _ = Describe("WebApifyClient", func() { return dataset, "next", nil } - args := teeargs.WebArguments{ + args := web.Page{ URL: "https://example.com", MaxDepth: 0, MaxPages: 1, @@ -194,7 +193,7 @@ var _ = Describe("WebApifyClient", func() { realClient, err := webapify.NewClient(apifyKey, nil) Expect(err).NotTo(HaveOccurred()) - args := teeargs.WebArguments{ + args := web.Page{ URL: "https://example.com", MaxDepth: 0, MaxPages: 1, diff --git a/internal/jobserver/jobserver.go b/internal/jobserver/jobserver.go index 1ace244f..3bb3a093 100644 --- a/internal/jobserver/jobserver.go +++ b/internal/jobserver/jobserver.go @@ -8,11 +8,10 @@ import ( "sync" "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "github.com/google/uuid" - teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" + "github.com/masa-finance/tee-worker/internal/capabilities" "github.com/masa-finance/tee-worker/internal/config" "github.com/masa-finance/tee-worker/internal/jobs" "github.com/masa-finance/tee-worker/internal/jobs/stats" @@ -28,7 +27,7 @@ type JobServer struct { results *ResultCache jobConfiguration config.JobConfiguration - jobWorkers map[teetypes.JobType]*jobWorkerEntry + jobWorkers map[types.JobType]*jobWorkerEntry executedJobs map[string]bool } @@ -80,32 +79,23 @@ func NewJobServer(workers int, jc config.JobConfiguration) *JobServer { // Initialize job workers logrus.Info("Setting up job workers...") - jobworkers := map[teetypes.JobType]*jobWorkerEntry{ - teetypes.WebJob: { + jobworkers := map[types.JobType]*jobWorkerEntry{ + types.WebJob: { w: jobs.NewWebScraper(jc, s), }, - teetypes.TwitterJob: { + types.TwitterJob: { w: jobs.NewTwitterScraper(jc, s), }, - teetypes.TwitterCredentialJob: { - w: jobs.NewTwitterScraper(jc, s), // Uses the same implementation as standard Twitter scraper - }, - teetypes.TwitterApiJob: { - w: jobs.NewTwitterScraper(jc, s), // Uses the same implementation as standard Twitter scraper - }, - teetypes.TwitterApifyJob: { - w: jobs.NewTwitterScraper(jc, s), // Register Apify job type with Twitter scraper - }, - teetypes.TiktokJob: { + types.TiktokJob: { w: jobs.NewTikTokScraper(jc, s), }, - teetypes.RedditJob: { + types.RedditJob: { w: jobs.NewRedditScraper(jc, s), }, - teetypes.LinkedInJob: { + types.LinkedInJob: { w: jobs.NewLinkedInScraper(jc, s), }, - teetypes.TelemetryJob: { + types.TelemetryJob: { w: jobs.NewTelemetryJob(jc, s), }, } @@ -154,31 +144,11 @@ func NewJobServer(workers int, jc config.JobConfiguration) *JobServer { return js } -// GetWorkerCapabilities returns the structured capabilities for all registered workers -func (js *JobServer) GetWorkerCapabilities() teetypes.WorkerCapabilities { - // Use a map to deduplicate capabilities by job type - jobTypeCapMap := make(map[teetypes.JobType]map[teetypes.Capability]struct{}) - - for _, workerEntry := range js.jobWorkers { - workerCapabilities := workerEntry.w.GetStructuredCapabilities() - for jobType, capabilities := range workerCapabilities { - if _, exists := jobTypeCapMap[jobType]; !exists { - jobTypeCapMap[jobType] = make(map[teetypes.Capability]struct{}) - } - for _, capability := range capabilities { - jobTypeCapMap[jobType][capability] = struct{}{} - } - } - } - - // Convert to final map format - allCapabilities := make(teetypes.WorkerCapabilities) - for jobType, capabilitySet := range jobTypeCapMap { - capabilities := maps.Keys(capabilitySet) - allCapabilities[jobType] = capabilities - } - - return allCapabilities +// GetWorkerCapabilities returns the structured capabilities using centralized detection +func (js *JobServer) GetWorkerCapabilities() types.WorkerCapabilities { + // Use centralized capability detection instead of aggregating from individual workers + // This ensures consistent, real capability detection across all job types + return capabilities.DetectCapabilities(js.jobConfiguration, js) } func (js *JobServer) Run(ctx context.Context) { @@ -203,7 +173,7 @@ func (js *JobServer) AddJob(j types.Job) (string, error) { return "", errors.New("this job is not for this worker") } - if j.Type != teetypes.TelemetryJob && config.MinersWhiteList != "" { + if j.Type != types.TelemetryJob && config.MinersWhiteList != "" { var miners []string // In standalone mode, we just whitelist ourselves diff --git a/internal/jobserver/jobserver_test.go b/internal/jobserver/jobserver_test.go index b2e64a78..adce7e9d 100644 --- a/internal/jobserver/jobserver_test.go +++ b/internal/jobserver/jobserver_test.go @@ -8,7 +8,6 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/config" @@ -24,7 +23,7 @@ var _ = Describe("Jobserver", func() { jobserver := NewJobServer(2, config.JobConfiguration{}) uuid, err := jobserver.AddJob(types.Job{ - Type: teetypes.WebJob, + Type: types.WebJob, Arguments: map[string]any{ "url": "google", }, @@ -51,7 +50,7 @@ var _ = Describe("Jobserver", func() { jobserver := NewJobServer(2, config.JobConfiguration{}) uuid, err := jobserver.AddJob(types.Job{ - Type: teetypes.WebJob, + Type: types.WebJob, Arguments: map[string]any{ "url": "google", }, @@ -64,7 +63,7 @@ var _ = Describe("Jobserver", func() { Expect(err.Error()).To(ContainSubstring("this job is not from a whitelisted miner")) uuid, err = jobserver.AddJob(types.Job{ - Type: teetypes.WebJob, + Type: types.WebJob, WorkerID: "miner1", Arguments: map[string]any{ "url": "google", @@ -81,7 +80,7 @@ var _ = Describe("Jobserver", func() { jobserver := NewJobServer(2, config.JobConfiguration{}) uuid, err := jobserver.AddJob(types.Job{ - Type: teetypes.WebJob, + Type: types.WebJob, Arguments: map[string]any{ "url": "google", }, @@ -96,7 +95,7 @@ var _ = Describe("Jobserver", func() { Expect(exists).ToNot(BeTrue()) uuid, err = jobserver.AddJob(types.Job{ - Type: teetypes.WebJob, + Type: types.WebJob, Arguments: map[string]any{ "url": "google", }, diff --git a/internal/jobserver/worker.go b/internal/jobserver/worker.go index 0c77382b..3d19edbf 100644 --- a/internal/jobserver/worker.go +++ b/internal/jobserver/worker.go @@ -4,7 +4,6 @@ import ( "context" "fmt" - teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" "github.com/sirupsen/logrus" ) @@ -26,7 +25,6 @@ func (js *JobServer) worker(c context.Context) { } type worker interface { - GetStructuredCapabilities() teetypes.WorkerCapabilities ExecuteJob(j types.Job) (types.JobResult, error) } diff --git a/pkg/client/apify_client.go b/pkg/client/apify_client.go index fafad74f..589ecb3f 100644 --- a/pkg/client/apify_client.go +++ b/pkg/client/apify_client.go @@ -331,8 +331,8 @@ func (c *ApifyClient) ValidateApiKey() error { } var ( - ErrActorFailed = errors.New("Actor run failed") - ErrActorAborted = errors.New("Actor run aborted") + ErrActorFailed = errors.New("actor run failed") + ErrActorAborted = errors.New("actor run aborted") ) // runActorAndGetProfiles runs the actor and retrieves profiles from the dataset diff --git a/pkg/client/http.go b/pkg/client/http.go index 7d05b9e2..50ace989 100644 --- a/pkg/client/http.go +++ b/pkg/client/http.go @@ -8,7 +8,6 @@ import ( "net/http" "time" - teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" ) @@ -19,6 +18,13 @@ type Client struct { HTTPClient *http.Client } +// EncryptedRequest represents an encrypted request/response pair +// note, this is copied from api/tee/encrypted.go to avoid TEE dependencies in client code +type EncryptedRequest struct { + EncryptedResult string `json:"encrypted_result"` + EncryptedRequest string `json:"encrypted_request"` +} + // setAPIKeyHeader sets the API key on the request if configured. func (c *Client) setAPIKeyHeader(req *http.Request) { if c.options != nil && c.options.APIKey != "" { @@ -44,7 +50,7 @@ func NewClient(baseURL string, opts ...Option) (*Client, error) { // CreateJobSignature sends a job to the server to generate a job signature. // The server will attach its worker ID to the job before generating the signature. -func (c *Client) CreateJobSignature(job teetypes.Job) (JobSignature, error) { +func (c *Client) CreateJobSignature(job types.Job) (JobSignature, error) { jobJSON, err := json.Marshal(job) if err != nil { return JobSignature(""), fmt.Errorf("error marshaling job: %w", err) @@ -115,7 +121,7 @@ func (c *Client) SubmitJob(JobSignature JobSignature) (*JobResult, error) { // Decrypt sends the encrypted result to the server to decrypt it. func (c *Client) Decrypt(JobSignature JobSignature, encryptedResult string) (string, error) { - decryptReq := types.EncryptedRequest{ + decryptReq := EncryptedRequest{ EncryptedResult: encryptedResult, EncryptedRequest: string(JobSignature), } diff --git a/pkg/client/http_test.go b/pkg/client/http_test.go index 9eaeee5f..e49bc3ac 100644 --- a/pkg/client/http_test.go +++ b/pkg/client/http_test.go @@ -5,7 +5,6 @@ import ( "net/http" "net/http/httptest" - teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" . "github.com/masa-finance/tee-worker/pkg/client" . "github.com/onsi/ginkgo/v2" @@ -59,7 +58,7 @@ var _ = Describe("Client", func() { Describe("CreateJobSignature", func() { It("should create a job signature successfully", func() { - job := teetypes.Job{Type: "test-job"} + job := types.Job{Type: "test-job"} signature, err := client.CreateJobSignature(job) Expect(err).NotTo(HaveOccurred()) Expect(signature).To(Equal(JobSignature("mock-signature"))) diff --git a/pkg/util/math.go b/pkg/util/math.go new file mode 100644 index 00000000..4e6f9719 --- /dev/null +++ b/pkg/util/math.go @@ -0,0 +1,27 @@ +package util + +import "golang.org/x/exp/constraints" + +func Min[T constraints.Ordered](elements ...T) T { + ret := elements[0] + + for _, x := range elements { + if x < ret { + ret = x + } + } + + return ret +} + +func Max[T constraints.Ordered](elements ...T) T { + ret := elements[0] + + for _, x := range elements { + if x > ret { + ret = x + } + } + + return ret +} diff --git a/pkg/util/math_test.go b/pkg/util/math_test.go new file mode 100644 index 00000000..3213360a --- /dev/null +++ b/pkg/util/math_test.go @@ -0,0 +1,24 @@ +package util_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/pkg/util" +) + +var _ = Describe("Math functions", func() { + Describe("Min", func() { + It("should calculate the minimum of a series of orderable values regardless of parameter order", func() { + Expect(util.Min(1, 2, 3, 4, 5, 6)).To(Equal(1)) + Expect(util.Min(2, 3, 8, -1, 4, 42)).To(Equal(-1)) + }) + }) + + Describe("Max", func() { + It("should calculate the maximum of a series of orderable values regardless of parameter order", func() { + Expect(util.Max(1, 2, 3, 4, 5, 6)).To(Equal(6)) + Expect(util.Max(2, 3, 8, -12, 4, 42)).To(Equal(42)) + }) + }) +}) diff --git a/pkg/util/set.go b/pkg/util/set.go new file mode 100644 index 00000000..f86bc315 --- /dev/null +++ b/pkg/util/set.go @@ -0,0 +1,103 @@ +package util + +import ( + "iter" + "maps" + "slices" +) + +// Set is a generic collection of unique items. +type Set[T comparable] map[T]struct{} + +// NewSet creates and returns a new Set with the given items, deduplicating them. +func NewSet[T comparable](items ...T) *Set[T] { + ret := make(Set[T], len(items)) + ret.Add(items...) + return &ret +} + +// Contains checks if an item is present in the set. +func (s *Set[T]) Contains(item T) bool { + _, exists := (*s)[item] + return exists +} + +// Add inserts the given items into the set, deduplicating them. +func (s *Set[T]) Add(items ...T) *Set[T] { + for _, item := range items { + (*s)[item] = struct{}{} + } + return s +} + +// Delete removes the given items from the set if it contains them. +func (s *Set[T]) Delete(items ...T) *Set[T] { + for _, item := range items { + delete((*s), item) + } + return s +} + +// Length returns the number of items in the set. +func (s *Set[T]) Length() int { + return len(*s) +} + +// Items returns a slice containing all the items in the set. +// The order of items in the slice is not guaranteed. +func (s *Set[T]) Items() []T { + return slices.Collect(s.ItemsSeq()) +} + +// ItemsSeq returns an iterator that yields all the items in the set. +// The order of items is not guaranteed. +func (s *Set[T]) ItemsSeq() iter.Seq[T] { + return maps.Keys(*s) +} + +// Union returns a new set containing all the items from the original set and all the provided sets, deduplicating them. +func (s *Set[T]) Union(sets ...*Set[T]) *Set[T] { + sum := s.Length() + for _, ss := range sets { + sum = sum + ss.Length() + } + + ret := make(map[T]struct{}, sum) + for k := range *s { + ret[k] = struct{}{} + } + for _, ss := range sets { + for k := range *ss { + ret[k] = struct{}{} + } + } + + rs := Set[T](ret) + return &rs +} + +// Intersection returns a new set containing only the items that are present in both the original set and s2. +func (s *Set[T]) Intersection(s2 *Set[T]) *Set[T] { + ret := make(map[T]struct{}, Min(s.Length(), s2.Length())) + for k := range *s { + if s2.Contains(k) { + ret[k] = struct{}{} + } + } + + rs := Set[T](ret) + return &rs +} + +// Difference returns a new set containing items that are in the original set but not in s2. +func (s *Set[T]) Difference(s2 *Set[T]) *Set[T] { + ret := make(map[T]struct{}, s.Length()) + for k := range *s { + if !s2.Contains(k) { + ret[k] = struct{}{} + } + } + + rs := Set[T](ret) + return &rs +} diff --git a/pkg/util/set_test.go b/pkg/util/set_test.go new file mode 100644 index 00000000..a096a451 --- /dev/null +++ b/pkg/util/set_test.go @@ -0,0 +1,73 @@ +package util_test + +import ( + "slices" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-worker/pkg/util" +) + +var _ = Describe("Set", func() { + It("should return a slice of all its elements", func() { + s := util.NewSet(0, 1, 2, 2, 2, 3, 4, 5, 5, 5, 5, 6) + Expect(s.Items()).To(ConsistOf(0, 1, 2, 3, 4, 5, 6)) + Expect(s.Length()).To(Equal(7)) + }) + + It("should check whether an item is included in the Set or not", func() { + s := util.NewSet(1, 2, 3, 4, 5, 6) + Expect(s.Contains(2)).To(BeTrue()) + Expect(s.Contains(42)).To(BeFalse()) + }) + + It("should add items to the set without duplicating", func() { + s := util.NewSet(1, 2, 3, 4, 5, 6) + s.Add(7, 8, 9, 2, 4) + Expect(s.Items()).To(ConsistOf(1, 2, 3, 4, 5, 6, 7, 8, 9)) + }) + + It("should delete items from the set if they exist", func() { + s := util.NewSet(1, 2, 3, 4, 5, 6) + s.Delete(7, 8, 9, 2, 4, 42) + Expect(s.Items()).To(ConsistOf(1, 3, 5, 6)) + }) + + It("should return the number of items in the set", func() { + s := util.NewSet(0, 1, 2, 3, 4, 5, 6) + Expect(s.Length()).To(Equal(7)) + s.Add(7, 8, 9, 2, 4) + Expect(s.Length()).To(Equal(10)) + s.Delete(0, 1) + Expect(s.Length()).To(Equal(8)) + }) + + It("should return a sequence of all its elements", func() { + s := util.NewSet(0, 1, 2, 3, 4, 5, 6) + items := slices.Collect(s.ItemsSeq()) + Expect(items).To(ConsistOf(0, 1, 2, 3, 4, 5, 6)) + }) + + It("should return the union of multiple sets", func() { + s1 := util.NewSet(0, 1, 2, 3, 4) + s2 := util.NewSet(3, 4, 5, 6, 7) + s3 := util.NewSet(8, 9, 0) + s4 := s1.Union(s2, s3) + Expect(s4.Items()).To(ConsistOf(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) + }) + + It("should return the intersection of two sets", func() { + s1 := util.NewSet(0, 1, 2, 3, 4) + s2 := util.NewSet(3, 4, 5, 6, 7) + s3 := s1.Intersection(s2) + Expect(s3.Items()).To(ConsistOf(3, 4)) + }) + + It("should return the difference of two sets", func() { + s1 := util.NewSet(0, 1, 2, 3, 4) + s2 := util.NewSet(3, 4, 5, 6, 7) + s3 := s1.Difference(s2) + Expect(s3.Items()).To(ConsistOf(0, 1, 2)) + }) +}) diff --git a/pkg/util/util_suite_test.go b/pkg/util/util_suite_test.go new file mode 100644 index 00000000..6d6903e9 --- /dev/null +++ b/pkg/util/util_suite_test.go @@ -0,0 +1,13 @@ +package util_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestUtil(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Util Suite") +}