diff --git a/.gitignore b/.gitignore index 9954081..f5f7bd6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,7 @@ go.work *~ *.log .DS_Store + +# LLM-related files .aider* +GEMINI.md diff --git a/args/args_suite_test.go b/args/args_suite_test.go new file mode 100644 index 0000000..861e0bf --- /dev/null +++ b/args/args_suite_test.go @@ -0,0 +1,13 @@ +package args_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestArgs(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Args Suite") +} diff --git a/args/linkedin.go b/args/linkedin.go index b7df066..dc3ba93 100644 --- a/args/linkedin.go +++ b/args/linkedin.go @@ -21,6 +21,7 @@ type LinkedInArguments struct { // UnmarshalJSON implements custom JSON unmarshaling with validation func (l *LinkedInArguments) UnmarshalJSON(data []byte) error { + // Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...) type Alias LinkedInArguments aux := &struct { *Alias diff --git a/args/reddit.go b/args/reddit.go new file mode 100644 index 0000000..05bc694 --- /dev/null +++ b/args/reddit.go @@ -0,0 +1,156 @@ +package args + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/masa-finance/tee-types/pkg/util" + teetypes "github.com/masa-finance/tee-types/types" +) + +var ( + ErrRedditInvalidType = errors.New("invalid type") + ErrRedditInvalidSort = errors.New("invalid sort") + ErrRedditTimeInTheFuture = errors.New("after field is in the future") + ErrRedditNoQueries = errors.New("queries must be provided for all query types except scrapeurls") + ErrRedditNoUrls = errors.New("urls must be provided for scrapeurls query type") + ErrRedditQueriesNotAllowed = errors.New("the scrapeurls query type does not admit queries") + ErrRedditUrlsNotAllowed = errors.New("urls can only be provided for the scrapeurls query type") +) + +const ( + // These reflect the default values in https://apify.com/trudax/reddit-scraper/input-schema + redditDefaultMaxItems = 10 + redditDefaultMaxPosts = 10 + redditDefaultMaxComments = 10 + redditDefaultMaxCommunities = 2 + redditDefaultMaxUsers = 2 + redditDefaultSort = teetypes.RedditSortNew +) + +// RedditArguments defines args for Reddit scrapes +// see https://apify.com/trudax/reddit-scraper +type RedditArguments struct { + QueryType teetypes.RedditQueryType `json:"type"` + Queries []string `json:"queries"` + URLs []teetypes.RedditStartURL `json:"urls"` + Sort teetypes.RedditSortType `json:"sort"` + IncludeNSFW bool `json:"include_nsfw"` + SkipPosts bool `json:"skip_posts"` // Valid only for searchusers + After time.Time `json:"after"` // valid only for scrapeurls and searchposts + MaxItems uint `json:"max_items"` // Max number of items to scrape (total), default 10 + MaxResults uint `json:"max_results"` // Max number of results per page, default MaxItems + MaxPosts uint `json:"max_posts"` // Max number of posts per page, default 10 + MaxComments uint `json:"max_comments"` // Max number of comments per page, default 10 + MaxCommunities uint `json:"max_communities"` // Max number of communities per page, default 2 + MaxUsers uint `json:"max_users"` // Max number of users per page, default 2 + NextCursor string `json:"next_cursor"` +} + +func (r *RedditArguments) UnmarshalJSON(data []byte) error { + // Set default values. They will be overridden if present in the JSON. + r.MaxItems = redditDefaultMaxItems + r.MaxPosts = redditDefaultMaxPosts + r.MaxComments = redditDefaultMaxComments + r.MaxCommunities = redditDefaultMaxCommunities + r.MaxUsers = redditDefaultMaxUsers + r.Sort = redditDefaultSort + + // Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...) + type Alias RedditArguments + aux := &struct { + *Alias + }{ + Alias: (*Alias)(r), + } + + if err := json.Unmarshal(data, aux); err != nil { + return fmt.Errorf("failed to unmarshal Reddit arguments: %w", err) + } + + if r.MaxResults == 0 { + r.MaxResults = r.MaxItems + } + + return r.Validate() +} + +var allowedHttpMethods = util.NewSet("GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS") + +const redditDomainSuffix = "reddit.com" + +func (r *RedditArguments) Validate() error { + var errs []error + r.QueryType = teetypes.RedditQueryType(strings.ToLower(string(r.QueryType))) + if !teetypes.AllRedditQueryTypes.Contains(r.QueryType) { + errs = append(errs, ErrRedditInvalidType) + } + + r.Sort = teetypes.RedditSortType(strings.ToLower(string(r.Sort))) + if !teetypes.AllRedditSortTypes.Contains(r.Sort) { + errs = append(errs, ErrRedditInvalidSort) + } + + if time.Now().Before(r.After) { + errs = append(errs, ErrRedditTimeInTheFuture) + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + + if r.QueryType == teetypes.RedditScrapeUrls { + if len(r.URLs) == 0 { + errs = append(errs, ErrRedditNoUrls) + } + if len(r.Queries) > 0 { + errs = append(errs, ErrRedditQueriesNotAllowed) + } + + for i, q := range r.URLs { + r.URLs[i].Method = strings.ToUpper(q.Method) + if q.Method == "" { + r.URLs[i].Method = "GET" + } + if !allowedHttpMethods.Contains(q.Method) { + errs = append(errs, fmt.Errorf("%s is not a valid HTTP method", q.Method)) + } + u, err := url.Parse(q.URL) + if err != nil { + errs = append(errs, fmt.Errorf("%s is not a valid URL", q.URL)) + } else { + if !strings.HasSuffix(u.Host, redditDomainSuffix) { + errs = append(errs, fmt.Errorf("invalid Reddit URL %s", q.URL)) + } + } + } + } else { + if len(r.Queries) == 0 { + errs = append(errs, ErrRedditNoQueries) + } + if len(r.URLs) > 0 { + errs = append(errs, ErrRedditUrlsNotAllowed) + } + } + + return errors.Join(errs...) +} + +// ValidateForJobType validates Twitter arguments for a specific job type +func (r *RedditArguments) ValidateForJobType(jobType teetypes.JobType) error { + if err := r.Validate(); err != nil { + return err + } + + // Validate QueryType against job-specific capabilities + return jobType.ValidateCapability(teetypes.Capability(r.QueryType)) +} + +// GetCapability returns the QueryType as a typed Capability +func (r *RedditArguments) GetCapability() teetypes.Capability { + return teetypes.Capability(r.QueryType) +} diff --git a/args/reddit_test.go b/args/reddit_test.go new file mode 100644 index 0000000..525b032 --- /dev/null +++ b/args/reddit_test.go @@ -0,0 +1,178 @@ +package args_test + +import ( + "encoding/json" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-types/args" + "github.com/masa-finance/tee-types/types" +) + +var _ = Describe("RedditArguments", func() { + Describe("Unmarshalling", func() { + It("should set default values", func() { + redditArgs := &args.RedditArguments{} + jsonData := `{"type": "searchposts", "queries": ["test"]}` + err := json.Unmarshal([]byte(jsonData), redditArgs) + Expect(err).ToNot(HaveOccurred()) + Expect(redditArgs.MaxItems).To(Equal(uint(10))) + Expect(redditArgs.MaxPosts).To(Equal(uint(10))) + Expect(redditArgs.MaxComments).To(Equal(uint(10))) + Expect(redditArgs.MaxCommunities).To(Equal(uint(2))) + Expect(redditArgs.MaxUsers).To(Equal(uint(2))) + Expect(redditArgs.Sort).To(Equal(types.RedditSortNew)) + Expect(redditArgs.MaxResults).To(Equal(redditArgs.MaxItems)) + }) + + It("should override default values", func() { + redditArgs := &args.RedditArguments{} + jsonData := `{"type": "searchposts", "queries": ["test"], "max_items": 20, "sort": "top"}` + err := json.Unmarshal([]byte(jsonData), redditArgs) + Expect(err).ToNot(HaveOccurred()) + Expect(redditArgs.MaxItems).To(Equal(uint(20))) + Expect(redditArgs.Sort).To(Equal(types.RedditSortTop)) + Expect(redditArgs.MaxResults).To(Equal(uint(20))) + }) + }) + + Describe("Validation", func() { + It("should succeed with valid arguments", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditSearchPosts, + Queries: []string{"test"}, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should succeed with valid scrapeurls arguments", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + URLs: []types.RedditStartURL{ + {URL: "https://www.reddit.com/r/golang/", Method: "GET"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail with an invalid type", func() { + redditArgs := &args.RedditArguments{ + QueryType: "invalidtype", + Queries: []string{"test"}, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditInvalidType)) + }) + + It("should fail with an invalid sort", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditSearchPosts, + Queries: []string{"test"}, + Sort: "invalidsort", + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditInvalidSort)) + }) + + It("should fail if the after time is in the future", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditSearchPosts, + Queries: []string{"test"}, + Sort: types.RedditSortNew, + After: time.Now().Add(24 * time.Hour), + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditTimeInTheFuture)) + }) + + It("should fail if queries are not provided for searchposts", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditSearchPosts, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditNoQueries)) + }) + + It("should fail if urls are not provided for scrapeurls", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditNoUrls)) + }) + + It("should fail if queries are provided for scrapeurls", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + Queries: []string{"test"}, + URLs: []types.RedditStartURL{ + {URL: "https://www.reddit.com/r/golang/", Method: "GET"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditQueriesNotAllowed)) + }) + + It("should fail if urls are provided for searchposts", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditSearchPosts, + Queries: []string{"test"}, + URLs: []types.RedditStartURL{ + {URL: "https://www.reddit.com/r/golang/", Method: "GET"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(MatchError(args.ErrRedditUrlsNotAllowed)) + }) + + It("should fail with an invalid URL", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + URLs: []types.RedditStartURL{ + {URL: "ht tp://invalid-url.com", Method: "GET"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("is not a valid URL")) + }) + + It("should fail with an invalid domain", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + URLs: []types.RedditStartURL{ + {URL: "https://www.google.com", Method: "GET"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("invalid Reddit URL")) + }) + + It("should fail with an invalid HTTP method", func() { + redditArgs := &args.RedditArguments{ + QueryType: types.RedditScrapeUrls, + URLs: []types.RedditStartURL{ + {URL: "https://www.reddit.com/r/golang/", Method: "INVALID"}, + }, + Sort: types.RedditSortNew, + } + err := redditArgs.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("is not a valid HTTP method")) + }) + }) +}) diff --git a/args/tiktok.go b/args/tiktok.go index 6c487e6..5f3687d 100644 --- a/args/tiktok.go +++ b/args/tiktok.go @@ -18,6 +18,7 @@ type TikTokTranscriptionArguments struct { // UnmarshalJSON implements custom JSON unmarshaling with validation func (t *TikTokTranscriptionArguments) UnmarshalJSON(data []byte) error { + // Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...) type Alias TikTokTranscriptionArguments aux := &struct { *Alias diff --git a/args/twitter.go b/args/twitter.go index a32f024..18c6773 100644 --- a/args/twitter.go +++ b/args/twitter.go @@ -21,6 +21,7 @@ type TwitterSearchArguments struct { // UnmarshalJSON implements custom JSON unmarshaling with validation func (t *TwitterSearchArguments) UnmarshalJSON(data []byte) error { + // Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...) type Alias TwitterSearchArguments aux := &struct { *Alias diff --git a/args/unmarshaller.go b/args/unmarshaller.go index 353bf8f..15c8f4d 100644 --- a/args/unmarshaller.go +++ b/args/unmarshaller.go @@ -49,6 +49,12 @@ type LinkedInJobArguments interface { ValidateForJobType(jobType types.JobType) error } +// RedditJobArguments extends JobArguments for Reddit-specific methods +type RedditJobArguments interface { + JobArguments + ValidateForJobType(jobType types.JobType) error +} + // UnmarshalJobArguments unmarshals job arguments from a generic map into the appropriate typed struct // This works with both tee-indexer and tee-worker JobArguments types func UnmarshalJobArguments(jobType types.JobType, args map[string]any) (JobArguments, error) { @@ -65,6 +71,9 @@ func UnmarshalJobArguments(jobType types.JobType, args map[string]any) (JobArgum case types.LinkedInJob: return unmarshalLinkedInArguments(jobType, args) + case types.RedditJob: + return unmarshalRedditArguments(jobType, args) + case types.TelemetryJob: return &TelemetryJobArguments{}, nil @@ -132,6 +141,27 @@ func unmarshalLinkedInArguments(jobType types.JobType, args map[string]any) (*Li return linkedInArgs, nil } +func unmarshalRedditArguments(jobType types.JobType, args map[string]any) (*RedditArguments, error) { + redditArgs := &RedditArguments{} + if err := unmarshalToStruct(args, redditArgs); err != nil { + return nil, fmt.Errorf("failed to unmarshal Reddit job arguments: %w", err) + } + + // If no QueryType is specified, use the default capability for this job type + if redditArgs.QueryType == "" { + if defaultCap, exists := types.JobDefaultCapabilityMap[jobType]; exists { + redditArgs.QueryType = types.RedditQueryType(defaultCap) + } + } + + // Perform job-type-specific validation for Reddit + if err := redditArgs.ValidateForJobType(jobType); err != nil { + return nil, fmt.Errorf("reddit job validation failed: %w", err) + } + + return redditArgs, nil +} + // unmarshalToStruct converts a map[string]any to a struct using JSON marshal/unmarshal // This provides the same functionality as the existing JobArguments.Unmarshal methods func unmarshalToStruct(args map[string]any, target any) error { diff --git a/args/unmarshaller_test.go b/args/unmarshaller_test.go new file mode 100644 index 0000000..04e784f --- /dev/null +++ b/args/unmarshaller_test.go @@ -0,0 +1,105 @@ +package args_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-types/args" + "github.com/masa-finance/tee-types/types" +) + +var _ = Describe("Unmarshaller", func() { + Describe("UnmarshalJobArguments", func() { + Context("with a WebJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "url": "https://example.com", + "selector": "h1", + "max_depth": 2, + } + jobArgs, err := args.UnmarshalJobArguments(types.WebJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + webArgs, ok := jobArgs.(*args.WebSearchArguments) + Expect(ok).To(BeTrue()) + Expect(webArgs.URL).To(Equal("https://example.com")) + Expect(webArgs.Selector).To(Equal("h1")) + Expect(webArgs.MaxDepth).To(Equal(2)) + }) + }) + + Context("with a TiktokJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "video_url": "https://www.tiktok.com/@user/video/123", + "language": "en-us", + } + jobArgs, err := args.UnmarshalJobArguments(types.TiktokJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + tiktokArgs, ok := jobArgs.(*args.TikTokTranscriptionArguments) + Expect(ok).To(BeTrue()) + Expect(tiktokArgs.VideoURL).To(Equal("https://www.tiktok.com/@user/video/123")) + Expect(tiktokArgs.Language).To(Equal("en-us")) + }) + }) + + Context("with a TwitterJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "type": "searchbyquery", + "query": "golang", + "count": 10, + } + jobArgs, err := args.UnmarshalJobArguments(types.TwitterJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + twitterArgs, ok := jobArgs.(*args.TwitterSearchArguments) + Expect(ok).To(BeTrue()) + Expect(twitterArgs.QueryType).To(Equal("searchbyquery")) + Expect(twitterArgs.Query).To(Equal("golang")) + Expect(twitterArgs.Count).To(Equal(10)) + }) + + It("should set the default capability for TwitterApifyJob", func() { + argsMap := map[string]any{"query": "masa-finance"} + jobArgs, err := args.UnmarshalJobArguments(types.TwitterApifyJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + twitterArgs, ok := jobArgs.(*args.TwitterSearchArguments) + Expect(ok).To(BeTrue()) + Expect(twitterArgs.GetCapability()).To(Equal(types.CapGetFollowers)) + }) + }) + + Context("with a RedditJob", func() { + It("should unmarshal the arguments correctly", func() { + argsMap := map[string]any{ + "type": "searchposts", + "queries": []string{"golang"}, + "sort": "new", + } + jobArgs, err := args.UnmarshalJobArguments(types.RedditJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + redditArgs, ok := jobArgs.(*args.RedditArguments) + Expect(ok).To(BeTrue()) + Expect(redditArgs.QueryType).To(Equal(types.RedditQueryType("searchposts"))) + }) + }) + + Context("with a TelemetryJob", func() { + It("should return a TelemetryJobArguments struct", func() { + argsMap := map[string]any{} + jobArgs, err := args.UnmarshalJobArguments(types.TelemetryJob, argsMap) + Expect(err).ToNot(HaveOccurred()) + _, ok := jobArgs.(*args.TelemetryJobArguments) + Expect(ok).To(BeTrue()) + }) + }) + + Context("with an unknown job type", func() { + It("should return an error", func() { + argsMap := map[string]any{} + _, err := args.UnmarshalJobArguments("unknown", argsMap) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unknown job type")) + }) + }) + }) +}) diff --git a/args/web.go b/args/web.go index f1f473f..33a466d 100644 --- a/args/web.go +++ b/args/web.go @@ -18,6 +18,7 @@ type WebSearchArguments struct { // UnmarshalJSON implements custom JSON unmarshaling with validation func (w *WebSearchArguments) UnmarshalJSON(data []byte) error { + // Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...) type Alias WebSearchArguments aux := &struct { *Alias diff --git a/pkg/util/set.go b/pkg/util/set.go index fe237be..33b907d 100644 --- a/pkg/util/set.go +++ b/pkg/util/set.go @@ -23,17 +23,19 @@ func (s *Set[T]) Contains(item T) bool { } // Add inserts the given items into the set, deduplicating them. -func (s *Set[T]) Add(items ...T) { +func (s *Set[T]) Add(items ...T) *Set[T] { for _, item := range items { (*s)[item] = struct{}{} } + return s } // Delete removes the given items from the set if it contains them. -func (s *Set[T]) Delete(items ...T) { +func (s *Set[T]) Delete(items ...T) *Set[T] { for _, item := range items { delete((*s), item) } + return s } // Length returns the number of items in the set. diff --git a/types/jobs.go b/types/jobs.go index 945eb63..9c1cce5 100644 --- a/types/jobs.go +++ b/types/jobs.go @@ -50,6 +50,7 @@ const ( TwitterApiJob JobType = "twitter-api" // Twitter scraping with API keys TwitterApifyJob JobType = "twitter-apify" // Twitter scraping with Apify LinkedInJob JobType = "linkedin" // LinkedIn scraping, keeping for unmarshalling logic + RedditJob JobType = "reddit" // Reddit scraping with Apify ) // Capability constants - typed to prevent typos and enable discoverability @@ -73,7 +74,13 @@ const ( CapGetFollowers Capability = "getfollowers" CapGetSpace Capability = "getspace" CapGetProfile Capability = "getprofile" // LinkedIn get profile capability - CapEmpty Capability = "" + // Reddit capabilities + CapScrapeUrls Capability = "scrapeurls" + CapSearchPosts Capability = "searchposts" + CapSearchUsers Capability = "searchusers" + CapSearchCommunities Capability = "searchcommunities" + + CapEmpty Capability = "" ) // Capability group constants for easy reuse @@ -104,6 +111,9 @@ var ( // TwitterApifyCaps are Twitter capabilities available with Apify TwitterApifyCaps = []Capability{CapGetFollowers, CapGetFollowing, CapEmpty} + + // RedditCaps are all the Reddit capabilities (only available with Apify) + RedditCaps = []Capability{CapScrapeUrls, CapSearchPosts, CapSearchUsers, CapSearchCommunities} ) // JobCapabilityMap defines which capabilities are valid for each job type @@ -128,6 +138,9 @@ var JobCapabilityMap = map[JobType][]Capability{ // TikTok job capabilities TiktokJob: AlwaysAvailableTiktokCaps, + // Reddit job capabilities + RedditJob: RedditCaps, + // Telemetry job capabilities TelemetryJob: AlwaysAvailableTelemetryCaps, } @@ -140,5 +153,6 @@ var JobDefaultCapabilityMap = map[JobType]Capability{ TwitterApifyJob: CapGetFollowers, WebJob: CapScraper, TiktokJob: CapTranscription, + RedditJob: CapScrapeUrls, TelemetryJob: CapTelemetry, } diff --git a/types/reddit.go b/types/reddit.go new file mode 100644 index 0000000..296557e --- /dev/null +++ b/types/reddit.go @@ -0,0 +1,190 @@ +package types + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/masa-finance/tee-types/pkg/util" +) + +type RedditQueryType string + +const ( + RedditScrapeUrls RedditQueryType = "scrapeurls" + RedditSearchPosts RedditQueryType = "searchposts" + RedditSearchUsers RedditQueryType = "searchusers" + RedditSearchCommunities RedditQueryType = "searchcommunities" +) + +var AllRedditQueryTypes = util.NewSet(RedditScrapeUrls, RedditSearchPosts, RedditSearchUsers, RedditSearchCommunities) + +type RedditSortType string + +const ( + RedditSortRelevance RedditSortType = "relevance" + RedditSortHot RedditSortType = "hot" + RedditSortTop RedditSortType = "top" + RedditSortNew RedditSortType = "new" + RedditSortRising RedditSortType = "rising" + RedditSortComments RedditSortType = "comments" +) + +var AllRedditSortTypes = util.NewSet( + RedditSortRelevance, + RedditSortHot, + RedditSortTop, + RedditSortNew, + RedditSortRising, + RedditSortComments, +) + +// RedditStartURL represents a single start URL for the Apify Reddit scraper. +type RedditStartURL struct { + URL string `json:"url"` + Method string `json:"method"` +} + +type RedditItemType string + +const ( + RedditUserItem RedditItemType = "user" + RedditPostItem RedditItemType = "post" + RedditCommentItem RedditItemType = "comment" + RedditCommunityItem RedditItemType = "community" +) + +// RedditUser represents the data structure for a Reddit user from the Apify scraper. +type RedditUser struct { + ID string `json:"id"` + URL string `json:"url"` + Username string `json:"username"` + UserIcon string `json:"userIcon"` + PostKarma int `json:"postKarma"` + CommentKarma int `json:"commentKarma"` + Description string `json:"description"` + Over18 bool `json:"over18"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + DataType string `json:"dataType"` +} + +// RedditPost represents the data structure for a Reddit post from the Apify scraper. +type RedditPost struct { + ID string `json:"id"` + ParsedID string `json:"parsedId"` + URL string `json:"url"` + Username string `json:"username"` + Title string `json:"title"` + CommunityName string `json:"communityName"` + ParsedCommunityName string `json:"parsedCommunityName"` + Body string `json:"body"` + HTML *string `json:"html"` + NumberOfComments int `json:"numberOfComments"` + UpVotes int `json:"upVotes"` + IsVideo bool `json:"isVideo"` + IsAd bool `json:"isAd"` + Over18 bool `json:"over18"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + DataType string `json:"dataType"` +} + +// RedditComment represents the data structure for a Reddit comment from the Apify scraper. +type RedditComment struct { + ID string `json:"id"` + ParsedID string `json:"parsedId"` + URL string `json:"url"` + ParentID string `json:"parentId"` + Username string `json:"username"` + Category string `json:"category"` + CommunityName string `json:"communityName"` + Body string `json:"body"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + UpVotes int `json:"upVotes"` + NumberOfReplies int `json:"numberOfreplies"` + HTML string `json:"html"` + DataType string `json:"dataType"` +} + +// RedditCommunity represents the data structure for a Reddit community from the Apify scraper. +type RedditCommunity struct { + ID string `json:"id"` + Name string `json:"name"` + Title string `json:"title"` + HeaderImage string `json:"headerImage"` + Description string `json:"description"` + Over18 bool `json:"over18"` + CreatedAt time.Time `json:"createdAt"` + ScrapedAt time.Time `json:"scrapedAt"` + NumberOfMembers int `json:"numberOfMembers"` + URL string `json:"url"` + DataType string `json:"dataType"` +} + +type RedditTypeSwitch struct { + Type RedditItemType `json:"type"` +} + +type RedditItem struct { + TypeSwitch *RedditTypeSwitch + User *RedditUser + Post *RedditPost + Comment *RedditComment + Community *RedditCommunity +} + +func (t *RedditItem) UnmarshalJSON(data []byte) error { + t.TypeSwitch = &RedditTypeSwitch{} + if err := json.Unmarshal(data, &t.TypeSwitch); err != nil { + return fmt.Errorf("failed to unmarshal reddit response type: %w", err) + } + + switch t.TypeSwitch.Type { + case RedditUserItem: + t.User = &RedditUser{} + if err := json.Unmarshal(data, t.User); err != nil { + return fmt.Errorf("failed to unmarshal reddit user: %w", err) + } + case RedditPostItem: + t.Post = &RedditPost{} + if err := json.Unmarshal(data, t.Post); err != nil { + return fmt.Errorf("failed to unmarshal reddit post: %w", err) + } + case RedditCommentItem: + t.Comment = &RedditComment{} + if err := json.Unmarshal(data, t.Comment); err != nil { + return fmt.Errorf("failed to unmarshal reddit comment: %w", err) + } + case RedditCommunityItem: + t.Community = &RedditCommunity{} + if err := json.Unmarshal(data, t.Community); err != nil { + return fmt.Errorf("failed to unmarshal reddit community: %w", err) + } + default: + return fmt.Errorf("unknown Reddit response type: %s", t.TypeSwitch.Type) + } + return nil +} + +// MarshalJSON implements the json.Marshaller interface for RedditResponse. +// It unwraps the inner struct (User, Post, Comment, or Community) and marshals it directly. +func (t *RedditItem) MarshalJSON() ([]byte, error) { + if t.TypeSwitch == nil { + return []byte("null"), nil + } + + switch t.TypeSwitch.Type { + case RedditUserItem: + return json.Marshal(t.User) + case RedditPostItem: + return json.Marshal(t.Post) + case RedditCommentItem: + return json.Marshal(t.Comment) + case RedditCommunityItem: + return json.Marshal(t.Community) + default: + return nil, fmt.Errorf("unknown Reddit response type: %s", t.TypeSwitch.Type) + } +} diff --git a/types/reddit_test.go b/types/reddit_test.go new file mode 100644 index 0000000..c6cf7a3 --- /dev/null +++ b/types/reddit_test.go @@ -0,0 +1,93 @@ +package types_test + +import ( + "encoding/json" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/tee-types/types" +) + +var _ = Describe("RedditResponse", func() { + Describe("Unmarshalling", func() { + It("should unmarshal a user response", func() { + jsonData := `{"type": "user", "id": "user123", "username": "testuser"}` + var resp types.RedditItem + err := json.Unmarshal([]byte(jsonData), &resp) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.User).ToNot(BeNil()) + Expect(resp.Post).To(BeNil()) + Expect(resp.User.ID).To(Equal("user123")) + Expect(resp.User.Username).To(Equal("testuser")) + }) + + It("should unmarshal a post response", func() { + jsonData := `{"type": "post", "id": "post123", "title": "Test Post"}` + var resp types.RedditItem + err := json.Unmarshal([]byte(jsonData), &resp) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Post).ToNot(BeNil()) + Expect(resp.User).To(BeNil()) + Expect(resp.Post.ID).To(Equal("post123")) + Expect(resp.Post.Title).To(Equal("Test Post")) + }) + + It("should return an error for an unknown type", func() { + jsonData := `{"type": "unknown", "id": "123"}` + var resp types.RedditItem + err := json.Unmarshal([]byte(jsonData), &resp) + Expect(err).To(MatchError("unknown Reddit response type: unknown")) + }) + }) + + Describe("Marshalling", func() { + It("should marshal a user response", func() { + now := time.Now() + resp := types.RedditItem{ + TypeSwitch: &types.RedditTypeSwitch{Type: types.RedditUserItem}, + User: &types.RedditUser{ + ID: "user123", + Username: "testuser", + CreatedAt: now, + CommentKarma: 10, + }, + } + + expectedJSON, err := json.Marshal(resp.User) + Expect(err).ToNot(HaveOccurred()) + + actualJSON, err := json.Marshal(&resp) + Expect(err).ToNot(HaveOccurred()) + + Expect(actualJSON).To(MatchJSON(expectedJSON)) + }) + + It("should marshal a post response", func() { + resp := types.RedditItem{ + TypeSwitch: &types.RedditTypeSwitch{Type: types.RedditPostItem}, + Post: &types.RedditPost{ + ID: "post123", + Title: "Test Post", + }, + } + + expectedJSON, err := json.Marshal(resp.Post) + Expect(err).ToNot(HaveOccurred()) + + actualJSON, err := json.Marshal(&resp) + Expect(err).ToNot(HaveOccurred()) + + Expect(actualJSON).To(MatchJSON(expectedJSON)) + }) + + It("should return an error for an unknown type", func() { + resp := types.RedditItem{ + TypeSwitch: &types.RedditTypeSwitch{Type: "unknown"}, + } + _, err := json.Marshal(&resp) + Expect(err).To(HaveOccurred()) + }) + }) +}) diff --git a/types/types_suite_test.go b/types/types_suite_test.go new file mode 100644 index 0000000..3356638 --- /dev/null +++ b/types/types_suite_test.go @@ -0,0 +1,13 @@ +package types_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestTypes(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Types Suite") +}