Skip to content

Commit

Permalink
feat: Add max concurrency option
Browse files Browse the repository at this point in the history
  • Loading branch information
caffeine-addictt committed Apr 8, 2024
1 parent 5a0c56a commit e5a85f3
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions src/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ func (e *strategyEnum) Type() string {

// Command stuff
var getFlags struct {
inputFile string
strategy strategyEnum
inputFile string
strategy strategyEnum
maxConcurrency int64
}

var getCommand = &cobra.Command{
Expand Down Expand Up @@ -172,16 +173,42 @@ var getCommand = &cobra.Command{
// Handle downloading
switch getFlags.strategy {
case strategyConcurrent:
fmt.Println("Downloading concurrently...")

var waitGroup sync.WaitGroup
waitGroup.Add(len(argSet))

for url := range argSet {
go func(url string) {
defer waitGroup.Done()
downloadFile(url)
}(url)
// Concurrency with no limit
if getFlags.maxConcurrency == 0 {
fmt.Println("Downloading concurrently... [No limit]")
waitGroup.Add(len(argSet))

for url := range argSet {
go func(url string) {
defer waitGroup.Done()
downloadFile(url)
}(url)
}

// Concurrency with limit
} else {
fmt.Printf("Downloading concurrently... [Max: %d]\n", getFlags.maxConcurrency)
waitGroup.Add(int(getFlags.maxConcurrency))

// Establish channel and workers
ch := make(chan string)
for t := 0; t < int(getFlags.maxConcurrency); t++ {
go func() {
for url := range ch {
downloadFile(url)
}

waitGroup.Done()
}()
}

for url := range argSet {
ch <- url
}

close(ch)
}

waitGroup.Wait()
Expand All @@ -200,6 +227,7 @@ func init() {

rootCommand.AddCommand(getCommand)
getCommand.Flags().StringVarP(&getFlags.inputFile, "file", "f", "", "Path to the input file containing the url(s)")
getCommand.Flags().Int64VarP(&getFlags.maxConcurrency, "max-concurrency", "m", 10, "Maximum number of concurrent downloads [0 = unlimited] (default is 10)")
getCommand.Flags().VarP(&getFlags.strategy, "strategy", "s", "Strategy to use when downloading (default is concurrent)")
if err := getCommand.RegisterFlagCompletionFunc("strategy", strategyCompletion); err != nil {
fmt.Println("Failed to register completion for flag -s in get command")
Expand Down

0 comments on commit e5a85f3

Please sign in to comment.