diff --git a/registry/download.go b/registry/download.go index 81ad923082..8767dfb92f 100644 --- a/registry/download.go +++ b/registry/download.go @@ -3,6 +3,7 @@ package registry import ( "archive/zip" "context" + "errors" "fmt" "io" "net/http" @@ -12,7 +13,6 @@ import ( "time" "github.com/avast/retry-go/v4" - "github.com/schollz/progressbar/v3" ) @@ -26,14 +26,26 @@ const ( RetryWaitTime = 1 * time.Second ) +type pluginURL struct { + url string + monorepo bool +} + func DownloadPluginFromGithub(ctx context.Context, localPath string, org string, name string, version string, typ PluginType) error { downloadDir := filepath.Dir(localPath) pluginZipPath := localPath + ".zip" // https://github.com/cloudquery/cloudquery/releases/download/plugins-source-test-v1.1.5/test_darwin_amd64.zip - downloadURL := fmt.Sprintf("https://github.com/cloudquery/cloudquery/releases/download/plugins-%s-%s-%s/%s_%s_%s.zip", typ, name, version, name, runtime.GOOS, runtime.GOARCH) - if org != "cloudquery" { - // https://github.com/yevgenypats/cq-source-test/releases/download/v1.0.1/cq-source-test_darwin_amd64.zip - downloadURL = fmt.Sprintf("https://github.com/%s/cq-%s-%s/releases/download/%s/cq-%s-%s_%s_%s.zip", org, typ, name, version, typ, name, runtime.GOOS, runtime.GOARCH) + urls := []pluginURL{ + // community plugin format + {url: fmt.Sprintf("https://github.com/%s/cq-%s-%s/releases/download/%s/cq-%s-%s_%s_%s.zip", org, typ, name, version, typ, name, runtime.GOOS, runtime.GOARCH)}, + } + if org == "cloudquery" { + urls = append( + // CloudQuery monorepo plugin + []pluginURL{{url: fmt.Sprintf("https://github.com/cloudquery/cloudquery/releases/download/plugins-%s-%s-%s/%s_%s_%s.zip", typ, name, version, name, runtime.GOOS, runtime.GOARCH), monorepo: true}}, + // fall back to community plugin format if the plugin is not found in the monorepo + urls..., + ) } if _, err := os.Stat(localPath); err == nil { @@ -44,7 +56,7 @@ func DownloadPluginFromGithub(ctx context.Context, localPath string, org string, return fmt.Errorf("failed to create plugin directory %s: %w", downloadDir, err) } - err := downloadFile(ctx, pluginZipPath, downloadURL) + used, err := downloadFile(ctx, pluginZipPath, urls...) if err != nil { return fmt.Errorf("failed to download plugin: %w", err) } @@ -55,8 +67,10 @@ func DownloadPluginFromGithub(ctx context.Context, localPath string, org string, } defer archive.Close() - pathInArchive := fmt.Sprintf("plugins/%s/%s", typ, name) - if org != "cloudquery" { + var pathInArchive string + if used.monorepo { + pathInArchive = fmt.Sprintf("plugins/%s/%s", typ, name) + } else { pathInArchive = fmt.Sprintf("cq-%s-%s", typ, name) } pathInArchive = WithBinarySuffix(pathInArchive) @@ -80,57 +94,69 @@ func DownloadPluginFromGithub(ctx context.Context, localPath string, org string, return nil } -func downloadFile(ctx context.Context, localPath string, url string) (err error) { +func downloadFile(ctx context.Context, localPath string, urls ...pluginURL) (used pluginURL, err error) { // Create the file out, err := os.Create(localPath) if err != nil { - return fmt.Errorf("failed to create file %s: %w", localPath, err) + return pluginURL{}, fmt.Errorf("failed to create file %s: %w", localPath, err) } defer out.Close() - // Get the data - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return fmt.Errorf("failed create request %s: %w", url, err) + for _, url := range urls { + err = downloadFileFromURL(ctx, out, url.url) + if err != nil && err.Error() == "not found" { + continue + } + return url, err } + return pluginURL{}, fmt.Errorf("failed downloading from URL %v. Error %w", urls, err) +} - err = retry.Do( - func() error { - // Do http request - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("failed to get url %s: %w", url, err) - } - - // Check server response - if resp.StatusCode != http.StatusOK { - fmt.Printf("Failed downloading %s with status code %d. Retrying\n", url, resp.StatusCode) - return fmt.Errorf("statusCode != 200") - } - defer resp.Body.Close() - - fmt.Printf("Downloading %s\n", url) - bar := downloadProgressBar(resp.ContentLength, "Downloading") - - // Writer the body to file - _, err = io.Copy(io.MultiWriter(out, bar), resp.Body) - if err != nil { - return fmt.Errorf("failed to copy body to file %s: %w", localPath, err) - } - - return nil - }, - retry.RetryIf(func(err error) bool { - return err.Error() == "statusCode != 200" - }), +func downloadFileFromURL(ctx context.Context, out *os.File, url string) error { + err := retry.Do(func() error { + // Get the data + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed create request %s: %w", url, err) + } + + // Do http request + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to get url %s: %w", url, err) + } + defer resp.Body.Close() + // Check server response + if resp.StatusCode == http.StatusNotFound { + return errors.New("not found") + } else if resp.StatusCode != http.StatusOK { + fmt.Printf("Failed downloading %s with status code %d. Retrying\n", url, resp.StatusCode) + return errors.New("statusCode != 200") + } + + fmt.Printf("Downloading %s\n", url) + bar := downloadProgressBar(resp.ContentLength, "Downloading") + + // Writer the body to file + _, err = io.Copy(io.MultiWriter(out, bar), resp.Body) + if err != nil { + return fmt.Errorf("failed to copy body to file %s: %w", out.Name(), err) + } + return nil + }, retry.RetryIf(func(err error) bool { + return err.Error() == "statusCode != 200" + }), retry.Attempts(RetryAttempts), retry.Delay(RetryWaitTime), ) - if err != nil { + for _, e := range err.(retry.Error) { + if e.Error() == "not found" { + return e + } + } return fmt.Errorf("failed downloading URL %q. Error %w", url, err) } - return nil } diff --git a/registry/download_test.go b/registry/download_test.go new file mode 100644 index 0000000000..9c359c59a3 --- /dev/null +++ b/registry/download_test.go @@ -0,0 +1,35 @@ +package registry + +import ( + "context" + "path" + "testing" +) + +func TestDownloadPluginFromGithubIntegration(t *testing.T) { + tmp := t.TempDir() + cases := []struct { + name string + org string + plugin string + version string + pluginType PluginType + wantErr bool + }{ + {name: "monorepo source", org: "cloudquery", plugin: "hackernews", version: "v1.1.4", pluginType: PluginTypeSource}, + {name: "many repo source", org: "cloudquery", plugin: "simple-analytics", version: "v1.0.0", pluginType: PluginTypeSource}, + {name: "monorepo destination", org: "cloudquery", plugin: "postgresql", version: "v2.0.7", pluginType: PluginTypeDestination}, + {name: "community source", org: "hermanschaaf", plugin: "simple-analytics", version: "v1.0.0", pluginType: PluginTypeSource}, + {name: "invalid community source", org: "cloudquery", plugin: "invalid-plugin", version: "v0.0.x", pluginType: PluginTypeSource, wantErr: true}, + {name: "invalid monorepo source", org: "not-cloudquery", plugin: "invalid-plugin", version: "v0.0.x", pluginType: PluginTypeSource, wantErr: true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := DownloadPluginFromGithub(context.Background(), path.Join(tmp, tc.name), tc.org, tc.plugin, tc.version, tc.pluginType) + if (err != nil) != tc.wantErr { + t.Errorf("DownloadPluginFromGithub() error = %v, wantErr %v", err, tc.wantErr) + return + } + }) + } +}