Skip to content

Commit

Permalink
feat: Add strategey flag to support concurrency or sync
Browse files Browse the repository at this point in the history
  • Loading branch information
caffeine-addictt committed Apr 7, 2024
1 parent 1d90c73 commit 86d0b76
Showing 1 changed file with 113 additions and 55 deletions.
168 changes: 113 additions & 55 deletions src/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package src

import (
"bufio"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -15,8 +16,36 @@ import (
"github.com/spf13/cobra"
)

// Strategy
type strategyEnum string

const (
strategySynchronous strategyEnum = "synchronous"
strategyConcurrent strategyEnum = "concurrent"
)

func (e *strategyEnum) String() string {
return string(*e)
}

func (e *strategyEnum) Set(value string) error {
switch value {
case "concurrent", "synchronous":
*e = strategyEnum(value)
return nil
default:
return errors.New("must be one of 'synchronous' or 'concurrent'")
}
}

func (e *strategyEnum) Type() string {
return "strategy"
}

// Command stuff
var getFlags struct {
inputFile string
strategy strategyEnum
}

var getCommand = &cobra.Command{
Expand Down Expand Up @@ -88,71 +117,100 @@ var getCommand = &cobra.Command{
os.Exit(1)
}

// Handle downloading
var waitGroup sync.WaitGroup
waitGroup.Add(len(argSet))

for url := range argSet {
go func(url string) {
defer waitGroup.Done()

split := strings.Split(url, "/")
downloadLocation := filepath.Clean(filepath.Join(dirPath, split[len(split)-1]))

// Ensure file already does not exist
Info("Checking if " + downloadLocation + " already exists")
if _, err := os.Stat(downloadLocation); err == nil {
fmt.Printf("File already exists for %s\n", downloadLocation)
Debug("File: " + downloadLocation + " already exists for " + url)
return
}
downloadFile := func(url string) {
split := strings.Split(url, "/")
downloadLocation := filepath.Clean(filepath.Join(dirPath, split[len(split)-1]))

// Get File
fmt.Printf("Downloading %s to %s", url, downloadLocation)
Info("Getting url: " + url)
// Ensure file already does not exist
Info("Checking if " + downloadLocation + " already exists")
if _, err := os.Stat(downloadLocation); err == nil {
fmt.Printf("File already exists for %s\n", downloadLocation)
Debug("File: " + downloadLocation + " already exists for " + url)
return
}

request, err := http.NewRequest(http.MethodGet, url, http.NoBody)
if err != nil {
fmt.Println("Failed to create request: " + url)
Debug(err.Error())
return
}
// Get File
fmt.Printf("Downloading %s to %s", url, downloadLocation)
Info("Getting url: " + url)

response, err := http.DefaultClient.Do(request)
if err != nil {
fmt.Println("Failed to get url: " + url)
Debug(err.Error())
return
}
defer response.Body.Close()

// Create file
Info("Creating file at: " + downloadLocation)
out, err := os.Create(downloadLocation)
if err != nil {
fmt.Println("Failed to create file at: " + downloadLocation)
Debug(err.Error())
return
}
defer out.Close()

// Write to file
Info("Writing " + url + " to " + downloadLocation)
if _, err := io.Copy(out, response.Body); err != nil {
fmt.Printf("Failed to write to file at: %s\n", downloadLocation)
Debug(err.Error())
return
}
request, err := http.NewRequest(http.MethodGet, url, http.NoBody)
if err != nil {
fmt.Println("Failed to create request: " + url)
Debug(err.Error())
return
}

fmt.Println("Downloaded " + url + " to " + downloadLocation)
}(url)
response, err := http.DefaultClient.Do(request)
if err != nil {
fmt.Println("Failed to get url: " + url)
Debug(err.Error())
return
}
defer response.Body.Close()

// Create file
Info("Creating file at: " + downloadLocation)
out, err := os.Create(downloadLocation)
if err != nil {
fmt.Println("Failed to create file at: " + downloadLocation)
Debug(err.Error())
return
}
defer out.Close()

// Write to file
Info("Writing " + url + " to " + downloadLocation)
if _, err := io.Copy(out, response.Body); err != nil {
fmt.Printf("Failed to write to file at: %s\n", downloadLocation)
Debug(err.Error())
return
}

fmt.Println("Downloaded " + url + " to " + downloadLocation)
}

waitGroup.Wait()
// Handle downloading
switch getFlags.strategy {
case strategyConcurrent:
fmt.Println("Downloading concurrently...")

var waitGroup sync.WaitGroup
waitGroup.Add(len(argSet))

for url := range argSet {
go func() {
defer waitGroup.Done()
downloadFile(url)
}()
}

waitGroup.Wait()
case strategySynchronous:
fmt.Println("Downloading synchronously...")

for url := range argSet {
downloadFile(url)
}
}
},
}

func init() {
getFlags.strategy = strategyConcurrent

rootCommand.AddCommand(getCommand)
getCommand.Flags().StringVarP(&getFlags.inputFile, "file", "f", "", "Path to the input file containing the url(s)")
getCommand.Flags().VarP(&getFlags.strategy, "strategy", "s", "Strategy to use when downloading (default is concurrent)")
if err := getCommand.RegisterFlagCompletionFunc("strategy", strategyCompletion); err != nil {
fmt.Println("Failed to register completion for flag -s in get command")
Debug(err.Error())
os.Exit(1)
}
}

func strategyCompletion(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
return []string{
"synchronous\tDownload videos sequentially",
"concurrent\tDownload videos concurrently DEFAULT",
}, cobra.ShellCompDirectiveNoFileComp
}

0 comments on commit 86d0b76

Please sign in to comment.