diff --git a/Makefile b/Makefile index b45da9fd..1417711c 100644 --- a/Makefile +++ b/Makefile @@ -77,7 +77,7 @@ test-twitter: docker-build-test @docker run --user root $(ENV_FILE_ARG) -v $(PWD)/.masa:/home/masa -v $(PWD)/coverage:/app/coverage --rm --workdir /app -e DATA_DIR=/home/masa $(TEST_IMAGE) go test -v ./internal/jobs/twitter_test.go ./internal/jobs/jobs_suite_test.go test-tiktok: docker-build-test - @docker run --user root $(ENV_FILE_ARG) -v $(PWD)/.masa:/home/masa -v $(PWD)/coverage:/app/coverage --rm --workdir /app -e DATA_DIR=/home/masa $(TEST_IMAGE) go test -v ./internal/jobs/tiktok_transcription_test.go ./internal/jobs/jobs_suite_test.go + @docker run --user root $(ENV_FILE_ARG) -v $(PWD)/.masa:/home/masa -v $(PWD)/coverage:/app/coverage --rm --workdir /app -e DATA_DIR=/home/masa $(TEST_IMAGE) go test -v ./internal/jobs/tiktok_test.go ./internal/jobs/jobs_suite_test.go test-reddit: docker-build-test @docker run --user root $(ENV_FILE_ARG) -v $(PWD)/.masa:/home/masa -v $(PWD)/coverage:/app/coverage --rm --workdir /app -e DATA_DIR=/home/masa $(TEST_IMAGE) go test -v ./internal/jobs/reddit_test.go ./internal/jobs/redditapify/client_test.go ./api/types/reddit/reddit_suite_test.go diff --git a/internal/capabilities/capabilities_suite_test.go b/internal/capabilities/capabilities_suite_test.go new file mode 100644 index 00000000..224f32a0 --- /dev/null +++ b/internal/capabilities/capabilities_suite_test.go @@ -0,0 +1,13 @@ +package capabilities_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestCapabilities(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Capabilities Suite") +} diff --git a/internal/capabilities/detector.go b/internal/capabilities/detector.go index 2eef55e5..abbafcb5 100644 --- a/internal/capabilities/detector.go +++ b/internal/capabilities/detector.go @@ -4,10 +4,14 @@ import ( "slices" "strings" + "maps" + + util "github.com/masa-finance/tee-types/pkg/util" teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/jobs/twitter" - "maps" + "github.com/masa-finance/tee-worker/pkg/client" + "github.com/sirupsen/logrus" ) // JobServerInterface defines the methods we need from JobServer to avoid circular dependencies @@ -37,7 +41,7 @@ func DetectCapabilities(jc types.JobConfiguration, jobServer JobServerInterface) hasAccounts := len(accounts) > 0 hasApiKeys := len(apiKeys) > 0 - hasApifyKey := apifyApiKey != "" + hasApifyKey := hasValidApifyKey(apifyApiKey) // Add Twitter-specific capabilities based on available authentication if hasAccounts { @@ -62,9 +66,17 @@ func DetectCapabilities(jc types.JobConfiguration, jobServer JobServerInterface) if hasApifyKey { capabilities[teetypes.TwitterApifyJob] = teetypes.TwitterApifyCaps capabilities[teetypes.RedditJob] = teetypes.RedditCaps + + // Merge TikTok search caps with any existing + existing := capabilities[teetypes.TiktokJob] + s := util.NewSet(existing...) + s.Add(teetypes.TiktokSearchCaps...) + capabilities[teetypes.TiktokJob] = s.Items() + } // Add general TwitterJob capability if any Twitter auth is available + // TODO: this will get cleaned up with unique twitter capabilities if hasAccounts || hasApiKeys || hasApifyKey { var twitterJobCaps []teetypes.Capability // Use the most comprehensive capabilities available @@ -123,3 +135,25 @@ func parseApiKeys(apiKeys []string) []*twitter.TwitterApiKey { } return result } + +// hasValidApifyKey checks if the provided Apify API key is valid by attempting to validate it +func hasValidApifyKey(apifyApiKey string) bool { + if apifyApiKey == "" { + return false + } + + // Create temporary Apify client and validate the key + apifyClient, err := client.NewApifyClient(apifyApiKey) + if err != nil { + logrus.Errorf("Failed to create Apify client during capability detection: %v", err) + return false + } + + if err := apifyClient.ValidateApiKey(); err != nil { + logrus.Errorf("Apify API key validation failed during capability detection: %v", err) + return false + } + + logrus.Infof("Apify API key validated successfully during capability detection") + return true +} diff --git a/internal/capabilities/detector_test.go b/internal/capabilities/detector_test.go index 7f9e1035..2a2cbb0c 100644 --- a/internal/capabilities/detector_test.go +++ b/internal/capabilities/detector_test.go @@ -1,12 +1,15 @@ -package capabilities +package capabilities_test import ( - "reflect" + "os" "slices" - "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" + . "github.com/masa-finance/tee-worker/internal/capabilities" ) // MockJobServer implements JobServerInterface for testing @@ -18,17 +21,32 @@ func (m *MockJobServer) GetWorkerCapabilities() teetypes.WorkerCapabilities { return m.capabilities } -func TestDetectCapabilities(t *testing.T) { - tests := []struct { - name string - jc types.JobConfiguration - jobServer JobServerInterface - expected teetypes.WorkerCapabilities - }{ - { - name: "With JobServer - gets capabilities from workers", - jc: types.JobConfiguration{}, - jobServer: &MockJobServer{ +var _ = Describe("DetectCapabilities", func() { + DescribeTable("capability detection scenarios", + func(jc types.JobConfiguration, jobServer JobServerInterface, expected teetypes.WorkerCapabilities) { + got := DetectCapabilities(jc, jobServer) + + // Extract job type keys and sort for consistent comparison + gotKeys := make([]string, 0, len(got)) + for jobType := range got { + gotKeys = append(gotKeys, jobType.String()) + } + + expectedKeys := make([]string, 0, len(expected)) + for jobType := range expected { + expectedKeys = append(expectedKeys, jobType.String()) + } + + // Sort both slices for comparison + slices.Sort(gotKeys) + slices.Sort(expectedKeys) + + // Compare the sorted slices + Expect(gotKeys).To(Equal(expectedKeys)) + }, + Entry("With JobServer - gets capabilities from workers", + types.JobConfiguration{}, + &MockJobServer{ capabilities: teetypes.WorkerCapabilities{ teetypes.WebJob: {teetypes.CapScraper}, teetypes.TelemetryJob: {teetypes.CapTelemetry}, @@ -36,58 +54,54 @@ func TestDetectCapabilities(t *testing.T) { teetypes.TwitterJob: {teetypes.CapSearchByQuery, teetypes.CapGetById, teetypes.CapGetProfileById}, }, }, - expected: teetypes.WorkerCapabilities{ + teetypes.WorkerCapabilities{ teetypes.WebJob: {teetypes.CapScraper}, teetypes.TelemetryJob: {teetypes.CapTelemetry}, teetypes.TiktokJob: {teetypes.CapTranscription}, teetypes.TwitterJob: {teetypes.CapSearchByQuery, teetypes.CapGetById, teetypes.CapGetProfileById}, }, - }, - { - name: "Without JobServer - basic capabilities only", - jc: types.JobConfiguration{}, - jobServer: nil, - expected: teetypes.WorkerCapabilities{ + ), + Entry("Without JobServer - basic capabilities only", + types.JobConfiguration{}, + nil, + teetypes.WorkerCapabilities{ teetypes.WebJob: {teetypes.CapScraper}, teetypes.TelemetryJob: {teetypes.CapTelemetry}, teetypes.TiktokJob: {teetypes.CapTranscription}, }, - }, - { - name: "With Twitter accounts - adds credential capabilities", - jc: types.JobConfiguration{ + ), + Entry("With Twitter accounts - adds credential capabilities", + types.JobConfiguration{ "twitter_accounts": []string{"account1", "account2"}, }, - jobServer: nil, - expected: teetypes.WorkerCapabilities{ + nil, + teetypes.WorkerCapabilities{ teetypes.WebJob: {teetypes.CapScraper}, teetypes.TelemetryJob: {teetypes.CapTelemetry}, teetypes.TiktokJob: {teetypes.CapTranscription}, teetypes.TwitterCredentialJob: teetypes.TwitterCredentialCaps, teetypes.TwitterJob: teetypes.TwitterCredentialCaps, }, - }, - { - name: "With Twitter API keys - adds API capabilities", - jc: types.JobConfiguration{ + ), + Entry("With Twitter API keys - adds API capabilities", + types.JobConfiguration{ "twitter_api_keys": []string{"key1", "key2"}, }, - jobServer: nil, - expected: teetypes.WorkerCapabilities{ + nil, + teetypes.WorkerCapabilities{ teetypes.WebJob: {teetypes.CapScraper}, teetypes.TelemetryJob: {teetypes.CapTelemetry}, teetypes.TiktokJob: {teetypes.CapTranscription}, teetypes.TwitterApiJob: teetypes.TwitterAPICaps, teetypes.TwitterJob: teetypes.TwitterAPICaps, }, - }, - { - name: "With mock elevated Twitter API keys - only basic capabilities detected", - jc: types.JobConfiguration{ + ), + Entry("With mock elevated Twitter API keys - only basic capabilities detected", + types.JobConfiguration{ "twitter_api_keys": []string{"Bearer abcd1234-ELEVATED"}, }, - jobServer: nil, - expected: teetypes.WorkerCapabilities{ + nil, + teetypes.WorkerCapabilities{ teetypes.WebJob: {teetypes.CapScraper}, teetypes.TelemetryJob: {teetypes.CapTelemetry}, teetypes.TiktokJob: {teetypes.CapTranscription}, @@ -95,82 +109,81 @@ func TestDetectCapabilities(t *testing.T) { teetypes.TwitterApiJob: teetypes.TwitterAPICaps, teetypes.TwitterJob: teetypes.TwitterAPICaps, }, - }, - } + ), + ) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := DetectCapabilities(tt.jc, tt.jobServer) + Context("Scraper Types", func() { + DescribeTable("scraper type detection", + func(jc types.JobConfiguration, expectedKeys []string) { + caps := DetectCapabilities(jc, nil) - // Extract job type keys and sort for consistent comparison - gotKeys := make([]string, 0, len(got)) - for jobType := range got { - gotKeys = append(gotKeys, jobType.String()) - } + jobNames := make([]string, 0, len(caps)) + for jobType := range caps { + jobNames = append(jobNames, jobType.String()) + } - expectedKeys := make([]string, 0, len(tt.expected)) - for jobType := range tt.expected { - expectedKeys = append(expectedKeys, jobType.String()) - } + // Sort both slices for comparison + slices.Sort(jobNames) + expectedSorted := make([]string, len(expectedKeys)) + copy(expectedSorted, expectedKeys) + slices.Sort(expectedSorted) - // Sort both slices for comparison - slices.Sort(gotKeys) - slices.Sort(expectedKeys) + // Compare the sorted slices + Expect(jobNames).To(Equal(expectedSorted)) + }, + Entry("Basic scrapers only", + types.JobConfiguration{}, + []string{"web", "telemetry", "tiktok"}, + ), + Entry("With Twitter accounts", + types.JobConfiguration{ + "twitter_accounts": []string{"user1:pass1"}, + }, + []string{"web", "telemetry", "tiktok", "twitter", "twitter-credential"}, + ), + Entry("With Twitter API keys", + types.JobConfiguration{ + "twitter_api_keys": []string{"key1"}, + }, + []string{"web", "telemetry", "tiktok", "twitter", "twitter-api"}, + ), + ) + }) - // Compare the sorted slices - if !reflect.DeepEqual(gotKeys, expectedKeys) { - t.Errorf("DetectCapabilities() job types = %v, want %v", gotKeys, expectedKeys) + Context("Apify Integration", func() { + It("should add enhanced capabilities when valid Apify API key is provided", func() { + apifyKey := os.Getenv("APIFY_API_KEY") + if apifyKey == "" { + Skip("APIFY_API_KEY is not set") } - }) - } -} -func TestDetectCapabilities_ScraperTypes(t *testing.T) { - tests := []struct { - name string - jc types.JobConfiguration - expectedKeys []string // scraper names we expect - }{ - { - name: "Basic scrapers only", - jc: types.JobConfiguration{}, - expectedKeys: []string{"web", "telemetry", "tiktok"}, - }, - { - name: "With Twitter accounts", - jc: types.JobConfiguration{ - "twitter_accounts": []string{"user1:pass1"}, - }, - expectedKeys: []string{"web", "telemetry", "tiktok", "twitter", "twitter-credential"}, - }, - { - name: "With Twitter API keys", - jc: types.JobConfiguration{ - "twitter_api_keys": []string{"key1"}, - }, - expectedKeys: []string{"web", "telemetry", "tiktok", "twitter", "twitter-api"}, - }, - } + jc := types.JobConfiguration{ + "apify_api_key": apifyKey, + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - caps := DetectCapabilities(tt.jc, nil) + caps := DetectCapabilities(jc, nil) - jobNames := make([]string, 0, len(caps)) - for jobType := range caps { - jobNames = append(jobNames, jobType.String()) - } + // TikTok should gain search capabilities with valid key + tiktokCaps, ok := caps[teetypes.TiktokJob] + Expect(ok).To(BeTrue(), "expected tiktok capabilities to be present") + Expect(tiktokCaps).To(ContainElement(teetypes.CapSearchByQuery), "expected tiktok to include CapSearchByQuery capability") + Expect(tiktokCaps).To(ContainElement(teetypes.CapSearchByTrending), "expected tiktok to include CapSearchByTrending capability") - // Sort both slices for comparison - slices.Sort(jobNames) - expectedSorted := make([]string, len(tt.expectedKeys)) - copy(expectedSorted, tt.expectedKeys) - slices.Sort(expectedSorted) + // Twitter-Apify job should be present with follower/following capabilities + twitterApifyCaps, ok := caps[teetypes.TwitterApifyJob] + Expect(ok).To(BeTrue(), "expected twitter-apify capabilities to be present") + Expect(twitterApifyCaps).To(ContainElement(teetypes.CapGetFollowers), "expected twitter-apify to include CapGetFollowers capability") + Expect(twitterApifyCaps).To(ContainElement(teetypes.CapGetFollowing), "expected twitter-apify to include CapGetFollowing capability") - // Compare the sorted slices - if !reflect.DeepEqual(jobNames, expectedSorted) { - t.Errorf("Expected capabilities %v, got %v", expectedSorted, jobNames) - } + // Reddit should be present + _, hasReddit := caps[teetypes.RedditJob] + Expect(hasReddit).To(BeTrue(), "expected reddit capabilities to be present") }) - } + }) +}) + +// Helper function to check if a job type exists in capabilities +func hasJobType(capabilities teetypes.WorkerCapabilities, jobName string) bool { + _, exists := capabilities[teetypes.JobType(jobName)] + return exists } diff --git a/internal/jobs/stats/stats.go b/internal/jobs/stats/stats.go index b87c4b0d..8f91bc72 100644 --- a/internal/jobs/stats/stats.go +++ b/internal/jobs/stats/stats.go @@ -24,12 +24,16 @@ const ( TwitterErrors StatType = "twitter_errors" TwitterAuthErrors StatType = "twitter_auth_errors" TwitterRateErrors StatType = "twitter_ratelimit_errors" - TwitterXSearchQueries StatType = "twitterx_search" + TwitterXSearchQueries StatType = "twitterx_search" // TODO: investigate if this is needed or used... WebSuccess StatType = "web_success" WebErrors StatType = "web_errors" WebInvalid StatType = "web_invalid" TikTokTranscriptionSuccess StatType = "tiktok_transcription_success" TikTokTranscriptionErrors StatType = "tiktok_transcription_errors" + TikTokVideos StatType = "tiktok_returned_videos" + TikTokQueries StatType = "tiktok_queries" + TikTokErrors StatType = "tiktok_errors" + TikTokAuthErrors StatType = "tiktok_auth_errors" RedditReturnedItems StatType = "reddit_returned_items" RedditQueries StatType = "reddit_queries" RedditErrors StatType = "reddit_errors" diff --git a/internal/jobs/tiktok_transcription.go b/internal/jobs/tiktok.go similarity index 72% rename from internal/jobs/tiktok_transcription.go rename to internal/jobs/tiktok.go index aebbde00..73949af2 100644 --- a/internal/jobs/tiktok_transcription.go +++ b/internal/jobs/tiktok.go @@ -13,6 +13,8 @@ import ( teetypes "github.com/masa-finance/tee-types/types" "github.com/masa-finance/tee-worker/api/types" "github.com/masa-finance/tee-worker/internal/jobs/stats" + "github.com/masa-finance/tee-worker/internal/jobs/tiktokapify" + "github.com/masa-finance/tee-worker/pkg/client" "github.com/sirupsen/logrus" ) @@ -27,6 +29,7 @@ type TikTokTranscriptionConfiguration struct { APIReferer string `json:"tiktok_api_referer,omitempty"` APIUserAgent string `json:"tiktok_api_user_agent,omitempty"` DefaultLanguage string `json:"tiktok_default_language,omitempty"` // e.g., "eng-US" + ApifyApiKey string `json:"apify_api_key,omitempty"` } // TikTokTranscriber is the main job struct for handling TikTok transcriptions. @@ -38,8 +41,13 @@ type TikTokTranscriber struct { // GetStructuredCapabilities returns the structured capabilities supported by the TikTok transcriber func (t *TikTokTranscriber) GetStructuredCapabilities() teetypes.WorkerCapabilities { + caps := make([]teetypes.Capability, 0, len(teetypes.AlwaysAvailableTiktokCaps)+len(teetypes.TiktokSearchCaps)) + caps = append(caps, teetypes.AlwaysAvailableTiktokCaps...) + if t.configuration.ApifyApiKey != "" { + caps = append(caps, teetypes.TiktokSearchCaps...) + } return teetypes.WorkerCapabilities{ - teetypes.TiktokJob: teetypes.AlwaysAvailableTiktokCaps, + teetypes.TiktokJob: caps, } } @@ -55,38 +63,33 @@ func NewTikTokTranscriber(jc types.JobConfiguration, statsCollector *stats.Stats // Get configurable values from job configuration if err := jc.Unmarshal(&config); err != nil { - logrus.WithError(err).Debug("TikTokTranscriber: Could not unmarshal job configuration, using all defaults") - } - - // Set defaults for configurable values if not provided - if config.DefaultLanguage == "" { - config.DefaultLanguage = "eng-US" + logrus.WithError(err).Warn("failed to unmarshal TikTokTranscriptionConfiguration from JobConfiguration, using defaults where applicable") } + // Get Apify key from configuration (validation now handled at startup by capability detection) + config.ApifyApiKey = jc.GetString("apify_api_key", config.ApifyApiKey) + // Note: APIUserAgent is optional, it can be set later or use a default if config.APIUserAgent == "" { - config.APIUserAgent = "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/135.0.0.0 Mobile Safari/537.36" + config.APIUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36" } - // Log the actual configuration values being used - logrus.WithFields(logrus.Fields{ - "transcription_endpoint": config.TranscriptionEndpoint, - "api_origin": config.APIOrigin, - "api_referer": config.APIReferer, - "api_user_agent": config.APIUserAgent, - "default_language": config.DefaultLanguage, - }).Info("TikTokTranscriber initialized with configuration") - - httpClient := &http.Client{ - Timeout: 30 * time.Second, // Sensible default timeout + // If a default language is set in the configuration, use it + if config.DefaultLanguage == "" { + config.DefaultLanguage = "eng-US" } return &TikTokTranscriber{ configuration: config, stats: statsCollector, - httpClient: httpClient, + httpClient: &http.Client{Timeout: 30 * time.Second}, } } +// NewTikTokScraper is an alias constructor to align with Twitter's naming pattern +func NewTikTokScraper(jc types.JobConfiguration, statsCollector *stats.StatsCollector) *TikTokTranscriber { + return NewTikTokTranscriber(jc, statsCollector) +} + // APIResponse is used to unmarshal the JSON response from the transcription API. type APIResponse struct { VideoTitle string `json:"videoTitle"` @@ -97,6 +100,28 @@ type APIResponse struct { // ExecuteJob processes a single TikTok transcription job. func (ttt *TikTokTranscriber) ExecuteJob(j types.Job) (types.JobResult, error) { + logrus.WithField("job_uuid", j.UUID).Info("Starting ExecuteJob for TikTok job") + + // Use the centralized type-safe unmarshaller + jobArgs, err := teeargs.UnmarshalJobArguments(teetypes.JobType(j.Type), map[string]any(j.Arguments)) + if err != nil { + return types.JobResult{Error: "Failed to unmarshal job arguments"}, fmt.Errorf("unmarshal job arguments: %w", err) + } + + // Branch by argument type (transcription vs search) + if transcriptionArgs, ok := jobArgs.(*teeargs.TikTokTranscriptionArguments); ok { + return ttt.executeTranscription(j, transcriptionArgs) + } else if searchByQueryArgs, ok := jobArgs.(*teeargs.TikTokSearchByQueryArguments); ok { + return ttt.executeSearchByQuery(j, searchByQueryArgs) + } else if searchByTrendingArgs, ok := jobArgs.(*teeargs.TikTokSearchByTrendingArguments); ok { + return ttt.executeSearchByTrending(j, searchByTrendingArgs) + } else { + return types.JobResult{Error: "invalid argument type for TikTok job"}, fmt.Errorf("invalid argument type") + } +} + +// executeTranscription calls the external transcription service and returns a normalized result +func (ttt *TikTokTranscriber) executeTranscription(j types.Job, a *teeargs.TikTokTranscriptionArguments) (types.JobResult, error) { logrus.WithField("job_uuid", j.UUID).Info("Starting ExecuteJob for TikTok transcription") if ttt.configuration.TranscriptionEndpoint == "" { @@ -260,6 +285,68 @@ func (ttt *TikTokTranscriber) ExecuteJob(j types.Job) (types.JobResult, error) { return types.JobResult{Data: jsonData}, nil } +// executeSearchByQuery runs the epctex/tiktok-search-scraper actor and returns results +func (ttt *TikTokTranscriber) executeSearchByQuery(j types.Job, a *teeargs.TikTokSearchByQueryArguments) (types.JobResult, error) { + c, err := tiktokapify.NewTikTokApifyClient(ttt.configuration.ApifyApiKey) + if err != nil { + ttt.stats.Add(j.WorkerID, stats.TikTokAuthErrors, 1) + return types.JobResult{Error: "Failed to create Apify client"}, fmt.Errorf("apify client: %w", err) + } + + limit := a.MaxItems + if limit <= 0 { + limit = 20 + } + + items, next, err := c.SearchByQuery(*a, client.EmptyCursor, limit) + if err != nil { + ttt.stats.Add(j.WorkerID, stats.TikTokErrors, 1) + return types.JobResult{Error: err.Error()}, err + } + + data, err := json.Marshal(items) + if err != nil { + // Do not increment error stats for marshal errors; not the worker's fault + return types.JobResult{Error: "Failed to marshal results"}, fmt.Errorf("marshal results: %w", err) + } + + // Increment returned videos based on the number of items + ttt.stats.Add(j.WorkerID, stats.TikTokVideos, uint(len(items))) + ttt.stats.Add(j.WorkerID, stats.TikTokQueries, 1) + return types.JobResult{Data: data, NextCursor: next.String()}, nil +} + +// executeSearchByTrending runs the lexis-solutions/tiktok-trending-videos-scraper actor and returns results +func (ttt *TikTokTranscriber) executeSearchByTrending(j types.Job, a *teeargs.TikTokSearchByTrendingArguments) (types.JobResult, error) { + c, err := tiktokapify.NewTikTokApifyClient(ttt.configuration.ApifyApiKey) + if err != nil { + ttt.stats.Add(j.WorkerID, stats.TikTokAuthErrors, 1) + return types.JobResult{Error: "Failed to create Apify client"}, fmt.Errorf("apify client: %w", err) + } + + limit := a.MaxItems + if limit <= 0 { + limit = 20 + } + + items, next, err := c.SearchByTrending(*a, client.EmptyCursor, uint(limit)) + if err != nil { + ttt.stats.Add(j.WorkerID, stats.TikTokErrors, 1) + return types.JobResult{Error: err.Error()}, err + } + + data, err := json.Marshal(items) + if err != nil { + // Do not increment error stats for marshal errors; not the worker's fault + return types.JobResult{Error: "Failed to marshal results"}, fmt.Errorf("marshal results: %w", err) + } + + // Increment returned videos based on the number of items + ttt.stats.Add(j.WorkerID, stats.TikTokVideos, uint(len(items))) + ttt.stats.Add(j.WorkerID, stats.TikTokQueries, 1) + return types.JobResult{Data: data, NextCursor: next.String()}, nil +} + // convertVTTToPlainText parses a VTT string and extracts the dialogue lines. // This is a basic implementation and might need to be made more robust. func convertVTTToPlainText(vttContent string) (string, error) { diff --git a/internal/jobs/tiktok_transcription_test.go b/internal/jobs/tiktok_test.go similarity index 54% rename from internal/jobs/tiktok_transcription_test.go rename to internal/jobs/tiktok_test.go index b9f10191..8490d1ad 100644 --- a/internal/jobs/tiktok_transcription_test.go +++ b/internal/jobs/tiktok_test.go @@ -2,6 +2,8 @@ package jobs_test import ( "encoding/json" + "fmt" + "os" "strings" "time" @@ -15,7 +17,7 @@ import ( "github.com/sirupsen/logrus" ) -var _ = Describe("TikTokTranscriber", func() { +var _ = Describe("TikTok", func() { var statsCollector *stats.StatsCollector var tikTokTranscriber *TikTokTranscriber var jobConfig types.JobConfiguration @@ -43,6 +45,7 @@ var _ = Describe("TikTokTranscriber", func() { It("should successfully transcribe the video and record success stats", func(ctx SpecContext) { videoURL := "https://www.tiktok.com/@.jake.ai/video/7516694182245813509" jobArguments := map[string]interface{}{ + "type": teetypes.CapTranscription, "video_url": videoURL, // default language is eng-US from tee types } @@ -113,6 +116,7 @@ var _ = Describe("TikTokTranscriber", func() { Context("when arguments are invalid", func() { It("should return an error if VideoURL is empty and not record error stats", func() { jobArguments := map[string]interface{}{ + "type": teetypes.CapTranscription, "video_url": "", // Empty URL } @@ -156,4 +160,172 @@ var _ = Describe("TikTokTranscriber", func() { }, 5*time.Second, 100*time.Millisecond).Should(BeNumerically("==", 0), "TikTokTranscriptionSuccess count should be 0") }) }) + + Context("TikTok Apify search", func() { + It("should search by query via Apify", func() { + apifyKey := os.Getenv("APIFY_API_KEY") + if apifyKey == "" { + Skip("APIFY_API_KEY is not set") + } + + jobConfig := types.JobConfiguration{ + "apify_api_key": apifyKey, + } + t := NewTikTokTranscriber(jobConfig, statsCollector) + + j := types.Job{ + Type: teetypes.TiktokJob, + Arguments: map[string]interface{}{ + "type": teetypes.CapSearchByQuery, + "search": []string{"crypto", "ai"}, + "max_items": 5, + "end_page": 1, + "proxy": map[string]any{"use_apify_proxy": true}, + }, + WorkerID: "tiktok-test-worker-search-query", + Timeout: 60 * time.Second, + } + + res, err := t.ExecuteJob(j) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Error).To(BeEmpty()) + + var items []*teetypes.TikTokSearchByQueryResult + err = json.Unmarshal(res.Data, &items) + Expect(err).NotTo(HaveOccurred()) + Expect(items).NotTo(BeEmpty()) + + for _, item := range items { + fmt.Println("Video: ", item.URL) + } + + expectedCount := uint(len(items)) + Eventually(func() uint { + if statsCollector == nil || statsCollector.Stats == nil || statsCollector.Stats.Stats == nil { + return 0 + } + workerStatsMap := statsCollector.Stats.Stats[j.WorkerID] + if workerStatsMap == nil { + return 0 + } + return workerStatsMap[stats.TikTokVideos] + }, 15*time.Second, 250*time.Millisecond).Should(BeNumerically("==", expectedCount), "TikTokVideos count should equal returned items") + + Eventually(func() uint { + if statsCollector == nil || statsCollector.Stats == nil || statsCollector.Stats.Stats == nil { + return 0 + } + workerStatsMap := statsCollector.Stats.Stats[j.WorkerID] + if workerStatsMap == nil { + return 0 + } + return workerStatsMap[stats.TikTokErrors] + }, 5*time.Second, 100*time.Millisecond).Should(BeNumerically("==", 0), "TikTokErrors should be 0 on success") + }) + + It("should search trending via Apify", func() { + apifyKey := os.Getenv("APIFY_API_KEY") + if apifyKey == "" { + Skip("APIFY_API_KEY is not set") + } + + jobConfig := types.JobConfiguration{ + "apify_api_key": apifyKey, + } + t := NewTikTokTranscriber(jobConfig, statsCollector) + + j := types.Job{ + Type: teetypes.TiktokJob, + Arguments: map[string]interface{}{ + "type": teetypes.CapSearchByTrending, + "country_code": "US", + "sort_by": "repost", + "max_items": 5, + "period": "7", + }, + WorkerID: "tiktok-test-worker-search-trending", + Timeout: 60 * time.Second, + } + + res, err := t.ExecuteJob(j) + Expect(err).NotTo(HaveOccurred()) + Expect(res.Error).To(BeEmpty()) + + var items []*teetypes.TikTokSearchByTrending + err = json.Unmarshal(res.Data, &items) + Expect(err).NotTo(HaveOccurred()) + Expect(items).NotTo(BeEmpty()) + + for _, item := range items { + fmt.Println("Video: ", item.Title) + } + + expectedCount := uint(len(items)) + Eventually(func() uint { + if statsCollector == nil || statsCollector.Stats == nil || statsCollector.Stats.Stats == nil { + return 0 + } + workerStatsMap := statsCollector.Stats.Stats[j.WorkerID] + if workerStatsMap == nil { + return 0 + } + return workerStatsMap[stats.TikTokVideos] + }, 15*time.Second, 250*time.Millisecond).Should(BeNumerically("==", expectedCount), "TikTokVideos count should equal returned items") + + Eventually(func() uint { + if statsCollector == nil || statsCollector.Stats == nil || statsCollector.Stats.Stats == nil { + return 0 + } + workerStatsMap := statsCollector.Stats.Stats[j.WorkerID] + if workerStatsMap == nil { + return 0 + } + return workerStatsMap[stats.TikTokErrors] + }, 5*time.Second, 100*time.Millisecond).Should(BeNumerically("==", 0), "TikTokErrors should be 0 on success") + }) + + It("should increment TikTokErrors when Apify key is missing", func() { + // No APIFY_API_KEY provided in config + jobConfig := types.JobConfiguration{} + t := NewTikTokTranscriber(jobConfig, statsCollector) + + j := types.Job{ + Type: teetypes.TiktokJob, + Arguments: map[string]interface{}{ + "type": teetypes.CapSearchByQuery, + "search": []string{"tiktok"}, + "max_items": 1, + "end_page": 1, + }, + WorkerID: "tiktok-test-worker-missing-key", + Timeout: 10 * time.Second, + } + + res, err := t.ExecuteJob(j) + Expect(err).To(HaveOccurred()) + Expect(res.Error).NotTo(BeEmpty()) + + Eventually(func() uint { + if statsCollector == nil || statsCollector.Stats == nil || statsCollector.Stats.Stats == nil { + return 0 + } + workerStatsMap := statsCollector.Stats.Stats[j.WorkerID] + if workerStatsMap == nil { + return 0 + } + return workerStatsMap[stats.TikTokErrors] + }, 5*time.Second, 100*time.Millisecond).Should(BeNumerically("==", 1), "TikTokErrors should increment by 1 for missing API key") + + Consistently(func() uint { + if statsCollector == nil || statsCollector.Stats == nil || statsCollector.Stats.Stats == nil { + return 0 + } + workerStatsMap := statsCollector.Stats.Stats[j.WorkerID] + if workerStatsMap == nil { + return 0 + } + return workerStatsMap[stats.TikTokVideos] + }, 1*time.Second, 100*time.Millisecond).Should(BeNumerically("==", 0), "TikTokVideos should remain 0 on error") + }) + }) }) diff --git a/internal/jobs/tiktokapify/client.go b/internal/jobs/tiktokapify/client.go new file mode 100644 index 00000000..63a88a8b --- /dev/null +++ b/internal/jobs/tiktokapify/client.go @@ -0,0 +1,132 @@ +package tiktokapify + +import ( + "encoding/json" + "fmt" + + teeargs "github.com/masa-finance/tee-types/args" + teetypes "github.com/masa-finance/tee-types/types" + "github.com/masa-finance/tee-worker/pkg/client" +) + +const ( + // Actors + SearchActorID = "epctex~tiktok-search-scraper" // must rent this actor from apify explicitly + TrendingActorID = "lexis-solutions~tiktok-trending-videos-scraper" // must rent this actor from apify explicitly +) + +type TikTokSearchByQueryRequest struct { + SearchTerms []string `json:"search"` + StartUrls []string `json:"startUrls"` + MaxItems uint `json:"maxItems"` + EndPage uint `json:"endPage"` + Proxy map[string]any `json:"proxy"` +} + +type TikTokSearchByTrendingRequest struct { + CountryCode string `json:"countryCode"` + SortBy string `json:"sortBy"` + MaxItems uint `json:"maxItems"` + Period string `json:"period"` +} + +type TikTokApifyClient struct { + apify client.Apify +} + +func NewTikTokApifyClient(apiToken string) (*TikTokApifyClient, error) { + apifyClient, err := client.NewApifyClient(apiToken) + if err != nil { + return nil, fmt.Errorf("failed to create Apify client: %w", err) + } + return &TikTokApifyClient{apify: apifyClient}, nil +} + +// ValidateApiKey validates the underlying Apify API token +func (c *TikTokApifyClient) ValidateApiKey() error { + return c.apify.ValidateApiKey() +} + +// SearchByQuery runs the search actor and returns typed results +func (c *TikTokApifyClient) SearchByQuery(input teeargs.TikTokSearchByQueryArguments, cursor client.Cursor, limit uint) ([]*teetypes.TikTokSearchByQueryResult, client.Cursor, error) { + // Map snake_case fields to Apify actor's expected camelCase input + startUrls := input.StartUrls + if startUrls == nil { + startUrls = []string{} + } + searchTerms := input.Search + if searchTerms == nil { + searchTerms = []string{} + } + + // Create structured request using the TikTokSearchByQueryRequest struct + request := TikTokSearchByQueryRequest{ + SearchTerms: searchTerms, + StartUrls: startUrls, + MaxItems: input.MaxItems, + EndPage: input.EndPage, + Proxy: map[string]any{"useApifyProxy": true}, + } + + // Convert struct to map[string]any for Apify client + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, "", fmt.Errorf("failed to marshal request: %w", err) + } + + var apifyInput map[string]any + if err := json.Unmarshal(requestBytes, &apifyInput); err != nil { + return nil, "", fmt.Errorf("failed to unmarshal to map: %w", err) + } + + dataset, next, err := c.apify.RunActorAndGetResponse(SearchActorID, apifyInput, cursor, limit) + if err != nil { + return nil, "", fmt.Errorf("apify run (search): %w", err) + } + + var results []*teetypes.TikTokSearchByQueryResult + for _, raw := range dataset.Data.Items { + var item teetypes.TikTokSearchByQueryResult + if err := json.Unmarshal(raw, &item); err != nil { + // Skip any items whose structure doesn't match + continue + } + results = append(results, &item) + } + return results, next, nil +} + +// SearchByTrending runs the trending actor and returns typed results +func (c *TikTokApifyClient) SearchByTrending(input teeargs.TikTokSearchByTrendingArguments, cursor client.Cursor, limit uint) ([]*teetypes.TikTokSearchByTrending, client.Cursor, error) { + request := TikTokSearchByTrendingRequest{ + CountryCode: input.CountryCode, + SortBy: input.SortBy, + MaxItems: uint(input.MaxItems), + Period: input.Period, + } + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, "", fmt.Errorf("failed to marshal request: %w", err) + } + + var apifyInput map[string]any + if err := json.Unmarshal(requestBytes, &apifyInput); err != nil { + return nil, "", fmt.Errorf("failed to unmarshal to map: %w", err) + } + + dataset, next, err := c.apify.RunActorAndGetResponse(TrendingActorID, apifyInput, cursor, limit) + if err != nil { + return nil, "", fmt.Errorf("apify run (trending): %w", err) + } + + var results []*teetypes.TikTokSearchByTrending + for _, raw := range dataset.Data.Items { + var item teetypes.TikTokSearchByTrending + if err := json.Unmarshal(raw, &item); err != nil { + continue + } + results = append(results, &item) + } + return results, next, nil +} diff --git a/internal/jobs/twitter.go b/internal/jobs/twitter.go index 8ffe6e71..224d4588 100644 --- a/internal/jobs/twitter.go +++ b/internal/jobs/twitter.go @@ -980,21 +980,6 @@ func NewTwitterScraper(jc types.JobConfiguration, c *stats.StatsCollector) *Twit accountManager := twitter.NewTwitterAccountManager(accounts, apiKeys) accountManager.DetectAllApiKeyTypes() - // Validate Apify API key at startup if provided (similar to API key detection) - // TODO: We should verify whether each of the actors is actually available through this API key - if config.ApifyApiKey != "" { - apifyScraper, err := twitterapify.NewTwitterApifyClient(config.ApifyApiKey) - if err != nil { - logrus.Errorf("Failed to create Apify scraper at startup: %v", err) - // Don't fail startup, just log the error - the key might work later or be temporary - } else if err := apifyScraper.ValidateApiKey(); err != nil { - logrus.Errorf("Apify API key validation failed at startup: %v", err) - // Don't fail startup, just log the error - the key might work later or be temporary - } else { - logrus.Infof("Apify API key validated successfully at startup") - } - } - if os.Getenv("TWITTER_SKIP_LOGIN_VERIFICATION") == "true" { config.SkipLoginVerification = true } diff --git a/internal/jobserver/jobserver.go b/internal/jobserver/jobserver.go index 7030439e..a8fa3307 100644 --- a/internal/jobserver/jobserver.go +++ b/internal/jobserver/jobserver.go @@ -97,7 +97,7 @@ func NewJobServer(workers int, jc types.JobConfiguration) *JobServer { w: jobs.NewTwitterScraper(jc, s), // Register Apify job type with Twitter scraper }, teetypes.TiktokJob: { - w: jobs.NewTikTokTranscriber(jc, s), + w: jobs.NewTikTokScraper(jc, s), }, teetypes.RedditJob: { w: jobs.NewRedditScraper(jc, s),