Skip to content

Commit

Permalink
feat: implement unthrottled concurrency using task queue
Browse files Browse the repository at this point in the history
- create worker goroutines specified through cli
- goroutines steal incoming tasks from a channel and execute them
- workers consume tasks instead of plain domain name strings
- a task consists of a domain and a provider

This approach spawns n green threads instead of n * len(providers).
Should prevent resource usage from blowing up and help scaling.
  • Loading branch information
lavafroth committed Sep 24, 2023
1 parent 85deaa7 commit 743bfea
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 109 deletions.
83 changes: 38 additions & 45 deletions cmd/gau/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,90 +2,83 @@ package main

import (
"bufio"
"io"
"os"
"sync"

"github.com/lc/gau/v2/pkg/output"
"github.com/lc/gau/v2/runner"
"github.com/lc/gau/v2/runner/flags"
log "github.com/sirupsen/logrus"
"io"
"os"
"sync"
)

func main() {
flag := flags.New()
cfg, err := flag.ReadInConfig()
cfg, err := flags.New().ReadInConfig()
if err != nil {
if cfg.Verbose {
log.Warnf("error reading config: %v", err)
}
}

pMap := make(runner.ProvidersMap)
for _, provider := range cfg.Providers {
pMap[provider] = cfg.Filters
log.Warnf("error reading config: %v", err)
}

config, err := cfg.ProviderConfig()
if err != nil {
log.Fatal(err)
}

gau := &runner.Runner{}
gau := new(runner.Runner)

if err = gau.Init(config, pMap); err != nil {
if err = gau.Init(config, cfg.Providers, cfg.Filters); err != nil {
log.Warn(err)
}

results := make(chan string)

var out io.Writer
// Handle results in background
if config.Output == "" {
out = os.Stdout
} else {
ofp, err := os.OpenFile(config.Output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if config.Output != "" {
out, err := os.OpenFile(config.Output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatalf("Could not open output file: %v\n", err)
}
defer ofp.Close()
out = ofp
defer out.Close()
} else {
out = os.Stdout
}

writeWg := &sync.WaitGroup{}
writeWg := new(sync.WaitGroup)
writeWg.Add(1)
if config.JSON {
go func() {
defer writeWg.Done()
go func(JSON bool) {
defer writeWg.Done()
if JSON {
output.WriteURLsJSON(out, results, config.Blacklist, config.RemoveParameters)
}()
} else {
go func() {
defer writeWg.Done()
if err = output.WriteURLs(out, results, config.Blacklist, config.RemoveParameters); err != nil {
log.Fatalf("error writing results: %v\n", err)
}
}()
}
} else if err = output.WriteURLs(out, results, config.Blacklist, config.RemoveParameters); err != nil {
log.Fatalf("error writing results: %v\n", err)
}
}(config.JSON)

domains := make(chan string)
gau.Start(domains, results)
workChan := make(chan runner.Work)
gau.Start(workChan, results)

if len(flags.Args()) > 0 {
for _, domain := range flags.Args() {
domains <- domain
domains := flags.Args()
if len(domains) > 0 {
for _, provider := range gau.Providers {
for _, domain := range domains {
workChan <- runner.NewWork(domain, provider)
}
}
} else {
sc := bufio.NewScanner(os.Stdin)
for sc.Scan() {
domains <- sc.Text()
}
for _, provider := range gau.Providers {
for sc.Scan() {
workChan <- runner.NewWork(sc.Text(), provider)

if err := sc.Err(); err != nil {
log.Fatal(err)
if err := sc.Err(); err != nil {
log.Fatal(err)
}
}
}

}

close(domains)
close(workChan)

// wait for providers to fetch URLS
gau.Wait()
Expand Down
3 changes: 1 addition & 2 deletions pkg/providers/commoncrawl/commoncrawl.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ func (c *Client) Fetch(ctx context.Context, domain string, results chan string)
return nil
}

paginate:
for page := uint(0); page < p.Pages; page++ {
select {
case <-ctx.Done():
break paginate
return nil
default:
logrus.WithFields(logrus.Fields{"provider": Name, "page": page}).Infof("fetching %s", domain)
apiURL := c.formatURL(domain, page)
Expand Down
6 changes: 2 additions & 4 deletions pkg/providers/otx/otx.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ func (c *Client) Name() string {
}

func (c *Client) Fetch(ctx context.Context, domain string, results chan string) error {
paginate:
for page := uint(1); ; page++ {
select {
case <-ctx.Done():
break paginate
return nil
default:
logrus.WithFields(logrus.Fields{"provider": Name, "page": page - 1}).Infof("fetching %s", domain)
apiURL := c.formatURL(domain, page)
Expand All @@ -68,11 +67,10 @@ paginate:
}

if !result.HasNext {
break paginate
return nil
}
}
}
return nil
}

func (c *Client) formatURL(domain string, page uint) string {
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/valyala/fasthttp"
)

const Version = `2.1.2`
const Version = `2.2.0`

// Provider is a generic interface for all archive fetchers
type Provider interface {
Expand Down
17 changes: 6 additions & 11 deletions pkg/providers/urlscan/urlscan.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ const (
Name = "urlscan"
)

var _ providers.Provider = (*Client)(nil)

type Client struct {
config *providers.Config
}
Expand All @@ -41,11 +39,10 @@ func (c *Client) Fetch(ctx context.Context, domain string, results chan string)
header.Value = c.config.URLScan.APIKey
}

paginate:
for page := uint(0); ; page++ {
select {
case <-ctx.Done():
break paginate
return nil
default:
logrus.WithFields(logrus.Fields{"provider": Name, "page": page}).Infof("fetching %s", domain)
apiURL := c.formatURL(domain, searchAfter)
Expand All @@ -62,7 +59,7 @@ paginate:
// rate limited
if result.Status == 429 {
logrus.WithField("provider", "urlscan").Warnf("urlscan responded with 429, probably being rate limited")
break paginate
return nil
}

total := len(result.Results)
Expand All @@ -73,20 +70,18 @@ paginate:

if i == total-1 {
sortParam := parseSort(res.Sort)
if sortParam != "" {
searchAfter = sortParam
} else {
break paginate
if sortParam == "" {
return nil
}
searchAfter = sortParam
}
}

if !result.HasMore {
break paginate
return nil
}
}
}
return nil
}

func (c *Client) formatURL(domain string, after string) string {
Expand Down
9 changes: 2 additions & 7 deletions runner/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,8 @@ func (o *Options) getFlagValues(c *Config) {
c.RemoveParameters = fp
}

if json {
c.JSON = true
}

if verbose {
c.Verbose = verbose
}
c.JSON = json
c.Verbose = verbose

// get filter flags
mc := o.viper.GetStringSlice("mc")
Expand Down
72 changes: 33 additions & 39 deletions runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,81 +13,75 @@ import (
)

type Runner struct {
providers []providers.Provider
wg sync.WaitGroup
sync.WaitGroup

config *providers.Config
Providers []providers.Provider
threads uint
ctx context.Context
cancelFunc context.CancelFunc
}

type ProvidersMap map[string]providers.Filters

// Init initializes the runner
func (r *Runner) Init(c *providers.Config, providerMap ProvidersMap) error {
r.config = c
func (r *Runner) Init(c *providers.Config, providers []string, filters providers.Filters) error {
r.threads = c.Threads
r.ctx, r.cancelFunc = context.WithCancel(context.Background())

for name, filters := range providerMap {
for _, name := range providers {
switch name {
case "urlscan":
r.providers = append(r.providers, urlscan.New(c))
r.Providers = append(r.Providers, urlscan.New(c))
case "otx":
o := otx.New(c)
r.providers = append(r.providers, o)
r.Providers = append(r.Providers, otx.New(c))
case "wayback":
r.providers = append(r.providers, wayback.New(c, filters))
r.Providers = append(r.Providers, wayback.New(c, filters))
case "commoncrawl":
cc, err := commoncrawl.New(c, filters)
if err != nil {
return fmt.Errorf("error instantiating commoncrawl: %v\n", err)
}
r.providers = append(r.providers, cc)
r.Providers = append(r.Providers, cc)
}
}

return nil
}

// Starts starts the worker
func (r *Runner) Start(domains chan string, results chan string) {
for i := uint(0); i < r.config.Threads; i++ {
r.wg.Add(1)
func (r *Runner) Start(workChan chan Work, results chan string) {
for i := uint(0); i < r.threads; i++ {
r.Add(1)
go func() {
defer r.wg.Done()
r.worker(r.ctx, domains, results)
defer r.Done()
r.worker(r.ctx, workChan, results)
}()
}
}

// Wait waits for the providers to finish fetching
func (r *Runner) Wait() {
r.wg.Wait()
type Work struct {
domain string
provider providers.Provider
}

func NewWork(domain string, provider providers.Provider) Work {
return Work{domain, provider}
}

func (w *Work) Do(ctx context.Context, results chan string) error {
return w.provider.Fetch(ctx, w.domain, results)
}

// worker checks to see if the context is finished and executes the fetching process for each provider
func (r *Runner) worker(ctx context.Context, domains chan string, results chan string) {
work:
func (r *Runner) worker(ctx context.Context, workChan chan Work, results chan string) {
for {
select {
case <-ctx.Done():
break work
case domain, ok := <-domains:
if ok {
var wg sync.WaitGroup
for _, p := range r.providers {
wg.Add(1)
go func(p providers.Provider) {
defer wg.Done()
if err := p.Fetch(ctx, domain, results); err != nil {
logrus.WithField("provider", p.Name()).Warnf("%s - %v", domain, err)
}
}(p)
}
wg.Wait()
}
return
case work, ok := <-workChan:
if !ok {
break work
return
}
if err := work.Do(ctx, results); err != nil {
logrus.WithField("provider", work.provider.Name()).Warnf("%s - %v", work.domain, err)
}
}
}
Expand Down

0 comments on commit 743bfea

Please sign in to comment.