diff --git a/main.go b/main.go index a367d14..4343cc3 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "path/filepath" + "strings" "sync" "time" @@ -133,6 +134,28 @@ func (d *Downloader) downloadFileWithTimeout(userCtx context.Context, u string) } } +func (d *Downloader) getDownloadSize(ctx context.Context, u string) (uint64, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, u, nil) + if err != nil { + return 0, fmt.Errorf("creating the request for %s: %w", u, err) + } + resp, err := d.client.Do(req) + if err != nil { + return 0, fmt.Errorf("sending get http request to %s: %w", u, err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return 0, fmt.Errorf("got unexpected http response status for %s: %s", u, resp.Status) + } + if resp.ContentLength <= 0 && resp.Header.Get("Content-Range") != "" { + var s uint64 + p := strings.Split(resp.Header.Get("Content-Range"), "/") + fmt.Sscan(p[len(p)-1], &s) + return s, nil + } + return uint64(resp.ContentLength), nil +} + func (d *Downloader) downloadFile(ctx context.Context, u string) ([]byte, error) { ch := make(chan []byte, 1) defer close(ch) @@ -193,9 +216,16 @@ func (d *Downloader) DownloadWithContext(ctx context.Context, urls ...string) <- wg.Add(1) go func(u string) { defer wg.Done() - s := DownloadStatus{URL: u} + path := filepath.Join(os.TempDir(), filepath.Base(u)) + s := DownloadStatus{URL: u, DownloadedFilePath: path} defer func() { ch <- s }() - s.DownloadedFilePath = filepath.Join(os.TempDir(), filepath.Base(u)) + t, err := d.getDownloadSize(ctx, u) // TODO: retry + if err != nil { + s.Error = fmt.Errorf("error getting file size: %w", err) + return + } + s.FileSizeBytes = t + ch <- s // send total file size to the user b, err := d.downloadFile(ctx, u) if err != nil { s.Error = err @@ -206,7 +236,6 @@ func (d *Downloader) DownloadWithContext(ctx context.Context, urls ...string) <- return } s.DownloadedFileBytes = uint64(len(b)) - s.FileSizeBytes = uint64(len(b)) }(u) } go func() { diff --git a/main_test.go b/main_test.go index 1a1f473..96f9cec 100644 --- a/main_test.go +++ b/main_test.go @@ -24,6 +24,9 @@ func TestDownload_Error(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { s := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + return + } tc.proc(w) }, )) @@ -36,11 +39,12 @@ func TestDownload_Error(t *testing.T) { WaitBetweenRetries: 0 * time.Second, } ch := d.Download(s.URL) - status := <-ch - if status.Error == nil { + <-ch // discard the first got (just the file size) + got := <-ch + if got.Error == nil { t.Error("expected an error, but got nil") } - if !strings.Contains(status.Error.Error(), "#4") { + if !strings.Contains(got.Error.Error(), "#4") { t.Error("expected #4 (configured number of retries), but did not get it") } if _, ok := <-ch; ok { @@ -59,6 +63,7 @@ func TestDownload_OkWithDefaultDownloader(t *testing.T) { defer s.Close() ch := DefaultDownloader().Download(s.URL) + <-ch // discard the first status (just the file size) got := <-ch defer os.Remove(got.DownloadedFilePath) @@ -99,6 +104,10 @@ func TestDownload_Retry(t *testing.T) { attempts := int32(0) s := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.Header().Add("Content-Length", "2") + return + } if atomic.CompareAndSwapInt32(&attempts, 0, 1) { tc.proc(w) } @@ -115,6 +124,7 @@ func TestDownload_Retry(t *testing.T) { WaitBetweenRetries: 0 * time.Second, } ch := d.Download(s.URL) + <-ch // discard the first status (just the file size) got := <-ch if got.Error != nil { t.Errorf("invalid error. want:nil got:%q", got.Error) @@ -150,6 +160,9 @@ func TestDownloadWithContext_ErrorUserTimeout(t *testing.T) { timeout := 10 * userTimeout s := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + return + } time.Sleep(2 * userTimeout) // this time is greater than the user timeout, but shorter than the timeout per chunk. }, )) @@ -165,11 +178,12 @@ func TestDownloadWithContext_ErrorUserTimeout(t *testing.T) { defer cancFunc() ch := d.DownloadWithContext(userCtx, s.URL) - status := <-ch - if status.Error == nil { + <-ch // discard the first got (just the file size) + got := <-ch + if got.Error == nil { t.Error("expected an error, but got nil") } - if !strings.Contains(status.Error.Error(), "#4") { + if !strings.Contains(got.Error.Error(), "#4") { t.Error("expected #4 (configured number of retries), but did not get it") } if _, ok := <-ch; ok { @@ -202,3 +216,5 @@ func TestDownload_Chunks(t *testing.T) { } } } + +// TODO: add tests for getDownloadSize (success with Content-Length, success with Content-Range, failure)