Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ go.work
*~
*.log
.DS_Store

# LLM-related files
.aider*
GEMINI.md
13 changes: 13 additions & 0 deletions args/args_suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package args_test

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestArgs(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Args Suite")
}
1 change: 1 addition & 0 deletions args/linkedin.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type LinkedInArguments struct {

// UnmarshalJSON implements custom JSON unmarshaling with validation
func (l *LinkedInArguments) UnmarshalJSON(data []byte) error {
// Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...)
type Alias LinkedInArguments
aux := &struct {
*Alias
Expand Down
156 changes: 156 additions & 0 deletions args/reddit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package args

import (
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
"time"

"github.com/masa-finance/tee-types/pkg/util"
teetypes "github.com/masa-finance/tee-types/types"
)

var (
ErrRedditInvalidType = errors.New("invalid type")
ErrRedditInvalidSort = errors.New("invalid sort")
ErrRedditTimeInTheFuture = errors.New("after field is in the future")
ErrRedditNoQueries = errors.New("queries must be provided for all query types except scrapeurls")
ErrRedditNoUrls = errors.New("urls must be provided for scrapeurls query type")
ErrRedditQueriesNotAllowed = errors.New("the scrapeurls query type does not admit queries")
ErrRedditUrlsNotAllowed = errors.New("urls can only be provided for the scrapeurls query type")
)

const (
// These reflect the default values in https://apify.com/trudax/reddit-scraper/input-schema
redditDefaultMaxItems = 10
redditDefaultMaxPosts = 10
redditDefaultMaxComments = 10
redditDefaultMaxCommunities = 2
redditDefaultMaxUsers = 2
redditDefaultSort = teetypes.RedditSortNew
)

// RedditArguments defines args for Reddit scrapes
// see https://apify.com/trudax/reddit-scraper
type RedditArguments struct {
QueryType teetypes.RedditQueryType `json:"type"`
Queries []string `json:"queries"`
URLs []teetypes.RedditStartURL `json:"urls"`
Sort teetypes.RedditSortType `json:"sort"`
IncludeNSFW bool `json:"include_nsfw"`
SkipPosts bool `json:"skip_posts"` // Valid only for searchusers
After time.Time `json:"after"` // valid only for scrapeurls and searchposts
MaxItems uint `json:"max_items"` // Max number of items to scrape (total), default 10
MaxResults uint `json:"max_results"` // Max number of results per page, default MaxItems
MaxPosts uint `json:"max_posts"` // Max number of posts per page, default 10
MaxComments uint `json:"max_comments"` // Max number of comments per page, default 10
MaxCommunities uint `json:"max_communities"` // Max number of communities per page, default 2
MaxUsers uint `json:"max_users"` // Max number of users per page, default 2
NextCursor string `json:"next_cursor"`
}

func (r *RedditArguments) UnmarshalJSON(data []byte) error {
// Set default values. They will be overridden if present in the JSON.
r.MaxItems = redditDefaultMaxItems
r.MaxPosts = redditDefaultMaxPosts
r.MaxComments = redditDefaultMaxComments
r.MaxCommunities = redditDefaultMaxCommunities
r.MaxUsers = redditDefaultMaxUsers
r.Sort = redditDefaultSort

// Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...)
type Alias RedditArguments
aux := &struct {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this slightly confuses me, why not unmarshalling on r directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to #7 (comment), this pattern is a go idiom to prevent infinite recursion w/ custom UnmarshalJSON methods... go will call your custom unmarshal method when using json.Unmarshal, which then uses json.Unmarshal...

I'll add an appropriate comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. OK cool 👍

*Alias
}{
Alias: (*Alias)(r),
}

if err := json.Unmarshal(data, aux); err != nil {
return fmt.Errorf("failed to unmarshal Reddit arguments: %w", err)
}

if r.MaxResults == 0 {
r.MaxResults = r.MaxItems
}

return r.Validate()
}

var allowedHttpMethods = util.NewSet("GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS")

const redditDomainSuffix = "reddit.com"

func (r *RedditArguments) Validate() error {
var errs []error
r.QueryType = teetypes.RedditQueryType(strings.ToLower(string(r.QueryType)))
if !teetypes.AllRedditQueryTypes.Contains(r.QueryType) {
errs = append(errs, ErrRedditInvalidType)
}

r.Sort = teetypes.RedditSortType(strings.ToLower(string(r.Sort)))
if !teetypes.AllRedditSortTypes.Contains(r.Sort) {
errs = append(errs, ErrRedditInvalidSort)
}

if time.Now().Before(r.After) {
errs = append(errs, ErrRedditTimeInTheFuture)
}

if len(errs) > 0 {
return errors.Join(errs...)
}

if r.QueryType == teetypes.RedditScrapeUrls {
if len(r.URLs) == 0 {
errs = append(errs, ErrRedditNoUrls)
}
if len(r.Queries) > 0 {
errs = append(errs, ErrRedditQueriesNotAllowed)
}

for i, q := range r.URLs {
r.URLs[i].Method = strings.ToUpper(q.Method)
if q.Method == "" {
r.URLs[i].Method = "GET"
}
if !allowedHttpMethods.Contains(q.Method) {
errs = append(errs, fmt.Errorf("%s is not a valid HTTP method", q.Method))
}
u, err := url.Parse(q.URL)
if err != nil {
errs = append(errs, fmt.Errorf("%s is not a valid URL", q.URL))
} else {
if !strings.HasSuffix(u.Host, redditDomainSuffix) {
errs = append(errs, fmt.Errorf("invalid Reddit URL %s", q.URL))
Copy link

Copilot AI Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method assignment to q.Method only modifies the local copy since Go passes structs by value. This means the validation passes but the original struct is not updated. Use a pointer to modify the original: r.URLs[i].Method = strings.ToUpper(r.URLs[i].Method)

Suggested change
errs = append(errs, fmt.Errorf("invalid Reddit URL %s", q.URL))
for i := range r.URLs {
r.URLs[i].Method = strings.ToUpper(r.URLs[i].Method)
if r.URLs[i].Method == "" {
r.URLs[i].Method = "GET"
}
if !allowedHttpMethods.Contains(r.URLs[i].Method) {
errs = append(errs, fmt.Errorf("%s is not a valid HTTP method", r.URLs[i].Method))
}
u, err := url.Parse(r.URLs[i].URL)
if err != nil {
errs = append(errs, fmt.Errorf("%s is not a valid URL", r.URLs[i].URL))
} else {
if !strings.HasSuffix(u.Host, redditDomainSuffix) {
errs = append(errs, fmt.Errorf("invalid Reddit URL %s", r.URLs[i].URL))

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous issue, this assignment only affects the local copy. Use r.URLs[i].Method = "GET" to modify the original struct.

Suggested change
errs = append(errs, fmt.Errorf("invalid Reddit URL %s", q.URL))
for i := range r.URLs {
r.URLs[i].Method = strings.ToUpper(r.URLs[i].Method)
if r.URLs[i].Method == "" {
r.URLs[i].Method = "GET"
}
if !allowedHttpMethods.Contains(r.URLs[i].Method) {
errs = append(errs, fmt.Errorf("%s is not a valid HTTP method", r.URLs[i].Method))
}
u, err := url.Parse(r.URLs[i].URL)
if err != nil {
errs = append(errs, fmt.Errorf("%s is not a valid URL", r.URLs[i].URL))
} else {
if !strings.HasSuffix(u.Host, redditDomainSuffix) {
errs = append(errs, fmt.Errorf("invalid Reddit URL %s", r.URLs[i].URL))

Copilot uses AI. Check for mistakes.
}
}
}
} else {
if len(r.Queries) == 0 {
errs = append(errs, ErrRedditNoQueries)
}
if len(r.URLs) > 0 {
errs = append(errs, ErrRedditUrlsNotAllowed)
}
}

return errors.Join(errs...)
}

// ValidateForJobType validates Twitter arguments for a specific job type
func (r *RedditArguments) ValidateForJobType(jobType teetypes.JobType) error {
if err := r.Validate(); err != nil {
return err
}

// Validate QueryType against job-specific capabilities
return jobType.ValidateCapability(teetypes.Capability(r.QueryType))
}

// GetCapability returns the QueryType as a typed Capability
func (r *RedditArguments) GetCapability() teetypes.Capability {
return teetypes.Capability(r.QueryType)
}
178 changes: 178 additions & 0 deletions args/reddit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package args_test

import (
"encoding/json"
"time"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

"github.com/masa-finance/tee-types/args"
"github.com/masa-finance/tee-types/types"
)

var _ = Describe("RedditArguments", func() {
Describe("Unmarshalling", func() {
It("should set default values", func() {
redditArgs := &args.RedditArguments{}
jsonData := `{"type": "searchposts", "queries": ["test"]}`
err := json.Unmarshal([]byte(jsonData), redditArgs)
Expect(err).ToNot(HaveOccurred())
Expect(redditArgs.MaxItems).To(Equal(uint(10)))
Expect(redditArgs.MaxPosts).To(Equal(uint(10)))
Expect(redditArgs.MaxComments).To(Equal(uint(10)))
Expect(redditArgs.MaxCommunities).To(Equal(uint(2)))
Expect(redditArgs.MaxUsers).To(Equal(uint(2)))
Expect(redditArgs.Sort).To(Equal(types.RedditSortNew))
Expect(redditArgs.MaxResults).To(Equal(redditArgs.MaxItems))
})

It("should override default values", func() {
redditArgs := &args.RedditArguments{}
jsonData := `{"type": "searchposts", "queries": ["test"], "max_items": 20, "sort": "top"}`
err := json.Unmarshal([]byte(jsonData), redditArgs)
Expect(err).ToNot(HaveOccurred())
Expect(redditArgs.MaxItems).To(Equal(uint(20)))
Expect(redditArgs.Sort).To(Equal(types.RedditSortTop))
Expect(redditArgs.MaxResults).To(Equal(uint(20)))
})
})

Describe("Validation", func() {
It("should succeed with valid arguments", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditSearchPosts,
Queries: []string{"test"},
Sort: types.RedditSortNew,
}
err := redditArgs.Validate()
Expect(err).ToNot(HaveOccurred())
})

It("should succeed with valid scrapeurls arguments", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditScrapeUrls,
URLs: []types.RedditStartURL{
{URL: "https://www.reddit.com/r/golang/", Method: "GET"},
},
Sort: types.RedditSortNew,
}
err := redditArgs.Validate()
Expect(err).ToNot(HaveOccurred())
})

It("should fail with an invalid type", func() {
redditArgs := &args.RedditArguments{
QueryType: "invalidtype",
Queries: []string{"test"},
Sort: types.RedditSortNew,
}
err := redditArgs.Validate()
Expect(err).To(MatchError(args.ErrRedditInvalidType))
})

It("should fail with an invalid sort", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditSearchPosts,
Queries: []string{"test"},
Sort: "invalidsort",
}
err := redditArgs.Validate()
Expect(err).To(MatchError(args.ErrRedditInvalidSort))
})

It("should fail if the after time is in the future", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditSearchPosts,
Queries: []string{"test"},
Sort: types.RedditSortNew,
After: time.Now().Add(24 * time.Hour),
}
err := redditArgs.Validate()
Expect(err).To(MatchError(args.ErrRedditTimeInTheFuture))
})

It("should fail if queries are not provided for searchposts", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditSearchPosts,
Sort: types.RedditSortNew,
}
err := redditArgs.Validate()
Expect(err).To(MatchError(args.ErrRedditNoQueries))
})

It("should fail if urls are not provided for scrapeurls", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditScrapeUrls,
Sort: types.RedditSortNew,
}
err := redditArgs.Validate()
Expect(err).To(MatchError(args.ErrRedditNoUrls))
})

It("should fail if queries are provided for scrapeurls", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditScrapeUrls,
Queries: []string{"test"},
URLs: []types.RedditStartURL{
{URL: "https://www.reddit.com/r/golang/", Method: "GET"},
},
Sort: types.RedditSortNew,
}
err := redditArgs.Validate()
Expect(err).To(MatchError(args.ErrRedditQueriesNotAllowed))
})

