From 8f9eca99d989e31aaa9d8062f869708d73221d1b Mon Sep 17 00:00:00 2001 From: mcamou Date: Thu, 21 Aug 2025 13:27:44 +0200 Subject: [PATCH 1/9] Add types for Reddit scraper --- .gitignore | 3 + args/args_suite_test.go | 13 +++ args/reddit.go | 150 ++++++++++++++++++++++++++++++ args/reddit_test.go | 178 +++++++++++++++++++++++++++++++++++ args/unmarshaller.go | 30 ++++++ args/unmarshaller_test.go | 105 +++++++++++++++++++++ pkg/util/set.go | 6 +- types/jobs.go | 16 +++- types/reddit.go | 190 ++++++++++++++++++++++++++++++++++++++ types/reddit_test.go | 93 +++++++++++++++++++ types/types_suite_test.go | 13 +++ 11 files changed, 794 insertions(+), 3 deletions(-) create mode 100644 args/args_suite_test.go create mode 100644 args/reddit.go create mode 100644 args/reddit_test.go create mode 100644 args/unmarshaller_test.go create mode 100644 types/reddit.go create mode 100644 types/reddit_test.go create mode 100644 types/types_suite_test.go diff --git a/.gitignore b/.gitignore index 9954081..f5f7bd6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,7 @@ go.work *~ *.log .DS_Store + +# LLM-related files .aider* +GEMINI.md diff --git a/args/args_suite_test.go b/args/args_suite_test.go new file mode 100644 index 0000000..861e0bf --- /dev/null +++ b/args/args_suite_test.go @@ -0,0 +1,13 @@ +package args_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestArgs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Args Suite") +} diff --git a/args/reddit.go b/args/reddit.go new file mode 100644 index 0000000..fad60f6 --- /dev/null +++ b/args/reddit.go @@ -0,0 +1,150 @@ +package args + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/masa-finance/tee-types/pkg/util" + teetypes "github.com/masa-finance/tee-types/types" +) + +var ( + ErrRedditInvalidType = errors.New("invalid type") + ErrRedditInvalidSort = errors.New("invalid sort") + ErrRedditTimeInTheFuture = errors.New("after field is in the future") + ErrRedditNoQueries = errors.New("queries must be provided for all query types except scrapeurls") + ErrRedditNoUrls = errors.New("urls must be provided for scrapeurls query type") + ErrRedditQueriesNotAllowed = errors.New("the scrapeurls query type does not admit queries") + ErrRedditUrlsNotAllowed = errors.New("urls can only be provided for the scrapeurls query type") +) + +const ( + // These reflect the default values in https://apify.com/trudax/reddit-scraper/input-schema + redditDefaultMaxItems = 10 + redditDefaultMaxPosts = 10 + redditDefaultMaxComments = 10 + redditDefaultMaxCommunities = 2 + redditDefaultMaxUsers = 2 + redditDefaultSort = teetypes.RedditSortNew +) + +// RedditArguments defines args for Reddit scrapes +// see https://apify.com/trudax/reddit-scraper +type RedditArguments struct { + QueryType teetypes.RedditQueryType `json:"type"` + Queries []string `json:"queries"` + URLs []teetypes.RedditStartURL `json:"urls"` + Sort teetypes.RedditSortType `json:"sort"` + IncludeNSFW bool `json:"include_nsfw"` + SkipPosts bool `json:"skip_posts"` // Valid only for searchusers + After time.Time `json:"after"` // valid only for scrapeurls and searchposts + MaxItems uint `json:"max_items"` // Max number of items to scrape (total), default 10 + MaxResults uint `json:"max_results"` // Max number of results per page, default MaxItems + MaxPosts uint `json:"max_posts"` // Max number of posts per page, default 10 + MaxComments uint `json:"max_comments"` // Max number of comments per page, default 10 + MaxCommunities uint `json:"max_communities"` // Max number of communities per page, default 2 + MaxUsers uint `json:"max_users"` // Max number of users per page, default 2 + NextCursor string `json:"next_cursor"` +} + +func (r *RedditArguments) UnmarshalJSON(data []byte) error { + type Alias RedditArguments + + // Set default values. They will be overridden if present in the JSON. + r.MaxItems = redditDefaultMaxItems + r.MaxPosts = redditDefaultMaxPosts + r.MaxComments = redditDefaultMaxComments + r.MaxCommunities = redditDefaultMaxCommunities + r.MaxUsers = redditDefaultMaxUsers + r.Sort = redditDefaultSort + + aux := &struct { + *Alias + }{ + Alias: (*Alias)(r), + } + + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("failed to unmarshal Reddit arguments: %w", err) + } + + if r.MaxResults == 0 { + r.MaxResults = r.MaxItems + } + + return r.Validate() +} + +var allowedHttpMethods = util.NewSet("GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS") + +const redditDomainSuffix = "reddit.com" + +func (r *RedditArguments) Validate() error { + var errs []error + if !teetypes.AllRedditQueryTypes.Contains(r.QueryType) { + errs = append(errs, ErrRedditInvalidType) + } + + if !teetypes.AllRedditSortTypes.Contains(r.Sort) { + errs = append(errs, ErrRedditInvalidSort) + } + + if time.Now().Before(r.After) { + errs = append(errs, ErrRedditTimeInTheFuture) + } + + if r.QueryType == teetypes.RedditScrapeUrls { + if len(r.URLs) == 0 { + errs = append(errs, ErrRedditNoUrls) + } + if len(r.Queries) > 0 { + errs = append(errs, ErrRedditQueriesNotAllowed) + } + + for _, q := range r.URLs { + if !allowedHttpMethods.Contains(q.Method) { + errs = append(errs, fmt.Errorf("%s is not a valid HTTP method", q.Method)) + } + u, err := url.Parse(q.URL) + if err != nil { + errs = append(errs, fmt.Errorf("%s is not a valid URL", q.URL)) + } else { + if !strings.HasSuffix(u.Host, redditDomainSuffix) { + errs = append(errs, fmt.Errorf("Invalid Reddit URL %s", q.URL)) + } + } + } + } else { + if len(r.Queries) == 0 { + errs = append(errs, ErrRedditNoQueries) + } + if len(r.URLs) > 0 { + errs = append(errs, ErrRedditUrlsNotAllowed) + } + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + + return nil +} + +// ValidateForJobType validates Twitter arguments for a specific job type +func (r *RedditArguments) ValidateForJobType(jobType teetypes.JobType) error { + if err := r.Validate(); err != nil { + return err + } + + // Validate QueryType against job-specific capabilities + return jobType.ValidateCapability(teetypes.Capability(r.QueryType)) +} + +// GetCapability returns the QueryType as a typed Capability +func (r *RedditArguments) GetCapability() teetypes.Capability { + return teetypes.Capability(r.QueryType) +} diff --git a/args/reddit_test.go b/args/reddit_test.go new file mode 100644 index 0000000..f158de3 --- /dev/null +++ b/args/reddit_test.go @@ -0,0 +1,178 @@ +package args_test + +import ( + "encoding/json" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-types/args" + "github.com/masa-finance/tee-types/types" +) + +var _ = Describe("RedditArguments", func() { + Describe("Unmarshalling", func() { + It("should set default values", func() { + redditArgs := &args.RedditArguments{} + jsonData := `{"type": "searchposts", "queries": ["test"]}` + err := json.Unmarshal([]byte(jsonData), redditArgs) + Expect(err).ToNot(HaveOccurred()) + Expect(redditArgs.MaxItems).To(Equal(uint(10))) + Expect(redditArgs.MaxPosts).To(Equal(uint(10))) + Expect(redditArgs.MaxComments).To(Equal(uint(10))) + Expect(redditArgs.MaxCommunities).To(Equal(uint(2))) + Expect(redditArgs.MaxUsers).To(Equal(uint(2))) + Expect(redditArgs.Sort).To(Equal(types.RedditSortNew)) + Expect(redditArgs.MaxResults).To(Equal(redditArgs.MaxItems)) + }) + + It("should override default values", func() { + redditArgs := &args.RedditArguments{} + jsonData := `{"type": "searchposts", "queries": ["test"], "max_items": 20, "sort": "top"}` + err := json.Unmarshal([]byte(jsonData), redditArgs) + Expect(err).ToNot(HaveOccurred()) + Expect(redditArgs.MaxItems).To(Equal(uint(20))) + Expect(redditArgs.Sort).To(Equal(types.RedditSortTop)) + Expect(redditArgs.MaxResults).To(Equal(uint(20))) + }) + }) + + Describe("Validation", func() { + It("should succeed with valid arguments", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditSearchPosts, + Queries: []string{"test"}, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should succeed with valid scrapeurls arguments", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + URLs: []types.RedditStartURL{ + {URL: "https://www.reddit.com/r/golang/", Method: "GET"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail with an invalid type", func() { + redditArgs := &args.RedditArguments{ + QueryType: "invalidtype", + Queries: []string{"test"}, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditInvalidType)) + }) + + It("should fail with an invalid sort", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditSearchPosts, + Queries: []string{"test"}, + Sort: "invalidsort", + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditInvalidSort)) + }) + + It("should fail if the after time is in the future", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditSearchPosts, + Queries: []string{"test"}, + Sort: types.RedditSortNew, + After: time.Now().Add(24 * time.Hour), + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditTimeInTheFuture)) + }) + + It("should fail if queries are not provided for searchposts", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditSearchPosts, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditNoQueries)) + }) + + It("should fail if urls are not provided for scrapeurls", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditNoUrls)) + }) + + It("should fail if queries are provided for scrapeurls", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + Queries: []string{"test"}, + URLs: []types.RedditStartURL{ + {URL: "https://www.reddit.com/r/golang/", Method: "GET"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditQueriesNotAllowed)) + }) + + It("should fail if urls are provided for searchposts", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditSearchPosts, + Queries: []string{"test"}, + URLs: []types.RedditStartURL{ + {URL: "https://www.reddit.com/r/golang/", Method: "GET"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditUrlsNotAllowed)) + }) + + It("should fail with an invalid URL", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + URLs: []types.RedditStartURL{ + {URL: "ht tp://invalid-url.com", Method: "GET"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("is not a valid URL")) + }) + + It("should fail with an invalid domain", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + URLs: []types.RedditStartURL{ + {URL: "https://www.google.com", Method: "GET"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("Invalid Reddit URL")) + }) + + It("should fail with an invalid HTTP method", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + URLs: []types.RedditStartURL{ + {URL: "https://www.reddit.com/r/golang/", Method: "INVALID"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("is not a valid HTTP method")) + }) + }) +}) diff --git a/args/unmarshaller.go b/args/unmarshaller.go index 353bf8f..15c8f4d 100644 --- a/args/unmarshaller.go +++ b/args/unmarshaller.go @@ -49,6 +49,12 @@ type LinkedInJobArguments interface { ValidateForJobType(jobType types.JobType) error } +// RedditJobArguments extends JobArguments for Reddit-specific methods +type RedditJobArguments interface { + JobArguments + ValidateForJobType(jobType types.JobType) error +} + // UnmarshalJobArguments unmarshals job arguments from a generic map into the appropriate typed struct // This works with both tee-indexer and tee-worker JobArguments types func UnmarshalJobArguments(jobType types.JobType, args map[string]any) (JobArguments, error) { @@ -65,6 +71,9 @@ func UnmarshalJobArguments(jobType types.JobType, args map[string]any) (JobArgum case types.LinkedInJob: return unmarshalLinkedInArguments(jobType, args) + case types.RedditJob: + return unmarshalRedditArguments(jobType, args) + case types.TelemetryJob: return &TelemetryJobArguments{}, nil @@ -132,6 +141,27 @@ func unmarshalLinkedInArguments(jobType types.JobType, args map[string]any) (*Li return linkedInArgs, nil } +func unmarshalRedditArguments(jobType types.JobType, args map[string]any) (*RedditArguments, error) { + redditArgs := &RedditArguments{} + if err := unmarshalToStruct(args, redditArgs); err != nil { + return nil, fmt.Errorf("failed to unmarshal Reddit job arguments: %w", err) + } + + // If no QueryType is specified, use the default capability for this job type + if redditArgs.QueryType == "" { + if defaultCap, exists := types.JobDefaultCapabilityMap[jobType]; exists { + redditArgs.QueryType = types.RedditQueryType(defaultCap) + } + } + + // Perform job-type-specific validation for Reddit + if err := redditArgs.ValidateForJobType(jobType); err != nil { + return nil, fmt.Errorf("reddit job validation failed: %w", err) + } + + return redditArgs, nil +} + // unmarshalToStruct converts a map[string]any to a struct using JSON marshal/unmarshal // This provides the same functionality as the existing JobArguments.Unmarshal methods func unmarshalToStruct(args map[string]any, target any) error { diff --git a/args/unmarshaller_test.go b/args/unmarshaller_test.go new file mode 100644 index 0000000..04e784f --- /dev/null +++ b/args/unmarshaller_test.go @@ -0,0 +1,105 @@ +package args_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-types/args" + "github.com/masa-finance/tee-types/types" +) + +var _ = Describe("Unmarshaller", func() { + Describe("UnmarshalJobArguments", func() { + Context("with a WebJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "url": "https://example.com", + "selector": "h1", + "max_depth": 2, + } + jobArgs, err := args.UnmarshalJobArguments(types.WebJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + webArgs, ok := jobArgs.(*args.WebSearchArguments) + Expect(ok).To(BeTrue()) + Expect(webArgs.URL).To(Equal("https://example.com")) + Expect(webArgs.Selector).To(Equal("h1")) + Expect(webArgs.MaxDepth).To(Equal(2)) + }) + }) + + Context("with a TiktokJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "video_url": "https://www.tiktok.com/@user/video/123", + "language": "en-us", + } + jobArgs, err := args.UnmarshalJobArguments(types.TiktokJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + tiktokArgs, ok := jobArgs.(*args.TikTokTranscriptionArguments) + Expect(ok).To(BeTrue()) + Expect(tiktokArgs.VideoURL).To(Equal("https://www.tiktok.com/@user/video/123")) + Expect(tiktokArgs.Language).To(Equal("en-us")) + }) + }) + + Context("with a TwitterJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "type": "searchbyquery", + "query": "golang", + "count": 10, + } + jobArgs, err := args.UnmarshalJobArguments(types.TwitterJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + twitterArgs, ok := jobArgs.(*args.TwitterSearchArguments) + Expect(ok).To(BeTrue()) + Expect(twitterArgs.QueryType).To(Equal("searchbyquery")) + Expect(twitterArgs.Query).To(Equal("golang")) + Expect(twitterArgs.Count).To(Equal(10)) + }) + + It("should set the default capability for TwitterApifyJob", func() { + argsMap := map[string]any{"query": "masa-finance"} + jobArgs, err := args.UnmarshalJobArguments(types.TwitterApifyJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + twitterArgs, ok := jobArgs.(*args.TwitterSearchArguments) + Expect(ok).To(BeTrue()) + Expect(twitterArgs.GetCapability()).To(Equal(types.CapGetFollowers)) + }) + }) + + Context("with a RedditJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "type": "searchposts", + "queries": []string{"golang"}, + "sort": "new", + } + jobArgs, err := args.UnmarshalJobArguments(types.RedditJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + redditArgs, ok := jobArgs.(*args.RedditArguments) + Expect(ok).To(BeTrue()) + Expect(redditArgs.QueryType).To(Equal(types.RedditQueryType("searchposts"))) + }) + }) + + Context("with a TelemetryJob", func() { + It("should return a TelemetryJobArguments struct", func() { + argsMap := map[string]any{} + jobArgs, err := args.UnmarshalJobArguments(types.TelemetryJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + _, ok := jobArgs.(*args.TelemetryJobArguments) + Expect(ok).To(BeTrue()) + }) + }) + + Context("with an unknown job type", func() { + It("should return an error", func() { + argsMap := map[string]any{} + _, err := args.UnmarshalJobArguments("unknown", argsMap) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unknown job type")) + }) + }) + }) +}) diff --git a/pkg/util/set.go b/pkg/util/set.go index fe237be..33b907d 100644 --- a/pkg/util/set.go +++ b/pkg/util/set.go @@ -23,17 +23,19 @@ func (s *Set[T]) Contains(item T) bool { } // Add inserts the given items into the set, deduplicating them. -func (s *Set[T]) Add(items ...T) { +func (s *Set[T]) Add(items ...T) *Set[T] { for _, item := range items { (*s)[item] = struct{}{} } + return s } // Delete removes the given items from the set if it contains them. -func (s *Set[T]) Delete(items ...T) { +func (s *Set[T]) Delete(items ...T) *Set[T] { for _, item := range items { delete((*s), item) } + return s } // Length returns the number of items in the set. diff --git a/types/jobs.go b/types/jobs.go index 945eb63..9c1cce5 100644 --- a/types/jobs.go +++ b/types/jobs.go @@ -50,6 +50,7 @@ const ( TwitterApiJob JobType = "twitter-api" // Twitter scraping with API keys TwitterApifyJob JobType = "twitter-apify" // Twitter scraping with Apify LinkedInJob JobType = "linkedin" // LinkedIn scraping, keeping for unmarshalling logic + RedditJob JobType = "reddit" // Reddit scraping with Apify ) // Capability constants - typed to prevent typos and enable discoverability @@ -73,7 +74,13 @@ const ( CapGetFollowers Capability = "getfollowers" CapGetSpace Capability = "getspace" CapGetProfile Capability = "getprofile" // LinkedIn get profile capability - CapEmpty Capability = "" + // Reddit capabilities + CapScrapeUrls Capability = "scrapeurls" + CapSearchPosts Capability = "searchposts" + CapSearchUsers Capability = "searchusers" + CapSearchCommunities Capability = "searchcommunities" + + CapEmpty Capability = "" ) // Capability group constants for easy reuse @@ -104,6 +111,9 @@ var ( // TwitterApifyCaps are Twitter capabilities available with Apify TwitterApifyCaps = []Capability{CapGetFollowers, CapGetFollowing, CapEmpty} + + // RedditCaps are all the Reddit capabilities (only available with Apify) + RedditCaps = []Capability{CapScrapeUrls, CapSearchPosts, CapSearchUsers, CapSearchCommunities} ) // JobCapabilityMap defines which capabilities are valid for each job type @@ -128,6 +138,9 @@ var JobCapabilityMap = map[JobType][]Capability{ // TikTok job capabilities TiktokJob: AlwaysAvailableTiktokCaps, + // Reddit job capabilities + RedditJob: RedditCaps, + // Telemetry job capabilities TelemetryJob: AlwaysAvailableTelemetryCaps, } @@ -140,5 +153,6 @@ var JobDefaultCapabilityMap = map[JobType]Capability{ TwitterApifyJob: CapGetFollowers, WebJob: CapScraper, TiktokJob: CapTranscription, + RedditJob: CapScrapeUrls, TelemetryJob: CapTelemetry, } diff --git a/types/reddit.go b/types/reddit.go new file mode 100644 index 0000000..7b2ed5e --- /dev/null +++ b/types/reddit.go @@ -0,0 +1,190 @@ +package types + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/masa-finance/tee-types/pkg/util" +) + +type RedditQueryType string + +const ( + RedditScrapeUrls RedditQueryType = "scrapeurls" + RedditSearchPosts RedditQueryType = "searchposts" + RedditSearchUsers RedditQueryType = "searchusers" + RedditSearchCommunities RedditQueryType = "searchcommunities" +) + +var AllRedditQueryTypes = util.NewSet(RedditScrapeUrls, RedditSearchPosts, RedditSearchUsers, RedditSearchCommunities) + +type RedditSortType string + +const ( + RedditSortRelevance RedditSortType = "relevance" + RedditSortHot RedditSortType = "hot" + RedditSortTop RedditSortType = "top" + RedditSortNew RedditSortType = "new" + RedditSortRising RedditSortType = "rising" + RedditSortComments RedditSortType = "comments" +) + +var AllRedditSortTypes = util.NewSet( + RedditSortRelevance, + RedditSortHot, + RedditSortTop, + RedditSortNew, + RedditSortRising, + RedditSortComments, +) + +// RedditStartURL represents a single start URL for the Apify Reddit scraper. +type RedditStartURL struct { + URL string `json:"url"` + Method string `json:"method"` +} + +type RedditResponseType string + +const ( + RedditUserResponse RedditResponseType = "user" + RedditPostResponse RedditResponseType = "post" + RedditCommentResponse RedditResponseType = "comment" + RedditCommunityResponse RedditResponseType = "community" +) + +// RedditUser represents the data structure for a Reddit user from the Apify scraper. +type RedditUser struct { + ID string `json:"id"` + URL string `json:"url"` + Username string `json:"username"` + UserIcon string `json:"userIcon"` + PostKarma int `json:"postKarma"` + CommentKarma int `json:"commentKarma"` + Description string `json:"description"` + Over18 bool `json:"over18"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + DataType string `json:"dataType"` +} + +// RedditPost represents the data structure for a Reddit post from the Apify scraper. +type RedditPost struct { + ID string `json:"id"` + ParsedID string `json:"parsedId"` + URL string `json:"url"` + Username string `json:"username"` + Title string `json:"title"` + CommunityName string `json:"communityName"` + ParsedCommunityName string `json:"parsedCommunityName"` + Body string `json:"body"` + HTML *string `json:"html"` + NumberOfComments int `json:"numberOfComments"` + UpVotes int `json:"upVotes"` + IsVideo bool `json:"isVideo"` + IsAd bool `json:"isAd"` + Over18 bool `json:"over18"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + DataType string `json:"dataType"` +} + +// RedditComment represents the data structure for a Reddit comment from the Apify scraper. +type RedditComment struct { + ID string `json:"id"` + ParsedID string `json:"parsedId"` + URL string `json:"url"` + ParentID string `json:"parentId"` + Username string `json:"username"` + Category string `json:"category"` + CommunityName string `json:"communityName"` + Body string `json:"body"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + UpVotes int `json:"upVotes"` + NumberOfReplies int `json:"numberOfreplies"` + HTML string `json:"html"` + DataType string `json:"dataType"` +} + +// RedditCommunity represents the data structure for a Reddit community from the Apify scraper. +type RedditCommunity struct { + ID string `json:"id"` + Name string `json:"name"` + Title string `json:"title"` + HeaderImage string `json:"headerImage"` + Description string `json:"description"` + Over18 bool `json:"over18"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + NumberOfMembers int `json:"numberOfMembers"` + URL string `json:"url"` + DataType string `json:"dataType"` +} + +type RedditTypeSwitch struct { + Type RedditResponseType `json:"type"` +} + +type RedditResponse struct { + TypeSwitch *RedditTypeSwitch + User *RedditUser + Post *RedditPost + Comment *RedditComment + Community *RedditCommunity +} + +func (t *RedditResponse) UnmarshalJSON(data []byte) error { + t.TypeSwitch = &RedditTypeSwitch{} + if err := json.Unmarshal(data, &t.TypeSwitch); err != nil { + return fmt.Errorf("failed to unmarshal reddit response type: %w", err) + } + + switch t.TypeSwitch.Type { + case RedditUserResponse: + t.User = &RedditUser{} + if err := json.Unmarshal(data, t.User); err != nil { + return fmt.Errorf("failed to unmarshal reddit user: %w", err) + } + case RedditPostResponse: + t.Post = &RedditPost{} + if err := json.Unmarshal(data, t.Post); err != nil { + return fmt.Errorf("failed to unmarshal reddit post: %w", err) + } + case RedditCommentResponse: + t.Comment = &RedditComment{} + if err := json.Unmarshal(data, t.Comment); err != nil { + return fmt.Errorf("failed to unmarshal reddit comment: %w", err) + } + case RedditCommunityResponse: + t.Community = &RedditCommunity{} + if err := json.Unmarshal(data, t.Community); err != nil { + return fmt.Errorf("failed to unmarshal reddit community: %w", err) + } + default: + return fmt.Errorf("unknown Reddit response type: %s", t.TypeSwitch.Type) + } + return nil +} + +// MarshalJSON implements the json.Marshaler interface for RedditResponse. +// It unwraps the inner struct (User, Post, Comment, or Community) and marshals it directly. +func (t *RedditResponse) MarshalJSON() ([]byte, error) { + if t.TypeSwitch == nil { + return []byte("null"), nil + } + + switch t.TypeSwitch.Type { + case RedditUserResponse: + return json.Marshal(t.User) + case RedditPostResponse: + return json.Marshal(t.Post) + case RedditCommentResponse: + return json.Marshal(t.Comment) + case RedditCommunityResponse: + return json.Marshal(t.Community) + default: + return nil, fmt.Errorf("unknown Reddit response type: %s", t.TypeSwitch.Type) + } +} diff --git a/types/reddit_test.go b/types/reddit_test.go new file mode 100644 index 0000000..05a7173 --- /dev/null +++ b/types/reddit_test.go @@ -0,0 +1,93 @@ +package types_test + +import ( + "encoding/json" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-types/types" +) + +var _ = Describe("RedditResponse", func() { + Describe("Unmarshalling", func() { + It("should unmarshal a user response", func() { + jsonData := `{"type": "user", "id": "user123", "username": "testuser"}` + var resp types.RedditResponse + err := json.Unmarshal([]byte(jsonData), &resp) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.User).ToNot(BeNil()) + Expect(resp.Post).To(BeNil()) + Expect(resp.User.ID).To(Equal("user123")) + Expect(resp.User.Username).To(Equal("testuser")) + }) + + It("should unmarshal a post response", func() { + jsonData := `{"type": "post", "id": "post123", "title": "Test Post"}` + var resp types.RedditResponse + err := json.Unmarshal([]byte(jsonData), &resp) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Post).ToNot(BeNil()) + Expect(resp.User).To(BeNil()) + Expect(resp.Post.ID).To(Equal("post123")) + Expect(resp.Post.Title).To(Equal("Test Post")) + }) + + It("should return an error for an unknown type", func() { + jsonData := `{"type": "unknown", "id": "123"}` + var resp types.RedditResponse + err := json.Unmarshal([]byte(jsonData), &resp) + Expect(err).To(MatchError("unknown Reddit response type: unknown")) + }) + }) + + Describe("Marshalling", func() { + It("should marshal a user response", func() { + now := time.Now() + resp := types.RedditResponse{ + TypeSwitch: &types.RedditTypeSwitch{Type: types.RedditUserResponse}, + User: &types.RedditUser{ + ID: "user123", + Username: "testuser", + CreatedAt: now, + CommentKarma: 10, + }, + } + + expectedJSON, err := json.Marshal(resp.User) + Expect(err).ToNot(HaveOccurred()) + + actualJSON, err := json.Marshal(&resp) + Expect(err).ToNot(HaveOccurred()) + + Expect(actualJSON).To(MatchJSON(expectedJSON)) + }) + + It("should marshal a post response", func() { + resp := types.RedditResponse{ + TypeSwitch: &types.RedditTypeSwitch{Type: types.RedditPostResponse}, + Post: &types.RedditPost{ + ID: "post123", + Title: "Test Post", + }, + } + + expectedJSON, err := json.Marshal(resp.Post) + Expect(err).ToNot(HaveOccurred()) + + actualJSON, err := json.Marshal(&resp) + Expect(err).ToNot(HaveOccurred()) + + Expect(actualJSON).To(MatchJSON(expectedJSON)) + }) + + It("should return an error for an unknown type", func() { + resp := types.RedditResponse{ + TypeSwitch: &types.RedditTypeSwitch{Type: "unknown"}, + } + _, err := json.Marshal(&resp) + Expect(err).To(HaveOccurred()) + }) + }) +}) diff --git a/types/types_suite_test.go b/types/types_suite_test.go new file mode 100644 index 0000000..3356638 --- /dev/null +++ b/types/types_suite_test.go @@ -0,0 +1,13 @@ +package types_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestTypes(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Types Suite") +} From 3a5a22710aa33bb769235c8c5b71c1ea7ab85aad Mon Sep 17 00:00:00 2001 From: mcamou Date: Thu, 21 Aug 2025 13:30:38 +0200 Subject: [PATCH 2/9] lint --- args/reddit.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/args/reddit.go b/args/reddit.go index fad60f6..0d60bb6 100644 --- a/args/reddit.go +++ b/args/reddit.go @@ -114,7 +114,7 @@ func (r *RedditArguments) Validate() error { errs = append(errs, fmt.Errorf("%s is not a valid URL", q.URL)) } else { if !strings.HasSuffix(u.Host, redditDomainSuffix) { - errs = append(errs, fmt.Errorf("Invalid Reddit URL %s", q.URL)) + errs = append(errs, fmt.Errorf("invalid Reddit URL %s", q.URL)) } } } From 49f78bf111b2e2d672527a0e96465df04978a297 Mon Sep 17 00:00:00 2001 From: mcamou Date: Thu, 21 Aug 2025 13:33:54 +0200 Subject: [PATCH 3/9] test fix --- args/reddit_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/args/reddit_test.go b/args/reddit_test.go index f158de3..525b032 100644 --- a/args/reddit_test.go +++ b/args/reddit_test.go @@ -159,7 +159,7 @@ var _ = Describe("RedditArguments", func() { } err := redditArgs.Validate() Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("Invalid Reddit URL")) + Expect(err.Error()).To(ContainSubstring("invalid Reddit URL")) }) It("should fail with an invalid HTTP method", func() { From edd7c8e0937d6a99b6ecfda66c938e7feb29f4df Mon Sep 17 00:00:00 2001 From: mcamou Date: Thu, 21 Aug 2025 13:59:56 +0200 Subject: [PATCH 4/9] Fix Reddit args validation --- args/reddit.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/args/reddit.go b/args/reddit.go index 0d60bb6..f7842e3 100644 --- a/args/reddit.go +++ b/args/reddit.go @@ -97,6 +97,10 @@ func (r *RedditArguments) Validate() error { errs = append(errs, ErrRedditTimeInTheFuture) } + if len(errs) > 0 { + return errors.Join(errs...) + } + if r.QueryType == teetypes.RedditScrapeUrls { if len(r.URLs) == 0 { errs = append(errs, ErrRedditNoUrls) From 570c8fb39313b1ee326df9d553b4df8546bd62a8 Mon Sep 17 00:00:00 2001 From: mcamou Date: Thu, 21 Aug 2025 18:52:53 +0200 Subject: [PATCH 5/9] Add default URL method --- args/reddit.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/args/reddit.go b/args/reddit.go index f7842e3..6dafa8c 100644 --- a/args/reddit.go +++ b/args/reddit.go @@ -110,6 +110,10 @@ func (r *RedditArguments) Validate() error { } for _, q := range r.URLs { + q.Method = strings.ToUpper(q.Method) + if q.Method == "" { + q.Method = "GET" + } if !allowedHttpMethods.Contains(q.Method) { errs = append(errs, fmt.Errorf("%s is not a valid HTTP method", q.Method)) } From 06dd40db19ea7a7bffa17236524d8601a63c276d Mon Sep 17 00:00:00 2001 From: mcamou Date: Thu, 21 Aug 2025 18:55:09 +0200 Subject: [PATCH 6/9] Lowercase query type and sort --- args/reddit.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/args/reddit.go b/args/reddit.go index 6dafa8c..1c7fe53 100644 --- a/args/reddit.go +++ b/args/reddit.go @@ -85,10 +85,12 @@ const redditDomainSuffix = "reddit.com" func (r *RedditArguments) Validate() error { var errs []error + r.QueryType = teetypes.RedditQueryType(strings.ToLower(string(r.QueryType))) if !teetypes.AllRedditQueryTypes.Contains(r.QueryType) { errs = append(errs, ErrRedditInvalidType) } + r.Sort = teetypes.RedditSortType(strings.ToLower(string(r.Sort))) if !teetypes.AllRedditSortTypes.Contains(r.Sort) { errs = append(errs, ErrRedditInvalidSort) } From 3cac2f13bc757d2af855596a89aebc1a9ad1af2b Mon Sep 17 00:00:00 2001 From: mcamou Date: Thu, 21 Aug 2025 19:27:34 +0200 Subject: [PATCH 7/9] Some renaming of the Reddit types, and adding RedditResult --- types/reddit.go | 42 ++++++++++++++++++++++++------------------ types/reddit_test.go | 16 ++++++++-------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/types/reddit.go b/types/reddit.go index 7b2ed5e..ea4a9d3 100644 --- a/types/reddit.go +++ b/types/reddit.go @@ -39,19 +39,25 @@ var AllRedditSortTypes = util.NewSet( RedditSortComments, ) +// RedditResult represents the response sent back from a Reddit query +type RedditResult struct { + Items []*RedditItem `json:"items"` + NextCursor string `json:"next_cursor"` +} + // RedditStartURL represents a single start URL for the Apify Reddit scraper. type RedditStartURL struct { URL string `json:"url"` Method string `json:"method"` } -type RedditResponseType string +type RedditItemType string const ( - RedditUserResponse RedditResponseType = "user" - RedditPostResponse RedditResponseType = "post" - RedditCommentResponse RedditResponseType = "comment" - RedditCommunityResponse RedditResponseType = "community" + RedditUserItem RedditItemType = "user" + RedditPostItem RedditItemType = "post" + RedditCommentItem RedditItemType = "comment" + RedditCommunityItem RedditItemType = "community" ) // RedditUser represents the data structure for a Reddit user from the Apify scraper. @@ -124,10 +130,10 @@ type RedditCommunity struct { } type RedditTypeSwitch struct { - Type RedditResponseType `json:"type"` + Type RedditItemType `json:"type"` } -type RedditResponse struct { +type RedditItem struct { TypeSwitch *RedditTypeSwitch User *RedditUser Post *RedditPost @@ -135,29 +141,29 @@ type RedditResponse struct { Community *RedditCommunity } -func (t *RedditResponse) UnmarshalJSON(data []byte) error { +func (t *RedditItem) UnmarshalJSON(data []byte) error { t.TypeSwitch = &RedditTypeSwitch{} if err := json.Unmarshal(data, &t.TypeSwitch); err != nil { return fmt.Errorf("failed to unmarshal reddit response type: %w", err) } switch t.TypeSwitch.Type { - case RedditUserResponse: + case RedditUserItem: t.User = &RedditUser{} if err := json.Unmarshal(data, t.User); err != nil { return fmt.Errorf("failed to unmarshal reddit user: %w", err) } - case RedditPostResponse: + case RedditPostItem: t.Post = &RedditPost{} if err := json.Unmarshal(data, t.Post); err != nil { return fmt.Errorf("failed to unmarshal reddit post: %w", err) } - case RedditCommentResponse: + case RedditCommentItem: t.Comment = &RedditComment{} if err := json.Unmarshal(data, t.Comment); err != nil { return fmt.Errorf("failed to unmarshal reddit comment: %w", err) } - case RedditCommunityResponse: + case RedditCommunityItem: t.Community = &RedditCommunity{} if err := json.Unmarshal(data, t.Community); err != nil { return fmt.Errorf("failed to unmarshal reddit community: %w", err) @@ -168,21 +174,21 @@ func (t *RedditResponse) UnmarshalJSON(data []byte) error { return nil } -// MarshalJSON implements the json.Marshaler interface for RedditResponse. +// MarshalJSON implements the json.Marshaller interface for RedditResponse. // It unwraps the inner struct (User, Post, Comment, or Community) and marshals it directly. -func (t *RedditResponse) MarshalJSON() ([]byte, error) { +func (t *RedditItem) MarshalJSON() ([]byte, error) { if t.TypeSwitch == nil { return []byte("null"), nil } switch t.TypeSwitch.Type { - case RedditUserResponse: + case RedditUserItem: return json.Marshal(t.User) - case RedditPostResponse: + case RedditPostItem: return json.Marshal(t.Post) - case RedditCommentResponse: + case RedditCommentItem: return json.Marshal(t.Comment) - case RedditCommunityResponse: + case RedditCommunityItem: return json.Marshal(t.Community) default: return nil, fmt.Errorf("unknown Reddit response type: %s", t.TypeSwitch.Type) diff --git a/types/reddit_test.go b/types/reddit_test.go index 05a7173..c6cf7a3 100644 --- a/types/reddit_test.go +++ b/types/reddit_test.go @@ -14,7 +14,7 @@ var _ = Describe("RedditResponse", func() { Describe("Unmarshalling", func() { It("should unmarshal a user response", func() { jsonData := `{"type": "user", "id": "user123", "username": "testuser"}` - var resp types.RedditResponse + var resp types.RedditItem err := json.Unmarshal([]byte(jsonData), &resp) Expect(err).ToNot(HaveOccurred()) Expect(resp.User).ToNot(BeNil()) @@ -25,7 +25,7 @@ var _ = Describe("RedditResponse", func() { It("should unmarshal a post response", func() { jsonData := `{"type": "post", "id": "post123", "title": "Test Post"}` - var resp types.RedditResponse + var resp types.RedditItem err := json.Unmarshal([]byte(jsonData), &resp) Expect(err).ToNot(HaveOccurred()) Expect(resp.Post).ToNot(BeNil()) @@ -36,7 +36,7 @@ var _ = Describe("RedditResponse", func() { It("should return an error for an unknown type", func() { jsonData := `{"type": "unknown", "id": "123"}` - var resp types.RedditResponse + var resp types.RedditItem err := json.Unmarshal([]byte(jsonData), &resp) Expect(err).To(MatchError("unknown Reddit response type: unknown")) }) @@ -45,8 +45,8 @@ var _ = Describe("RedditResponse", func() { Describe("Marshalling", func() { It("should marshal a user response", func() { now := time.Now() - resp := types.RedditResponse{ - TypeSwitch: &types.RedditTypeSwitch{Type: types.RedditUserResponse}, + resp := types.RedditItem{ + TypeSwitch: &types.RedditTypeSwitch{Type: types.RedditUserItem}, User: &types.RedditUser{ ID: "user123", Username: "testuser", @@ -65,8 +65,8 @@ var _ = Describe("RedditResponse", func() { }) It("should marshal a post response", func() { - resp := types.RedditResponse{ - TypeSwitch: &types.RedditTypeSwitch{Type: types.RedditPostResponse}, + resp := types.RedditItem{ + TypeSwitch: &types.RedditTypeSwitch{Type: types.RedditPostItem}, Post: &types.RedditPost{ ID: "post123", Title: "Test Post", @@ -83,7 +83,7 @@ var _ = Describe("RedditResponse", func() { }) It("should return an error for an unknown type", func() { - resp := types.RedditResponse{ + resp := types.RedditItem{ TypeSwitch: &types.RedditTypeSwitch{Type: "unknown"}, } _, err := json.Marshal(&resp) From 63523ac82ad020359e88dc384af2002cba4b04fb Mon Sep 17 00:00:00 2001 From: mcamou Date: Thu, 21 Aug 2025 19:30:53 +0200 Subject: [PATCH 8/9] Remove RedditResult (the cursor is in JobResult) --- types/reddit.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/types/reddit.go b/types/reddit.go index ea4a9d3..296557e 100644 --- a/types/reddit.go +++ b/types/reddit.go @@ -39,12 +39,6 @@ var AllRedditSortTypes = util.NewSet( RedditSortComments, ) -// RedditResult represents the response sent back from a Reddit query -type RedditResult struct { - Items []*RedditItem `json:"items"` - NextCursor string `json:"next_cursor"` -} - // RedditStartURL represents a single start URL for the Apify Reddit scraper. type RedditStartURL struct { URL string `json:"url"` From e480a88219291c81b9509c24a68ee2334d9612d1 Mon Sep 17 00:00:00 2001 From: mcamou Date: Fri, 22 Aug 2025 11:39:51 +0200 Subject: [PATCH 9/9] PR comments --- args/linkedin.go | 1 + args/reddit.go | 16 ++++++---------- args/tiktok.go | 1 + args/twitter.go | 1 + args/web.go | 1 + 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/args/linkedin.go b/args/linkedin.go index b7df066..dc3ba93 100644 --- a/args/linkedin.go +++ b/args/linkedin.go @@ -21,6 +21,7 @@ type LinkedInArguments struct { // UnmarshalJSON implements custom JSON unmarshaling with validation func (l *LinkedInArguments) UnmarshalJSON(data []byte) error { + // Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...) type Alias LinkedInArguments aux := &struct { *Alias diff --git a/args/reddit.go b/args/reddit.go index 1c7fe53..05bc694 100644 --- a/args/reddit.go +++ b/args/reddit.go @@ -52,8 +52,6 @@ type RedditArguments struct { } func (r *RedditArguments) UnmarshalJSON(data []byte) error { - type Alias RedditArguments - // Set default values. They will be overridden if present in the JSON. r.MaxItems = redditDefaultMaxItems r.MaxPosts = redditDefaultMaxPosts @@ -62,6 +60,8 @@ func (r *RedditArguments) UnmarshalJSON(data []byte) error { r.MaxUsers = redditDefaultMaxUsers r.Sort = redditDefaultSort + // Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...) + type Alias RedditArguments aux := &struct { *Alias }{ @@ -111,10 +111,10 @@ func (r *RedditArguments) Validate() error { errs = append(errs, ErrRedditQueriesNotAllowed) } - for _, q := range r.URLs { - q.Method = strings.ToUpper(q.Method) + for i, q := range r.URLs { + r.URLs[i].Method = strings.ToUpper(q.Method) if q.Method == "" { - q.Method = "GET" + r.URLs[i].Method = "GET" } if !allowedHttpMethods.Contains(q.Method) { errs = append(errs, fmt.Errorf("%s is not a valid HTTP method", q.Method)) @@ -137,11 +137,7 @@ func (r *RedditArguments) Validate() error { } } - if len(errs) > 0 { - return errors.Join(errs...) - } - - return nil + return errors.Join(errs...) } // ValidateForJobType validates Twitter arguments for a specific job type diff --git a/args/tiktok.go b/args/tiktok.go index 6c487e6..5f3687d 100644 --- a/args/tiktok.go +++ b/args/tiktok.go @@ -18,6 +18,7 @@ type TikTokTranscriptionArguments struct { // UnmarshalJSON implements custom JSON unmarshaling with validation func (t *TikTokTranscriptionArguments) UnmarshalJSON(data []byte) error { + // Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...) type Alias TikTokTranscriptionArguments aux := &struct { *Alias diff --git a/args/twitter.go b/args/twitter.go index a32f024..18c6773 100644 --- a/args/twitter.go +++ b/args/twitter.go @@ -21,6 +21,7 @@ type TwitterSearchArguments struct { // UnmarshalJSON implements custom JSON unmarshaling with validation func (t *TwitterSearchArguments) UnmarshalJSON(data []byte) error { + // Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...) type Alias TwitterSearchArguments aux := &struct { *Alias diff --git a/args/web.go b/args/web.go index f1f473f..33a466d 100644 --- a/args/web.go +++ b/args/web.go @@ -18,6 +18,7 @@ type WebSearchArguments struct { // UnmarshalJSON implements custom JSON unmarshaling with validation func (w *WebSearchArguments) UnmarshalJSON(data []byte) error { + // Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...) type Alias WebSearchArguments aux := &struct { *Alias