It("should fail if urls are provided for searchposts", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditSearchPosts,
Queries: []string{"test"},
URLs: []types.RedditStartURL{
{URL: "https://www.reddit.com/r/golang/", Method: "GET"},
},
Sort: types.RedditSortNew,
}
err := redditArgs.Validate()
Expect(err).To(MatchError(args.ErrRedditUrlsNotAllowed))
})

It("should fail with an invalid URL", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditScrapeUrls,
URLs: []types.RedditStartURL{
{URL: "ht tp://invalid-url.com", Method: "GET"},
},
Sort: types.RedditSortNew,
}
err := redditArgs.Validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("is not a valid URL"))
})

It("should fail with an invalid domain", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditScrapeUrls,
URLs: []types.RedditStartURL{
{URL: "https://www.google.com", Method: "GET"},
},
Sort: types.RedditSortNew,
}
err := redditArgs.Validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("invalid Reddit URL"))
})

It("should fail with an invalid HTTP method", func() {
redditArgs := &args.RedditArguments{
QueryType: types.RedditScrapeUrls,
URLs: []types.RedditStartURL{
{URL: "https://www.reddit.com/r/golang/", Method: "INVALID"},
},
Sort: types.RedditSortNew,
}
err := redditArgs.Validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("is not a valid HTTP method"))
})
})
})
1 change: 1 addition & 0 deletions args/tiktok.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type TikTokTranscriptionArguments struct {

// UnmarshalJSON implements custom JSON unmarshaling with validation
func (t *TikTokTranscriptionArguments) UnmarshalJSON(data []byte) error {
// Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...)
type Alias TikTokTranscriptionArguments
aux := &struct {
*Alias
Expand Down
1 change: 1 addition & 0 deletions args/twitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type TwitterSearchArguments struct {

// UnmarshalJSON implements custom JSON unmarshaling with validation
func (t *TwitterSearchArguments) UnmarshalJSON(data []byte) error {
// Prevent infinite recursion (you call json.Unmarshal which then calls `UnmarshalJSON`, which then calls `json.Unmarshal`...)
type Alias TwitterSearchArguments
aux := &struct {
*Alias
Expand Down
Loading