From 693f20100f748d954f6ca682ba157985faafb1f2 Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Wed, 11 May 2022 22:01:22 -0400 Subject: [PATCH 01/13] Fix command injection in go-getter when passing params to hg clone The fix for this is to add -- to the arguments of each hg command, before any user-input. This indicates the end of optional arguments, only positional arguments are allowed. Test Results Before Change ``` ~> go test ./... -run=TestHg -v === RUN TestHgGetter_impl --- PASS: TestHgGetter_impl (0.00s) === RUN TestHgGetter --- PASS: TestHgGetter (0.60s) === RUN TestHgGetter_branch --- PASS: TestHgGetter_branch (0.96s) === RUN TestHgGetter_GetFile --- PASS: TestHgGetter_GetFile (0.61s) === RUN TestHgGetter_HgArgumentsNotAllowed === RUN TestHgGetter_HgArgumentsNotAllowed/arguments_allowed_in_destination get_hg_test.go:144: Expected no err, got: error running /usr/local/bin/hg: === RUN TestHgGetter_HgArgumentsNotAllowed/arguments_passed_into_rev_parameter get_hg_test.go:163: Expected no err, got: /usr/local/bin/hg exited with 1: === RUN TestHgGetter_HgArgumentsNotAllowed/arguments_passed_in_the_repository_URL get_hg_test.go:182: Expected no err, got: /usr/local/bin/hg exited with 255: hg clone: option -U not recognized alias 'clone' resolves to unknown command 'false' --- FAIL: TestHgGetter_HgArgumentsNotAllowed (1.02s) --- FAIL: TestHgGetter_HgArgumentsNotAllowed/arguments_allowed_in_destination (0.15s) --- FAIL: TestHgGetter_HgArgumentsNotAllowed/arguments_passed_into_rev_parameter (0.56s) --- FAIL: TestHgGetter_HgArgumentsNotAllowed/arguments_passed_in_the_repository_URL (0.31s) FAIL ``` Test Results After Change ``` ~> go test ./... -run=TestHg -v === RUN TestHgGetter_impl --- PASS: TestHgGetter_impl (0.00s) === RUN TestHgGetter --- PASS: TestHgGetter (0.61s) === RUN TestHgGetter_branch --- PASS: TestHgGetter_branch (0.99s) === RUN TestHgGetter_GetFile --- PASS: TestHgGetter_GetFile (0.61s) === RUN TestHgGetter_HgArgumentsNotAllowed === RUN TestHgGetter_HgArgumentsNotAllowed/arguments_allowed_in_destination === RUN TestHgGetter_HgArgumentsNotAllowed/arguments_passed_into_rev_parameter === RUN TestHgGetter_HgArgumentsNotAllowed/arguments_passed_in_the_repository_URL --- PASS: TestHgGetter_HgArgumentsNotAllowed (1.37s) --- PASS: TestHgGetter_HgArgumentsNotAllowed/arguments_allowed_in_destination (0.62s) --- PASS: TestHgGetter_HgArgumentsNotAllowed/arguments_passed_into_rev_parameter (0.61s) --- PASS: TestHgGetter_HgArgumentsNotAllowed/arguments_passed_in_the_repository_URL (0.15s) PASS ``` --- get_hg.go | 4 +-- get_hg_test.go | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/get_hg.go b/get_hg.go index 8cc4fe4c..ebad2133 100644 --- a/get_hg.go +++ b/get_hg.go @@ -103,7 +103,7 @@ func (g *HgGetter) GetFile(ctx context.Context, req *Request) error { } func (g *HgGetter) clone(dst string, u *url.URL) error { - cmd := exec.Command("hg", "clone", "-U", u.String(), dst) + cmd := exec.Command("hg", "clone", "-U", "--", u.String(), dst) return getRunCommand(cmd) } @@ -116,7 +116,7 @@ func (g *HgGetter) pull(dst string, u *url.URL) error { func (g *HgGetter) update(ctx context.Context, dst string, u *url.URL, rev string) error { args := []string{"update"} if rev != "" { - args = append(args, rev) + args = append(args, "--", rev) } cmd := exec.CommandContext(ctx, "hg", args...) diff --git a/get_hg_test.go b/get_hg_test.go index 41486bc3..74f0e930 100644 --- a/get_hg_test.go +++ b/get_hg_test.go @@ -2,9 +2,11 @@ package getter import ( "context" + "net/url" "os" "os/exec" "path/filepath" + "strings" "testing" testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" @@ -118,3 +120,83 @@ func TestHgGetter_GetFile(t *testing.T) { } testing_helper.AssertContents(t, dst, "Hello\n") } +func TestHgGetter_HgArgumentsNotAllowed(t *testing.T) { + if !testHasHg { + t.Log("hg not found, skipping") + t.Skip() + } + ctx := context.Background() + + tc := []struct { + name string + req Request + errChk func(testing.TB, error) + }{ + { + // If arguments are allowed in the destination, this request to Get will fail + name: "arguments allowed in destination", + req: Request{ + Dst: "--config=alias.clone=!touch ./TEST", + u: testModuleURL("basic-hg"), + }, + errChk: func(t testing.TB, err error) { + if err != nil { + t.Errorf("Expected no err, got: %s", err) + } + }, + }, + { + // Test arguments passed into the `rev` parameter + // This clone call will fail regardless, but an exit code of 1 indicates + // that the `false` command executed + // We are expecting an hg parse error + name: "arguments passed into rev parameter", + req: Request{ + u: testModuleURL("basic-hg?rev=--config=alias.update=!false"), + }, + errChk: func(t testing.TB, err error) { + if err == nil { + return + } + + if !strings.Contains(err.Error(), "hg: parse error") { + t.Errorf("Expected no err, got: %s", err) + } + }, + }, + { + // Test arguments passed in the repository URL + // This Get call will fail regardless, but it should fail + // because the repository can't be found. + // Other failures indicate that hg interpreted the argument passed in the URL + name: "arguments passed in the repository URL", + req: Request{ + u: &url.URL{Path: "--config=alias.clone=false"}, + }, + errChk: func(t testing.TB, err error) { + if err == nil { + return + } + + if !strings.Contains(err.Error(), "repository --config=alias.clone=false not found") { + t.Errorf("Expected no err, got: %s", err) + } + }, + }, + } + for _, tt := range tc { + tt := tt + t.Run(tt.name, func(t *testing.T) { + g := new(HgGetter) + + if tt.req.Dst == "" { + dst := testing_helper.TempDir(t) + tt.req.Dst = dst + } + + defer os.RemoveAll(tt.req.Dst) + err := g.Get(ctx, &tt.req) + tt.errChk(t, err) + }) + } +} From 6b7623c8ee246c67e027d73639197812bdae29f5 Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Wed, 11 May 2022 23:09:58 -0400 Subject: [PATCH 02/13] Remove upwards path traversal in subdirectories, filenames * Prevent arbitrary file read, path traversal via subdirectory extraction Not opt-in or opt-out, just never allowed. Upwards path traversal is not a subdirectory. *Prevent arbitrary file write via `filename` Not opt-in or opt-out, just never allowed. Upwards path traversal is not a filename in a subdirectory. --- client.go | 16 ++++++++++++++++ get_test.go | 21 +++++++++++++++++++++ source.go | 4 +++- 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index fd6a0cba..932d351b 100644 --- a/client.go +++ b/client.go @@ -50,6 +50,18 @@ func (c *Client) Get(ctx context.Context, req *Request) (*GetResult, error) { // and then copy over the proper subdir. req.Src, req.subDir = SourceDirSubdir(req.Src) if req.subDir != "" { + // Check if the subdirectory is attempting to traverse upwards, outside of + // the cloned repository path. + req.subDir = filepath.Clean(req.subDir) + if containsDotDot(req.subDir) { + return nil, fmt.Errorf("subdirectory component contain path traversal out of the repository") + } + + // Prevent absolute paths, remove a leading path separator from the subdirectory + if req.subDir[0] == os.PathSeparator { + req.subDir = req.subDir[1:] + } + td, tdcloser, err := safetemp.Dir("", "getter") if err != nil { return nil, err @@ -199,6 +211,10 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * filename = v } + if containsDotDot(filename) { + return nil, &getError{true, fmt.Errorf("filename query parameter contain path traversal")} + } + req.Dst = filepath.Join(req.Dst, filename) } } diff --git a/get_test.go b/get_test.go index 1a340c64..d3383443 100644 --- a/get_test.go +++ b/get_test.go @@ -348,6 +348,27 @@ func TestGetFile_archive(t *testing.T) { // Verify the main file exists testing_helper.AssertContents(t, dst, "Hello\n") } +func TestGetFile_filename_path_traversal(t *testing.T) { + dst := testing_helper.TempDir(t) + u := testModule("basic-file/foo.txt") + + u += "?filename=../../../../../../../../../../../../../tmp/bar.txt" + + ctx := context.Background() + op, err := GetAny(ctx, dst, u) + + if op != nil { + t.Fatalf("unexpected op: %v", op) + } + + if err == nil { + t.Fatalf("expected error") + } + + if !strings.Contains(err.Error(), "filename query parameter contain path traversal") { + t.Fatalf("unexpected err: %s", err) + } +} func TestGetFile_archiveChecksum(t *testing.T) { ctx := context.Background() diff --git a/source.go b/source.go index dab6d400..48ac9234 100644 --- a/source.go +++ b/source.go @@ -58,7 +58,9 @@ func SourceDirSubdir(src string) (string, string) { // // The returned path is the full absolute path. func SubdirGlob(dst, subDir string) (string, error) { - matches, err := filepath.Glob(filepath.Join(dst, subDir)) + pattern := filepath.Join(dst, subDir) + + matches, err := filepath.Glob(pattern) if err != nil { return "", err } From 32e23926a0c52ebe603e1d43a90e7a858f969c9c Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Thu, 12 May 2022 00:21:41 -0400 Subject: [PATCH 03/13] Add Timeout option to HgGetter and GitGetter enforced with os/exec.CommandContext --- detect_test.go | 11 ++++++----- get.go | 13 +++++++------ get_git.go | 35 +++++++++++++++++++++++------------ get_git_test.go | 38 +++++++++++++++++++++----------------- get_hg.go | 27 ++++++++++++++++++++------- get_hg_test.go | 26 ++++++++++++++++++++++++-- get_test.go | 2 +- 7 files changed, 102 insertions(+), 50 deletions(-) diff --git a/detect_test.go b/detect_test.go index bc7933a8..c0a555e7 100644 --- a/detect_test.go +++ b/detect_test.go @@ -6,11 +6,12 @@ import ( ) func TestDetect(t *testing.T) { - gitGetter := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + gitGetter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } cases := []struct { Input string diff --git a/get.go b/get.go index f2641456..3c31932f 100644 --- a/get.go +++ b/get.go @@ -76,12 +76,13 @@ func init() { // The order of the Getters in the list may affect the result // depending if the Request.Src is detected as valid by multiple getters Getters = []Getter{ - &GitGetter{[]Detector{ - new(GitHubDetector), - new(GitDetector), - new(BitBucketDetector), - new(GitLabDetector), - }, + &GitGetter{ + Detectors: []Detector{ + new(GitHubDetector), + new(GitDetector), + new(BitBucketDetector), + new(GitLabDetector), + }, }, new(HgGetter), new(SmbClientGetter), diff --git a/get_git.go b/get_git.go index 31433912..b8afecc9 100644 --- a/get_git.go +++ b/get_git.go @@ -14,6 +14,7 @@ import ( "runtime" "strconv" "strings" + "time" urlhelper "github.com/hashicorp/go-getter/v2/helper/url" safetemp "github.com/hashicorp/go-safetemp" @@ -24,6 +25,10 @@ import ( // a git repository. type GitGetter struct { Detectors []Detector + + // Timeout sets a deadline which all hg CLI operations should + // complete within. Defaults to zero which means no timeout. + Timeout time.Duration } var defaultBranchRegexp = regexp.MustCompile(`\s->\sorigin/(.*)`) @@ -71,10 +76,16 @@ func (g *GitGetter) Get(ctx context.Context, req *Request) error { req.u.RawQuery = q.Encode() } + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + var sshKeyFile string if sshKey != "" { // Check that the git version is sufficiently new. - if err := checkGitVersion("2.3"); err != nil { + if err := checkGitVersion(ctx, "2.3"); err != nil { return fmt.Errorf("Error using ssh key: %v", err) } @@ -121,7 +132,7 @@ func (g *GitGetter) Get(ctx context.Context, req *Request) error { // Next: check out the proper tag/branch if it is specified, and checkout if ref != "" { - if err := g.checkout(req.Dst, ref); err != nil { + if err := g.checkout(ctx, req.Dst, ref); err != nil { return err } } @@ -163,8 +174,8 @@ func (g *GitGetter) GetFile(ctx context.Context, req *Request) error { return fg.GetFile(ctx, req) } -func (g *GitGetter) checkout(dst string, ref string) error { - cmd := exec.Command("git", "checkout", ref) +func (g *GitGetter) checkout(ctx context.Context, dst string, ref string) error { + cmd := exec.CommandContext(ctx, "git", "checkout", ref) cmd.Dir = dst return getRunCommand(cmd) } @@ -192,18 +203,18 @@ func (g *GitGetter) update(ctx context.Context, dst, sshKeyFile, ref string, dep // Not a branch, switch to default branch. This will also catch // non-existent branches, in which case we want to switch to default // and then checkout the proper branch later. - ref = findDefaultBranch(dst) + ref = findDefaultBranch(ctx, dst) } // We have to be on a branch to pull - if err := g.checkout(dst, ref); err != nil { + if err := g.checkout(ctx, dst, ref); err != nil { return err } if depth > 0 { - cmd = exec.Command("git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") } else { - cmd = exec.Command("git", "pull", "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--ff-only") } cmd.Dir = dst @@ -226,9 +237,9 @@ func (g *GitGetter) fetchSubmodules(ctx context.Context, dst, sshKeyFile string, // findDefaultBranch checks the repo's origin remote for its default branch // (generally "master"). "master" is returned if an origin default branch // can't be determined. -func findDefaultBranch(dst string) string { +func findDefaultBranch(ctx context.Context, dst string) string { var stdoutbuf bytes.Buffer - cmd := exec.Command("git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") + cmd := exec.CommandContext(ctx, "git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") cmd.Dir = dst cmd.Stdout = &stdoutbuf err := cmd.Run() @@ -278,13 +289,13 @@ func setupGitEnv(cmd *exec.Cmd, sshKeyFile string) { // checkGitVersion is used to check the version of git installed on the system // against a known minimum version. Returns an error if the installed version // is older than the given minimum. -func checkGitVersion(min string) error { +func checkGitVersion(ctx context.Context, min string) error { want, err := version.NewVersion(min) if err != nil { return err } - out, err := exec.Command("git", "version").Output() + out, err := exec.CommandContext(ctx, "git", "version").Output() if err != nil { return err } diff --git a/get_git_test.go b/get_git_test.go index cde0ba68..36def9e0 100644 --- a/get_git_test.go +++ b/get_git_test.go @@ -342,12 +342,13 @@ func TestGitGetter_gitVersion(t *testing.T) { os.Setenv("PATH", dir) // Asking for a higher version throws an error - if err := checkGitVersion("2.3"); err == nil { + ctx := context.Background() + if err := checkGitVersion(ctx, "2.3"); err == nil { t.Fatal("expect git version error") } // Passes when version is satisfied - if err := checkGitVersion("1.9"); err != nil { + if err := checkGitVersion(ctx, "1.9"); err != nil { t.Fatal(err) } } @@ -411,11 +412,12 @@ func TestGitGetter_sshSCPStyle(t *testing.T) { GetMode: ModeDir, } - getter := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } client := &Client{ Getters: []Getter{getter}, @@ -623,11 +625,12 @@ func TestGitGetter_GitHubDetector(t *testing.T) { } pwd := "/pwd" - f := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + f := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } for i, tc := range cases { req := &Request{ @@ -704,11 +707,12 @@ func TestGitGetter_Detector(t *testing.T) { } pwd := "/pwd" - getter := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } for _, tc := range cases { t.Run(tc.Input, func(t *testing.T) { diff --git a/get_hg.go b/get_hg.go index ebad2133..f0979af0 100644 --- a/get_hg.go +++ b/get_hg.go @@ -8,6 +8,7 @@ import ( "os/exec" "path/filepath" "runtime" + "time" urlhelper "github.com/hashicorp/go-getter/v2/helper/url" safetemp "github.com/hashicorp/go-safetemp" @@ -15,7 +16,12 @@ import ( // HgGetter is a Getter implementation that will download a module from // a Mercurial repository. -type HgGetter struct{} +type HgGetter struct { + + // Timeout sets a deadline which all hg CLI operations should + // complete within. Defaults to zero which means no timeout. + Timeout time.Duration +} func (g *HgGetter) Mode(ctx context.Context, _ *url.URL) (Mode, error) { return ModeDir, nil @@ -49,13 +55,20 @@ func (g *HgGetter) Get(ctx context.Context, req *Request) error { if err != nil && !os.IsNotExist(err) { return err } + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + if err != nil { - if err := g.clone(req.Dst, newURL); err != nil { + if err := g.clone(ctx, req.Dst, newURL); err != nil { return err } } - if err := g.pull(req.Dst, newURL); err != nil { + if err := g.pull(ctx, req.Dst, newURL); err != nil { return err } @@ -102,13 +115,13 @@ func (g *HgGetter) GetFile(ctx context.Context, req *Request) error { return fg.GetFile(ctx, req) } -func (g *HgGetter) clone(dst string, u *url.URL) error { - cmd := exec.Command("hg", "clone", "-U", "--", u.String(), dst) +func (g *HgGetter) clone(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "clone", "-U", "--", u.String(), dst) return getRunCommand(cmd) } -func (g *HgGetter) pull(dst string, u *url.URL) error { - cmd := exec.Command("hg", "pull") +func (g *HgGetter) pull(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "pull") cmd.Dir = dst return getRunCommand(cmd) } diff --git a/get_hg_test.go b/get_hg_test.go index 74f0e930..c729a486 100644 --- a/get_hg_test.go +++ b/get_hg_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "strings" "testing" + "time" testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" ) @@ -171,8 +172,7 @@ func TestHgGetter_HgArgumentsNotAllowed(t *testing.T) { // Other failures indicate that hg interpreted the argument passed in the URL name: "arguments passed in the repository URL", req: Request{ - u: &url.URL{Path: "--config=alias.clone=false"}, - }, + u: &url.URL{Path: "--config=alias.clone=false"}}, errChk: func(t testing.TB, err error) { if err == nil { return @@ -200,3 +200,25 @@ func TestHgGetter_HgArgumentsNotAllowed(t *testing.T) { }) } } + +func TestHgGetter_GetWithTimeout(t *testing.T) { + if !testHasHg { + t.Log("hg not found, skipping") + t.Skip() + } + ctx := context.Background() + g := &HgGetter{ + Timeout: 1 * time.Millisecond, + } + + dst := testing_helper.TempDir(t) + defer os.RemoveAll(filepath.Dir(dst)) + req := &Request{ + Dst: dst, + u: testModuleURL("basic-hg/foo.txt"), + } + + if err := g.Get(ctx, req); err == nil { + t.Fatalf("err: %s", err.Error()) + } +} diff --git a/get_test.go b/get_test.go index d3383443..0ee215aa 100644 --- a/get_test.go +++ b/get_test.go @@ -784,7 +784,7 @@ func TestGetFile_inplace_badChecksum(t *testing.T) { } } -func TestgetForcedGetter(t *testing.T) { +func TestGetForcedGetter(t *testing.T) { type args struct { src string } From 77a06dd815b70bd443ce2616d82b2c9a748e834e Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Thu, 12 May 2022 16:50:49 -0400 Subject: [PATCH 04/13] Add DisableSymlinks option to getter request The fix for this is a new client request option, DisableSymlinks. When set to true, symlinks are disabled. This prevents the client, likely in combination with the GitGetter, from following a symlink when the subdirectory selection from the checked out repo is a symlink. * Add custom symlink copy error * Add DisableSymlinks as client option Setting DisableSymlinks per request works but must be set on all request made by a client. Adding it as a top-level client config option allows for setting DisableSymlinks for all client.Get requests. --- client.go | 18 ++++++++- cmd/go-getter/main.go | 7 +++- copy_dir.go | 10 ++++- get_file_copy.go | 15 +++++++- get_git_test.go | 85 +++++++++++++++++++++++++++++++++++++++++++ get_http.go | 2 +- request.go | 4 ++ 7 files changed, 134 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index 932d351b..be9ff3d5 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package getter import ( "context" + "errors" "fmt" "io/ioutil" "os" @@ -14,6 +15,9 @@ import ( safetemp "github.com/hashicorp/go-safetemp" ) +// ErrSymlinkCopy means that a copy of a symlink was encountered on a request with DisableSymlinks enabled. +var ErrSymlinkCopy = errors.New("copying of symlinks has been disabled") + // Client is a client for downloading things. // // Top-level functions such as Get are shortcuts for interacting with a client. @@ -27,6 +31,10 @@ type Client struct { // Getters is the list of protocols supported by this client. If this // is nil, then the default Getters variable will be used. Getters []Getter + + // Disable symlinks is used to prevent copying or writing files through symlinks for Get requests. + // When set to true any copying or writing through symlinks will result in a ErrSymlinkCopy error. + DisableSymlinks bool } // GetResult is the result of a Client.Get @@ -46,9 +54,15 @@ func (c *Client) Get(ctx context.Context, req *Request) (*GetResult, error) { req.GetMode = ModeAny } + // Client setting takes precedence for all requests + if c.DisableSymlinks { + req.DisableSymlinks = true + } + // If there is a subdir component, then we download the root separately // and then copy over the proper subdir. req.Src, req.subDir = SourceDirSubdir(req.Src) + if req.subDir != "" { // Check if the subdirectory is attempting to traverse upwards, outside of // the cloned repository path. @@ -135,7 +149,7 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * // Determine if we have an archive type archiveV := q.Get("archive") if archiveV != "" { - // Delete the paramter since it is a magic parameter we don't + // Delete the parameter since it is a magic parameter we don't // want to pass on to the Getter q.Del("archive") req.u.RawQuery = q.Encode() @@ -300,7 +314,7 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * return nil, &getError{true, err} } - err = copyDir(ctx, req.realDst, subDir, false, req.umask()) + err = copyDir(ctx, req.realDst, subDir, false, req.DisableSymlinks, req.umask()) if err != nil { return nil, &getError{false, err} } diff --git a/cmd/go-getter/main.go b/cmd/go-getter/main.go index cef8f714..85dede6e 100644 --- a/cmd/go-getter/main.go +++ b/cmd/go-getter/main.go @@ -16,6 +16,7 @@ import ( func main() { modeRaw := flag.String("mode", "any", "get mode (any, file, dir)") progress := flag.Bool("progress", false, "display terminal progress") + noSymlinks := flag.Bool("disable-symlinks", false, "prevent copying or writing files through symlinks") flag.Parse() args := flag.Args() if len(args) < 2 { @@ -54,12 +55,16 @@ func main() { if *progress { req.ProgressListener = defaultProgressBar } - wg := sync.WaitGroup{} wg.Add(1) client := getter.DefaultClient + // Disable symlinks for all client requests + if *noSymlinks { + client.DisableSymlinks = true + } + getters := getter.Getters getters = append(getters, new(gcs.Getter)) getters = append(getters, new(s3.Getter)) diff --git a/copy_dir.go b/copy_dir.go index cb1abb06..c4675148 100644 --- a/copy_dir.go +++ b/copy_dir.go @@ -11,7 +11,7 @@ import ( // should already exist. // // If ignoreDot is set to true, then dot-prefixed files/folders are ignored. -func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask os.FileMode) error { +func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, disableSymlinks bool, umask os.FileMode) error { src, err := filepath.EvalSymlinks(src) if err != nil { return err @@ -34,6 +34,12 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask } } + if disableSymlinks { + if info.Mode()&os.ModeSymlink == os.ModeSymlink { + return ErrSymlinkCopy + } + } + // The "path" has the src prefixed to it. We need to join our // destination with the path without the src on it. dstPath := filepath.Join(dst, path[len(src):]) @@ -54,7 +60,7 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask } // If we have a file, copy the contents. - _, err = copyFile(ctx, dstPath, path, info.Mode(), umask) + _, err = copyFile(ctx, dstPath, path, disableSymlinks, info.Mode(), umask) return err } diff --git a/get_file_copy.go b/get_file_copy.go index 29abbd1a..be44ff35 100644 --- a/get_file_copy.go +++ b/get_file_copy.go @@ -2,6 +2,7 @@ package getter import ( "context" + "fmt" "io" "os" ) @@ -49,7 +50,19 @@ func copyReader(dst string, src io.Reader, fmode, umask os.FileMode) error { } // copyFile copies a file in chunks from src path to dst path, using umask to create the dst file -func copyFile(ctx context.Context, dst, src string, fmode, umask os.FileMode) (int64, error) { +func copyFile(ctx context.Context, dst, src string, disableSymlinks bool, fmode, umask os.FileMode) (int64, error) { + + if disableSymlinks { + fileInfo, err := os.Lstat(src) + if err != nil { + return 0, fmt.Errorf("failed to check copy file source for symlinks: %w", err) + } + + if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink { + return 0, ErrSymlinkCopy + } + } + srcF, err := os.Open(src) if err != nil { return 0, err diff --git a/get_git_test.go b/get_git_test.go index 36def9e0..acd561d8 100644 --- a/get_git_test.go +++ b/get_git_test.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "encoding/base64" + "errors" + "fmt" "io/ioutil" "net/url" "os" @@ -734,6 +736,89 @@ func TestGitGetter_Detector(t *testing.T) { } } +func TestGitGetter_subdirectory_symlink(t *testing.T) { + dst := testing_helper.TempDir(t) + + repo := testGitRepo(t, "repo-with-symlink") + innerDir := filepath.Join(repo.dir, "this-directory-contains-a-symlink") + if err := os.Mkdir(innerDir, 0700); err != nil { + t.Fatal(err) + } + path := filepath.Join(innerDir, "this-is-a-symlink") + if err := os.Symlink("/etc/passwd", path); err != nil { + t.Fatal(err) + } + repo.git("add", path) + repo.git("commit", "-m", "Adding "+path) + + u, err := url.Parse(fmt.Sprintf("git::%s//this-directory-contains-a-symlink", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + req := &Request{ + Src: u.String(), + Dst: dst, + Pwd: ".", + GetMode: ModeDir, + } + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(GitHubDetector), + }, + } + client := &Client{ + Getters: []Getter{getter}, + DisableSymlinks: true, + } + + ctx := context.Background() + _, err = client.Get(ctx, req) + if err == nil { + t.Fatalf("expected client get to fail") + } + if !errors.Is(err, ErrSymlinkCopy) { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestGitGetter_subdirectory_traversal(t *testing.T) { + dst := testing_helper.TempDir(t) + + repo := testGitRepo(t, "empty-repo") + u, err := url.Parse(fmt.Sprintf("git::%s//../../../../../../etc/passwd", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + req := &Request{ + Src: u.String(), + Dst: dst, + Pwd: ".", + GetMode: ModeDir, + } + + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(GitHubDetector), + }, + } + client := &Client{ + Getters: []Getter{getter}, + } + + ctx := context.Background() + _, err = client.Get(ctx, req) + if err == nil { + t.Fatalf("expected client get to fail") + } + if !strings.Contains(err.Error(), "subdirectory component contain path traversal out of the repository") { + t.Fatalf("unexpected error: %v", err) + } +} + // gitRepo is a helper struct which controls a single temp git repo. type gitRepo struct { t *testing.T diff --git a/get_http.go b/get_http.go index 261db485..97a228be 100644 --- a/get_http.go +++ b/get_http.go @@ -256,7 +256,7 @@ func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir return err } - return copyDir(ctx, req.Dst, sourcePath, false, req.umask()) + return copyDir(ctx, req.Dst, sourcePath, false, req.DisableSymlinks, req.umask()) } // parseMeta looks for the first meta tag in the given reader that diff --git a/request.go b/request.go index 73e1dfeb..5e0ee6e9 100644 --- a/request.go +++ b/request.go @@ -58,6 +58,10 @@ type Request struct { // By default a no op progress listener is used. ProgressListener ProgressTracker + // Disable symlinks is used to prevent copying or writing files through symlinks. + // When set to true any copying or writing through symlinks will result in a ErrSymlinkCopy error. + DisableSymlinks bool + u *url.URL subDir, realDst string } From af532093446cf1a98eaa5aebc65d06293a512ebe Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Fri, 13 May 2022 16:31:22 -0400 Subject: [PATCH 05/13] Update get_http to address various get concerns * Add XTerraformGetLimit and XTerraformGetDisabled * Add Multiple new options to limit resource consumption: DoNotCheckHeadFirst, HeadFirstTimeout, ReadTimeout, MaxBytes * Add getter client to context for reuse * Add setters/getters for storing configured getter.Client in a context * Update HttpGetter to use ClientFromContext when available; otherwise use a limited client for supporting X-Terraform-Get request * Refactor HttpGetter function to make it clear when a configured getter.Client is required * Add security section to README --- README.md | 114 +++++++-- client.go | 3 + client_option.go | 21 ++ get_http.go | 352 ++++++++++++++++++++++---- get_http_test.go | 626 ++++++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 1024 insertions(+), 92 deletions(-) diff --git a/README.md b/README.md index 3f763c50..5309bd38 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,8 @@ URLs. For example: "github.com/hashicorp/go-getter" would turn into a Git URL. Or "./foo" would turn into a file URL. These are extensible. This library is used by [Terraform](https://terraform.io) for -downloading modules and [Nomad](https://nomadproject.io) for downloading -binaries. +downloading modules, [Packer](https://packer.io) for downloading binaries, and +[Nomad](https://nomadproject.io) for downloading binaries. ## Installation and Usage @@ -47,6 +47,16 @@ $ go-getter github.com/foo/bar ./foo The command is useful for verifying URL structures. +## Security +Fetching resources from user-supplied URLs is an inherently dangerous operation and may +leave your application vulnerable to [server side request forgery](https://owasp.org/www-community/attacks/Server_Side_Request_Forgery), +[path traversal](https://owasp.org/www-community/attacks/Path_Traversal), [denial of service](https://owasp.org/www-community/attacks/Denial_of_Service) +or other security flaws. + +go-getter contains mitigations for some of these security issues, but should still be used with +caution in security-critical contexts. See the available [security options](#Security-Options) that +can be configured to mitigate some of these risks. + ## URL Format go-getter uses a single string URL as input to download from a variety of @@ -83,7 +93,7 @@ is built-in by default: file URLs. * GitHub URLs, such as "github.com/mitchellh/vagrant" are automatically changed to Git protocol over HTTP. - * GitLab URLs, such as "gitlab.com/inkscape/inkscape" are automatically + * GitLab URLs, such as "gitlab.com/inkscape/inkscape" are automatically changed to Git protocol over HTTP. * BitBucket URLs, such as "bitbucket.org/mitchellh/vagrant" are automatically changed to a Git or mercurial protocol using the BitBucket API. @@ -178,7 +188,7 @@ checksum string. Examples: ``` ./foo.txt?checksum=file:./foo.txt.sha256sum ``` - + When checksumming from a file - ex: with `checksum=file:url` - go-getter will get the file linked in the URL after `file:` using the same configuration. For example, in `file:http://releases.ubuntu.com/cosmic/MD5SUMS` go-getter will @@ -279,7 +289,7 @@ None from a private key file on disk, you would run `base64 -w0 `. **Note**: Git 2.3+ is required to use this feature. - + * `depth` - The Git clone depth. The provided number specifies the last `n` revisions to clone from the repository. @@ -374,21 +384,21 @@ files from a smb shared folder whenever the url is prefixed with `smb://`. ⚠️ The [`smbclient`](https://www.samba.org/samba/docs/current/man-html/smbclient.1.html) command is available only for Linux. This is the ONLY option for a Linux user and therefore the client must be installed. - + The `smbclient` cli is not available for Windows and MacOS. The go-getter will try to get files using the file system, when this happens the getter uses the FileGetter implementation. -When connecting to a smb server, the OS creates a local mount in a system specific volume folder, and go-getter will +When connecting to a smb server, the OS creates a local mount in a system specific volume folder, and go-getter will try to access the following folders when looking for local mounts. - MacOS: /Volumes/ - Windows: \\\\\\\\ -The following examples work for all the OSes: +The following examples work for all the OSes: - smb://host/shared/dir (downloads directory content) -- smb://host/shared/dir/file (downloads file) +- smb://host/shared/dir/file (downloads file) -The following examples work for Linux: +The following examples work for Linux: - smb://username:password@host/shared/dir (downloads directory content) - smb://username@host/shared/dir - smb://username:password@host/shared/dir/file (downloads file) @@ -396,13 +406,85 @@ The following examples work for Linux: ⚠️ The above examples also work on the other OSes but the authentication is not used to access the file system. - - + + #### SMB Testing The test for `get_smb.go` requires a smb server running which can be started inside a docker container by -running `make start-smb`. Once the container is up the shared folder can be accessed via `smb:///public/` or -`smb://user:password@/private/` by another container or machine in the same network. +running `make start-smb`. Once the container is up the shared folder can be accessed via `smb:///public/` or +`smb://user:password@/private/` by another container or machine in the same network. -To run the tests inside `get_smb_test.go` and `client_test.go`, prepare the environment with `make smbtests-prepare`. On prepare some +To run the tests inside `get_smb_test.go` and `client_test.go`, prepare the environment with `make smbtests-prepare`. On prepare some mock files and directories will be added to the shared folder and a go-getter container will start together with the samba server. -Once the environment for testing is prepared, run `make smbtests` to run the tests. \ No newline at end of file +Once the environment for testing is prepared, run `make smbtests` to run the tests. + +### Security Options + +**Disable Symlinks** + +In your getter client config, we recommend using the `DisableSymlinks` option, +which prevents writing through or copying from symlinks (which may point outside the directory). + +```go +client := getter.Client{ + // This will prevent copying or writing files through symlinks + DisableSymlinks: true, +} +``` + +**Disable or Limit `X-Terraform-Get`** + +Go-Getter supports arbitrary redirects via the `X-Terraform-Get` header. This functionality +exists to support [Terraform use cases](https://www.terraform.io/language/modules/sources#http-urls), +but is likely not needed in most applications. + +For code that uses the `HttpGetter`, add the following configuration options: + +```go +var httpGetter = &getter.HttpGetter{ + // Most clients should disable X-Terraform-Get + // See the note below + XTerraformGetDisabled: true, + // Your software probably doesn’t rely on X-Terraform-Get, but + // if it does, you should set the above field to false, plus + // set XTerraformGet Limit to prevent endless redirects + // XTerraformGetLimit: 10, +} +``` + +**Enforce Timeouts** + +The `HttpGetter` supports timeouts and other resource-constraining configuration options. The `GitGetter` and `HgGetter` +only support timeouts. + +Configuration for the `HttpGetter`: + +```go +var httpGetter = &getter.HttpGetter{ + // Disable pre-fetch HEAD requests + DoNotCheckHeadFirst: true, + + // As an alternative to the above setting, you can + // set a reasonable timeout for HEAD requests + // HeadFirstTimeout: 10 * time.Second, + // Read timeout for HTTP operations + ReadTimeout: 30 * time.Second, + // Set the maximum number of bytes + // that can be read by the getter + MaxBytes: 500000000, // 500 MB +} +``` + +For code that uses the `GitGetter` or `HgGetter`, set the `Timeout` option: +```go +var gitGetter = &getter.GitGetter{ + // Set a reasonable timeout for git operations + Timeout: 5 * time.Minute, +} +``` + +```go +var hgGetter = &getter.HgGetter{ + // Set a reasonable timeout for hg operations + Timeout: 5 * time.Minute, +} +``` diff --git a/client.go b/client.go index be9ff3d5..3aa5dd1d 100644 --- a/client.go +++ b/client.go @@ -49,6 +49,9 @@ func (c *Client) Get(ctx context.Context, req *Request) (*GetResult, error) { return nil, err } + // Pass along the configured Getter client in the context for usage with the X-Terraform-Get feature. + ctx = NewContextWithClient(ctx, c) + // Store this locally since there are cases we swap this if req.GetMode == ModeInvalid { req.GetMode = ModeAny diff --git a/client_option.go b/client_option.go index c17f73fe..77d13cdc 100644 --- a/client_option.go +++ b/client_option.go @@ -1,5 +1,26 @@ package getter +import ( + "context" +) + +type clientContextKey int + +const clientContextValue clientContextKey = 0 + +func NewContextWithClient(ctx context.Context, client *Client) context.Context { + return context.WithValue(ctx, clientContextValue, client) +} + +func ClientFromContext(ctx context.Context) *Client { + // ctx.Value returns nil if ctx has no value for the key; + client, ok := ctx.Value(clientContextValue).(*Client) + if !ok { + return nil + } + return client +} + // configure configures a client with options. func (c *Client) configure() error { // Default decompressor values diff --git a/get_http.go b/get_http.go index 97a228be..5a9658dd 100644 --- a/get_http.go +++ b/get_http.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "strings" + "time" safetemp "github.com/hashicorp/go-safetemp" ) @@ -26,7 +27,9 @@ import ( // wish. The response must be a 2xx. // // First, a header is looked for "X-Terraform-Get" which should contain -// a source URL to download. +// a source URL to download. This source must use one of the configured +// protocols and getters for the client, or "http"/"https" if using +// the HttpGetter directly. // // If the header is not present, then a meta tag is searched for named // "terraform-get" and the content should be a source URL. @@ -49,6 +52,35 @@ type HttpGetter struct { // and as such it needs to be initialized before use, via something like // make(http.Header). Header http.Header + // DoNotCheckHeadFirst configures the client to NOT check if the server + // supports HEAD requests. + DoNotCheckHeadFirst bool + + // HeadFirstTimeout configures the client to enforce a timeout when + // the server supports HEAD requests. + // + // The zero value means no timeout. + HeadFirstTimeout time.Duration + + // ReadTimeout configures the client to enforce a timeout when + // making a request to an HTTP server and reading its response body. + // + // The zero value means no timeout. + ReadTimeout time.Duration + + // MaxBytes limits the number of bytes that will be ready from an HTTP + // response body returned from a server. The zero value means no limit. + MaxBytes int64 + + // XTerraformGetLimit configures how many times the client with follow + // the " X-Terraform-Get" header value. + // + // The zero value means no limit. + XTerraformGetLimit int + + // XTerraformGetDisabled disables the client's usage of the "X-Terraform-Get" + // header value. + XTerraformGetDisabled bool } func (g *HttpGetter) Mode(ctx context.Context, u *url.URL) (Mode, error) { @@ -58,7 +90,112 @@ func (g *HttpGetter) Mode(ctx context.Context, u *url.URL) (Mode, error) { return ModeFile, nil } +type contextKey int + +const ( + xTerraformGetDisable contextKey = 0 + xTerraformGetLimit contextKey = 1 + xTerraformGetLimitCurrentValue contextKey = 2 + httpClientValue contextKey = 3 + httpMaxBytesValue contextKey = 4 +) + +func xTerraformGetDisabled(ctx context.Context) bool { + value, ok := ctx.Value(xTerraformGetDisable).(bool) + if !ok { + return false + } + return value +} + +func xTerraformGetLimitCurrentValueFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimitCurrentValue).(int) + if !ok { + return 1 + } + return value +} + +func xTerraformGetLimiConfiguredtFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimit).(int) + if !ok { + return 0 + } + return value +} + +func httpClientFromContext(ctx context.Context) *http.Client { + value, ok := ctx.Value(httpClientValue).(*http.Client) + if !ok { + return nil + } + return value +} + +func httpMaxBytesFromContext(ctx context.Context) int64 { + value, ok := ctx.Value(httpMaxBytesValue).(int64) + if !ok { + return 0 // no limit + } + return value +} + +type limitedWrappedReaderCloser struct { + underlying io.Reader + closeFn func() error +} + +func (l *limitedWrappedReaderCloser) Read(p []byte) (n int, err error) { + return l.underlying.Read(p) +} + +func (l *limitedWrappedReaderCloser) Close() (err error) { + return l.closeFn() +} + +func newLimitedWrappedReaderCloser(r io.ReadCloser, limit int64) io.ReadCloser { + return &limitedWrappedReaderCloser{ + underlying: io.LimitReader(r, limit), + closeFn: r.Close, + } +} + func (g *HttpGetter) Get(ctx context.Context, req *Request) error { + // Optionally disable any X-Terraform-Get redirects. This is recommended for usage of + // this client outside of Terraform's. This feature is likely not required if the + // source server can provider normal HTTP redirects. + if g.XTerraformGetDisabled { + ctx = context.WithValue(ctx, xTerraformGetDisable, g.XTerraformGetDisabled) + } + + // Optionally enforce a limit on X-Terraform-Get redirects. We check this for every + // invocation of this function, because the value is not passed down to subsequent + // client Get function invocations. + if g.XTerraformGetLimit > 0 { + ctx = context.WithValue(ctx, xTerraformGetLimit, g.XTerraformGetLimit) + } + + // If there was a limit on X-Terraform-Get redirects, check what the current count value. + // + // If the value is greater than the limit, return an error. Otherwise, increment the value, + // and include it in the the context to be passed along in all the subsequent client + // Get function invocations. + if limit := xTerraformGetLimiConfiguredtFromContext(ctx); limit > 0 { + currentValue := xTerraformGetLimitCurrentValueFromContext(ctx) + + if currentValue > limit { + return fmt.Errorf("too many X-Terraform-Get redirects: %d", currentValue) + } + + currentValue++ + + ctx = context.WithValue(ctx, xTerraformGetLimitCurrentValue, currentValue) + } + + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } // Copy the URL so we can modify it var newU url.URL = *req.u req.u = &newU @@ -70,17 +207,33 @@ func (g *HttpGetter) Get(ctx context.Context, req *Request) error { } } + // If the HTTP client is nil, check if there is one available in the context, + // otherwise create one using cleanhttp's default transport. if g.Client == nil { - g.Client = httpClient + if client := httpClientFromContext(ctx); client != nil { + g.Client = client + } else { + g.Client = httpClient + } } + // Pass along the configured HTTP client in the context for usage with the X-Terraform-Get feature. + ctx = context.WithValue(ctx, httpClientValue, g.Client) + // Add terraform-get to the parameter. q := req.u.Query() q.Add("terraform-get", "1") req.u.RawQuery = q.Encode() + readCtx := ctx + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + // Get the URL - httpReq, err := http.NewRequestWithContext(ctx, "GET", req.u.String(), nil) + httpReq, err := http.NewRequestWithContext(readCtx, "GET", req.u.String(), nil) if err != nil { return err } @@ -92,40 +245,53 @@ func (g *HttpGetter) Get(ctx context.Context, req *Request) error { if err != nil { return err } - defer resp.Body.Close() + + body := resp.Body + if maxBytes := httpMaxBytesFromContext(ctx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("bad response code: %d", resp.StatusCode) } + if disabled := xTerraformGetDisabled(ctx); disabled { + return nil + } + + // Get client with configured Getters from the context + // If the client is nil, we know we're using the HttpGetter directly. In this case, + // we don't know exactly which protocols are configured, but we can make a good guess. + // + // This prevents all default getters from being allowed when only using the + // HttpGetter directly. To enable protocol switching, a client "wrapper" must + // be used. + var getterClient *Client + if v := ClientFromContext(ctx); v != nil { + getterClient = v + } else { + getterClient = &Client{ + Getters: []Getter{g}, + } + } + // Extract the source URL var source string if v := resp.Header.Get("X-Terraform-Get"); v != "" { source = v } else { - source, err = g.parseMeta(resp.Body) + source, err = g.parseMeta(readCtx, body) if err != nil { return err } } + if source == "" { return fmt.Errorf("no source URL was returned") } - // If there is a subdir component, then we download the root separately - // into a temporary directory, then copy over the proper subdir. - source, subDir := SourceDirSubdir(source) - req = &Request{ - GetMode: ModeDir, - Src: source, - Dst: req.Dst, - } - if subDir == "" { - _, err = DefaultClient.Get(ctx, req) - return err - } - // We have a subdir, time to jump some hoops - return g.getSubdir(ctx, req, source, subDir) + return g.getXTerraformSource(ctx, req, source, getterClient) } // GetFile fetches the file from src and stores it at dst. @@ -135,6 +301,11 @@ func (g *HttpGetter) Get(ctx context.Context, req *Request) error { // falsely identified as being replaced, or corrupted with extra bytes // appended. func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } + if g.Netrc { // Add auth from netrc if we can if err := addAuthFromNetrc(req.u); err != nil { @@ -157,38 +328,67 @@ func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { } var currentFileSize int64 + var httpReq *http.Request + + if g.DoNotCheckHeadFirst == false { + headCtx := ctx - // We first make a HEAD request so we can check - // if the server supports range queries. If the server/URL doesn't - // support HEAD requests, we just fall back to GET. - httpReq, err := http.NewRequestWithContext(ctx, "HEAD", req.u.String(), nil) + if g.HeadFirstTimeout > 0 { + var cancel context.CancelFunc + + headCtx, cancel = context.WithTimeout(ctx, g.HeadFirstTimeout) + defer cancel() + } + + // We first make a HEAD request so we can check + // if the server supports range queries. If the server/URL doesn't + // support HEAD requests, we just fall back to GET. + httpReq, err = http.NewRequestWithContext(headCtx, "HEAD", req.u.String(), nil) + if err != nil { + return err + } + if g.Header != nil { + httpReq.Header = g.Header.Clone() + } + headResp, err := g.Client.Do(httpReq) + if err == nil { + headResp.Body.Close() + if headResp.StatusCode == 200 { + // If the HEAD request succeeded, then attempt to set the range + // query if we can. + if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { + if fi, err := f.Stat(); err == nil { + if _, err = f.Seek(0, io.SeekEnd); err == nil { + currentFileSize = fi.Size() + httpReq.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) + if currentFileSize >= headResp.ContentLength { + // file already present + return nil + } + } + } + } + } + } + } + + readCtx := ctx + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + + httpReq, err = http.NewRequestWithContext(readCtx, "GET", req.u.String(), nil) if err != nil { return err } if g.Header != nil { httpReq.Header = g.Header.Clone() } - headResp, err := g.Client.Do(httpReq) - if err == nil { - headResp.Body.Close() - if headResp.StatusCode == 200 { - // If the HEAD request succeeded, then attempt to set the range - // query if we can. - if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { - if fi, err := f.Stat(); err == nil { - if _, err = f.Seek(0, io.SeekEnd); err == nil { - currentFileSize = fi.Size() - httpReq.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) - if currentFileSize >= headResp.ContentLength { - // file already present - return nil - } - } - } - } - } + if currentFileSize > 0 { + httpReq.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) } - httpReq.Method = "GET" resp, err := g.Client.Do(httpReq) if err != nil { @@ -204,6 +404,10 @@ func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { body := resp.Body + if maxBytes := httpMaxBytesFromContext(readCtx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if req.ProgressListener != nil { // track download fn := filepath.Base(req.u.EscapedPath()) @@ -211,16 +415,59 @@ func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { } defer body.Close() - n, err := Copy(ctx, f, body) + n, err := Copy(readCtx, f, body) if err == nil && n < resp.ContentLength { err = io.ErrShortWrite } return err } +// getXTerraformSource downloads the source into the destination +// using a protocol switching capable client. +func (g *HttpGetter) getXTerraformSource(ctx context.Context, req *Request, source string, client *Client) error { + + // If there is a subdir component, then we download the root separately + // into a temporary directory, then copy over the proper subdir. + source, subDir := SourceDirSubdir(source) + req = &Request{ + GetMode: ModeDir, + Src: source, + Dst: req.Dst, + DisableSymlinks: req.DisableSymlinks, + } + + if subDir == "" { + // We have a X-Terraform-Get source lets check for supported Getters + var allowed bool + for _, getter := range client.Getters { + shouldDownload, err := Detect(req, getter) + if err != nil { + return fmt.Errorf("failed to detect the proper Getter to handle %s: %w", source, err) + } + if !shouldDownload { + // the request should not be processed by that getter + continue + } + allowed = true + } + + if !allowed { + protocol := strings.Split(source, ":")[0] + return fmt.Errorf("download not supported for scheme %q", protocol) + } + + _, err := client.Get(ctx, req) + return err + } + + // We have a subdir, time to jump some hoops + return g.getSubdir(ctx, req, source, subDir, client) + +} + // getSubdir downloads the source into the destination, but with // the proper subdir. -func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir string) error { +func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir string, client *Client) error { // Create a temporary directory to store the full source. This has to be // a non-existent directory. td, tdcloser, err := safetemp.Dir("", "getter") @@ -229,8 +476,13 @@ func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir } defer tdcloser.Close() - // Download that into the given directory - if _, err := Get(ctx, td, source); err != nil { + tdReq := &Request{ + Src: source, + Dst: td, + GetMode: ModeDir, + DisableSymlinks: req.DisableSymlinks, + } + if _, err := client.Get(ctx, tdReq); err != nil { return err } @@ -261,13 +513,17 @@ func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir // parseMeta looks for the first meta tag in the given reader that // will give us the source URL. -func (g *HttpGetter) parseMeta(r io.Reader) (string, error) { +func (g *HttpGetter) parseMeta(ctx context.Context, r io.Reader) (string, error) { d := xml.NewDecoder(r) d.CharsetReader = charsetReader d.Strict = false var err error var t xml.Token for { + if ctx.Err() != nil { + return "", fmt.Errorf("context error while parsing meta tag: %w", ctx.Err()) + } + t, err = d.Token() if err != nil { if err == io.EOF { diff --git a/get_http_test.go b/get_http_test.go index b6991411..0ecc259a 100644 --- a/get_http_test.go +++ b/get_http_test.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httputil" "net/url" "os" "path/filepath" @@ -16,6 +17,7 @@ import ( "strings" "testing" + cleanhttp "github.com/hashicorp/go-cleanhttp" testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" ) @@ -38,12 +40,26 @@ func TestHttpGetter_header(t *testing.T) { u.Path = "/header" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -103,12 +119,26 @@ func TestHttpGetter_meta(t *testing.T) { u.Path = "/meta" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -134,12 +164,26 @@ func TestHttpGetter_metaSubdir(t *testing.T) { u.Path = "/meta-subdir" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "error downloading") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -165,12 +209,26 @@ func TestHttpGetter_metaSubdirGlob(t *testing.T) { u.Path = "/meta-subdir-glob" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "error downloading") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -361,12 +419,26 @@ func TestHttpGetter_auth(t *testing.T) { u.User = url.UserPassword("foo", "bar") req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -397,12 +469,26 @@ func TestHttpGetter_authNetrc(t *testing.T) { defer tempEnv(t, "NETRC", path)() req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -442,14 +528,29 @@ func TestHttpGetter_cleanhttp(t *testing.T) { u.Path = "/header" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } + } func TestHttpGetter__RespectsContextCanceled(t *testing.T) { @@ -491,6 +592,429 @@ func TestHttpGetter__RespectsContextCanceled(t *testing.T) { } } +func TestHttpGetter__XTerraformGetLimit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + + dst := testing_helper.TempDir(t) + defer os.RemoveAll(dst) + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.Client = &http.Client{} + + req := Request{ + Dst: dst, + u: &u, + GetMode: ModeDir, + } + + err := g.Get(ctx, &req) + if !strings.Contains(err.Error(), "too many X-Terraform-Get redirects") { + t.Fatalf("too many X-Terraform-Get redirects, got: %v", err) + } +} + +func TestHttpGetter__XTerraformGetDisabled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + dst := testing_helper.TempDir(t) + + g := new(HttpGetter) + g.XTerraformGetDisabled = true + g.Client = &http.Client{} + + req := Request{ + Dst: dst, + u: &u, + GetMode: ModeDir, + } + + err := g.Get(ctx, &req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} +func TestHttpGetter__XTerraformGetProxyBypass(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetProxyBypass(t) + + proxyLn := testHttpServerProxy(t, ln.Addr().String()) + + t.Logf("starting malicious server on: %v", ln.Addr().String()) + t.Logf("starting proxy on: %v", proxyLn.Addr().String()) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + dst := testing_helper.TempDir(t) + + proxy, err := url.Parse(fmt.Sprintf("http://%s/", proxyLn.Addr().String())) + if err != nil { + t.Fatalf("failed to parse proxy URL: %v", err) + } + + transport := cleanhttp.DefaultTransport() + transport.Proxy = http.ProxyURL(proxy) + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.Client = &http.Client{ + Transport: transport, + } + + client := &Client{ + Getters: []Getter{g}, + } + + req := Request{ + Dst: dst, + Src: u.String(), + } + + _, err = client.Get(ctx, &req) + if err != nil { + t.Logf("client get error: %v", err) + } +} + +func TestHttpGetter__XTerraformGetConfiguredGettersBypass(t *testing.T) { + tc := []struct { + name string + configuredGetters []Getter + errExpected bool + }{ + {name: "configured getter for git protocol switch", configuredGetters: []Getter{new(GitGetter)}, errExpected: false}, + {name: "configured getter for multiple protocol switch", configuredGetters: []Getter{new(HgGetter), new(GitGetter), new(FileGetter)}, errExpected: false}, + {name: "configured getter for file protocol switch", configuredGetters: []Getter{new(FileGetter)}, errExpected: true}, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetConfiguredGettersBypass(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + + for _, tt := range tc { + tt := tt + t.Run(tt.name, func(t *testing.T) { + dst := testing_helper.TempDir(t) + + rt := hookableHTTPRoundTripper{ + before: func(req *http.Request) { + t.Logf("making request") + }, + RoundTripper: http.DefaultTransport, + } + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.Client = &http.Client{ + Transport: &rt, + } + + client := &Client{ + Getters: []Getter{g}, + } + client.Getters = append(client.Getters, tt.configuredGetters...) + + t.Logf("%v", u.String()) + + req := Request{ + Dst: dst, + Src: u.String(), + GetMode: ModeDir, + } + + _, err := client.Get(ctx, &req) + if tt.errExpected && err == nil { + t.Fatalf("error expected") + } + if err != nil { + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("expected download not supported for scheme, got: %v", err) + } + } + }) + } +} + +func TestHttpGetter__endless_body(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithEndlessBody(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/" + dst := testing_helper.TempDir(t) + + g := new(HttpGetter) + g.MaxBytes = 10 + g.DoNotCheckHeadFirst = true + + client := &Client{ + Getters: []Getter{g}, + } + + t.Logf("%v", u.String()) + + req := Request{ + Dst: dst, + Src: u.String(), + GetMode: ModeFile, + } + + _, err := client.Get(ctx, &req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHttpGetter_subdirLink(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerSubDir(t) + defer ln.Close() + + dst, err := ioutil.TempDir("", "tf") + if err != nil { + t.Fatalf("err: %s", err) + } + + t.Logf("dst: %q", dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/regular-subdir//meta-subdir" + + g := new(HttpGetter) + client := &Client{ + Getters: []Getter{g}, + } + + t.Logf("url: %q", u.String()) + + req := Request{ + Dst: dst, + Src: u.String(), + GetMode: ModeAny, + } + + _, err = client.Get(ctx, &req) + if err != nil { + t.Fatalf("get err: %v", err) + } +} + +func testHttpServerWithXTerraformGetLoop(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v:%v", ln.Addr().String(), "/loop") + + mux := http.NewServeMux() + mux.HandleFunc("/loop", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving loop") + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetProxyBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v/bypass", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/bypass", func(w http.ResponseWriter, r *http.Request) { + t.Fail() + t.Logf("bypassed proxy") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetConfiguredGettersBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("git::http://%v/some/repository.git", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving git HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func TestHttpGetter_XTerraformWithClientFromContext(t *testing.T) { + tc := []struct { + name string + client *Client + errExpected bool + }{ + { + name: "default getters", + client: &Client{ + Getters: Getters, + }, + errExpected: false, + }, + { + name: "client configured with needed getters", + client: &Client{ + Getters: []Getter{ + new(HttpGetter), + new(FileGetter), + }, + }, + errExpected: false, + }, + { + name: "nil client", + errExpected: true, + }, + } + + for _, tt := range tc { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ln := testHttpServer(t) + defer ln.Close() + ctx := context.Background() + + g := new(HttpGetter) + dst := testing_helper.TempDir(t) + defer os.RemoveAll(dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/header" + + req := &Request{ + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, + } + + // Using a client stored in the ctx with a file getter should work + ctx = NewContextWithClient(ctx, tt.client) + + err := g.Get(ctx, req) + if tt.errExpected && err == nil { + t.Fatalf("error expected") + } + + if err != nil { + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("expected download not supported for scheme, got: %v", err) + } + return + } + + // Verify the main file exists + mainPath := filepath.Join(dst, "main.tf") + if _, err := os.Stat(mainPath); err != nil { + t.Fatalf("err: %s", err) + } + }) + } +} + +func testHttpServerProxy(t *testing.T, upstreamHost string) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving proxy: %v: %#+v", r.URL.Path, r.Header) + // create the reverse proxy + proxy := httputil.NewSingleHostReverseProxy(r.URL) + // Note that ServeHttp is non blocking & uses a go routine under the hood + proxy.ServeHTTP(w, r) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpServer(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -515,6 +1039,29 @@ func testHttpServer(t *testing.T) net.Listener { return ln } +func testHttpServerWithEndlessBody(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + for { + w.Write([]byte(".\n")) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpHandlerExpectHeader(w http.ResponseWriter, r *http.Request) { if expected, ok := r.URL.Query()["expected"]; ok { if r.Header.Get(expected[0]) != "" { @@ -598,6 +1145,29 @@ func testHttpHandlerNoRange(w http.ResponseWriter, r *http.Request) { } } +func testHttpServerSubDir(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + t.Logf("serving: %v: %v: %#+[1]v", r.Method, r.URL.String(), r.Header) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + const testHttpMetaStr = ` From e8be9f4334fb6e23ed79ee36fd9befb42d099756 Mon Sep 17 00:00:00 2001 From: Kent 'picat' Gruber Date: Tue, 17 May 2022 10:33:24 -0400 Subject: [PATCH 06/13] Port changes from hashicorp/eastebry/timeout-for-getters Adding timeout to s3Getter --- s3/get_s3.go | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/s3/get_s3.go b/s3/get_s3.go index 32065bc5..53942e83 100644 --- a/s3/get_s3.go +++ b/s3/get_s3.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -20,9 +21,21 @@ import ( // Getter is a Getter implementation that will download a module from // a S3 bucket. -type Getter struct{} +type Getter struct { + + // Timeout sets a deadline which all S3 operations should + // complete within. Zero value means no timeout. + Timeout time.Duration +} func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { @@ -40,7 +53,7 @@ func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { Bucket: aws.String(bucket), Prefix: aws.String(path), } - resp, err := client.ListObjects(req) + resp, err := client.ListObjectsWithContext(ctx, req) if err != nil { return 0, err } @@ -64,6 +77,12 @@ func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { func (g *Getter) Get(ctx context.Context, req *getter.Request) error { + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL region, bucket, path, _, creds, err := g.parseUrl(req.URL()) if err != nil { @@ -105,7 +124,7 @@ func (g *Getter) Get(ctx context.Context, req *getter.Request) error { s3Req.Marker = aws.String(lastMarker) } - resp, err := client.ListObjects(s3Req) + resp, err := client.ListObjectsWithContext(ctx, s3Req) if err != nil { return err } @@ -161,7 +180,7 @@ func (g *Getter) getObject(ctx context.Context, client *s3.S3, req *getter.Reque s3req.VersionId = aws.String(version) } - resp, err := client.GetObject(s3req) + resp, err := client.GetObjectWithContext(ctx, s3req) if err != nil { return err } From 6a726826d97ea82461f9cfa750cae5a57c310c26 Mon Sep 17 00:00:00 2001 From: Kent 'picat' Gruber Date: Tue, 17 May 2022 14:32:54 -0400 Subject: [PATCH 07/13] Port changes from from hashicorp/add-missing-timeouts Add missing timeouts to `S3Getter` and `GCSGetter` --- gcs/get_gcs.go | 28 +++++++++++++++++++++++++++- s3/get_s3.go | 7 +++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/gcs/get_gcs.go b/gcs/get_gcs.go index 8781c157..5adf6f84 100644 --- a/gcs/get_gcs.go +++ b/gcs/get_gcs.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "cloud.google.com/go/storage" "github.com/hashicorp/go-getter/v2" @@ -15,10 +16,21 @@ import ( // Getter is a Getter implementation that will download a module from // a GCS bucket. -type Getter struct{} +type Getter struct { + + // Timeout sets a deadline which all GCS operations should + // complete within. Zero value means no timeout. + Timeout time.Duration +} func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, err := g.parseURL(u) if err != nil { @@ -54,6 +66,13 @@ func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { } func (g *Getter) Get(ctx context.Context, req *getter.Request) error { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, err := g.parseURL(req.URL()) if err != nil { @@ -111,6 +130,13 @@ func (g *Getter) Get(ctx context.Context, req *getter.Request) error { } func (g *Getter) GetFile(ctx context.Context, req *getter.Request) error { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, err := g.parseURL(req.URL()) if err != nil { diff --git a/s3/get_s3.go b/s3/get_s3.go index 53942e83..827a292f 100644 --- a/s3/get_s3.go +++ b/s3/get_s3.go @@ -158,6 +158,13 @@ func (g *Getter) Get(ctx context.Context, req *getter.Request) error { } func (g *Getter) GetFile(ctx context.Context, req *getter.Request) error { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + region, bucket, path, version, creds, err := g.parseUrl(req.URL()) if err != nil { return err From 5825b3f85f8b2ef25dc53be2b9ef3037d3d77ec6 Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Thu, 19 May 2022 08:20:19 -0400 Subject: [PATCH 08/13] Update test to work on Windows After change ``` > go test -v ./... -run=TestFile --- PASS: TestFileGetter_dir (0.00s) === RUN TestFileGetter_dirSymlink --- PASS: TestFileGetter_dirSymlink (0.04s) === RUN TestFileGetter_GetFile --- PASS: TestFileGetter_GetFile (0.00s) === RUN TestFileGetter_GetFile_Copy --- PASS: TestFileGetter_GetFile_Copy (0.01s) === RUN TestFileGetter_percent2F --- PASS: TestFileGetter_percent2F (0.06s) === RUN TestFileGetter_Mode_notexist --- PASS: TestFileGetter_Mode_notexist (0.00s) === RUN TestFileGetter_Mode_file --- PASS: TestFileGetter_Mode_file (0.00s) === RUN TestFileGetter_Mode_dir --- PASS: TestFileGetter_Mode_dir (0.00s) PASS > go test -v ./... -run=TestGit === RUN TestGitGetter_subdirectory_symlink get_git_test.go:870: initializing git repo in --- PASS: TestGitGetter_subdirectory_symlink (5.45s) === RUN TestGitGetter_subdirectory_traversal get_git_test.go:870: initializing git repo in --- PASS: TestGitGetter_subdirectory_traversal (0.27s) PASS ``` --- get_file_test.go | 2 + get_file_windows_test.go | 263 +++++++++++++++++++++++++++++++++++++++ get_git_test.go | 29 ++++- helper/testing/utils.go | 2 +- 4 files changed, 290 insertions(+), 6 deletions(-) create mode 100644 get_file_windows_test.go diff --git a/get_file_test.go b/get_file_test.go index 71e05f7c..41dfe93d 100644 --- a/get_file_test.go +++ b/get_file_test.go @@ -1,3 +1,5 @@ +// +build test unix + package getter import ( diff --git a/get_file_windows_test.go b/get_file_windows_test.go new file mode 100644 index 00000000..d10dafbb --- /dev/null +++ b/get_file_windows_test.go @@ -0,0 +1,263 @@ +// +build test windows + +package getter + +import ( + "context" + "os" + "path/filepath" + "testing" + + testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" + urlhelper "github.com/hashicorp/go-getter/v2/helper/url" +) + +func TestFileGetter_impl(t *testing.T) { + var _ Getter = new(FileGetter) +} + +func TestFileGetter(t *testing.T) { + g := new(FileGetter) + dst := testing_helper.TempDir(t) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testModuleURL("basic"), + } + + // With a dir that doesn't exist + if err := g.Get(ctx, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the destination folder is a symlink + fi, err := os.Lstat(dst) + if err != nil { + t.Fatalf("err: %s", err) + } + if fi.Mode()&os.ModeSymlink == 0 { + t.Fatal("destination is not a symlink") + } + + // Verify the main file exists + mainPath := filepath.Join(dst, "main.tf") + if _, err := os.Stat(mainPath); err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestFileGetter_sourceFile(t *testing.T) { + g := new(FileGetter) + dst := testing_helper.TempDir(t) + ctx := context.Background() + + // With a source URL that is a path to a file + u := testModuleURL("basic") + u.Path += "/main.tf" + u.RawPath = u.Path + + req := &Request{ + Dst: dst, + u: u, + } + if err := g.Get(ctx, req); err == nil { + t.Fatal("should error") + } +} + +func TestFileGetter_sourceNoExist(t *testing.T) { + g := new(FileGetter) + dst := testing_helper.TempDir(t) + ctx := context.Background() + + // With a source URL that doesn't exist + u := testModuleURL("basic") + u.Path += "/main" + u.RawPath = u.Path + + req := &Request{ + Dst: dst, + u: u, + } + if err := g.Get(ctx, req); err == nil { + t.Fatal("should error") + } +} + +func TestFileGetter_dir(t *testing.T) { + g := new(FileGetter) + dst := testing_helper.TempDir(t) + ctx := context.Background() + + if err := os.MkdirAll(dst, 0755); err != nil { + t.Fatalf("err: %s", err) + } + + req := &Request{ + Dst: dst, + u: testModuleURL("basic"), + } + // With a dir that exists that isn't a symlink + if err := g.Get(ctx, req); err == nil { + t.Fatal("should error") + } +} + +func TestFileGetter_dirSymlink(t *testing.T) { + g := new(FileGetter) + dst := testing_helper.TempDir(t) + ctx := context.Background() + + dst2 := testing_helper.TempDir(t) + + // Make parents + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + t.Fatalf("err: %s", err) + } + if err := os.MkdirAll(dst2, 0755); err != nil { + t.Fatalf("err: %s", err) + } + + // Make a symlink + if err := os.Symlink(dst2, dst); err != nil { + t.Fatalf("err: %s", err) + } + + req := &Request{ + Dst: dst, + u: testModuleURL("basic"), + } + + // With a dir that exists that isn't a symlink + if err := g.Get(ctx, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the main file exists + mainPath := filepath.Join(dst, "main.tf") + if _, err := os.Stat(mainPath); err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestFileGetter_GetFile(t *testing.T) { + g := new(FileGetter) + dst := testing_helper.TempTestFile(t) + defer os.RemoveAll(filepath.Dir(dst)) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testModuleURL("basic-file/foo.txt"), + } + + // With a dir that doesn't exist + if err := g.GetFile(ctx, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the destination folder is a symlink + fi, err := os.Lstat(dst) + if err != nil { + t.Fatalf("err: %s", err) + } + if fi.Mode()&os.ModeSymlink == 0 { + t.Fatal("destination is not a symlink") + } + + // Verify the main file exists + testing_helper.AssertContents(t, dst, "Hello\r\n") +} + +func TestFileGetter_GetFile_Copy(t *testing.T) { + g := new(FileGetter) + + dst := testing_helper.TempTestFile(t) + defer os.RemoveAll(filepath.Dir(dst)) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testModuleURL("basic-file/foo.txt"), + Copy: true, + } + + // With a dir that doesn't exist + if err := g.GetFile(ctx, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the destination folder is a symlink + fi, err := os.Lstat(dst) + if err != nil { + t.Fatalf("err: %s", err) + } + if fi.Mode()&os.ModeSymlink != 0 { + t.Fatal("destination is a symlink") + } + + // Verify the main file exists + testing_helper.AssertContents(t, dst, "Hello\r\n") +} + +// https://github.com/hashicorp/terraform/issues/8418 +func TestFileGetter_percent2F(t *testing.T) { + g := new(FileGetter) + dst := testing_helper.TempDir(t) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testModuleURL("basic%2Ftest"), + } + + // With a dir that doesn't exist + if err := g.Get(ctx, req); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the main file exists + mainPath := filepath.Join(dst, "main.tf") + if _, err := os.Stat(mainPath); err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestFileGetter_Mode_notexist(t *testing.T) { + g := new(FileGetter) + ctx := context.Background() + + u := urlhelper.MustParse("nonexistent") + if _, err := g.Mode(ctx, u); err == nil { + t.Fatal("expect source file error") + } +} + +func TestFileGetter_Mode_file(t *testing.T) { + g := new(FileGetter) + ctx := context.Background() + + // Check the client mode when pointed at a file. + mode, err := g.Mode(ctx, testModuleURL("basic-file/foo.txt")) + if err != nil { + t.Fatalf("err: %s", err) + } + if mode != ModeFile { + t.Fatal("expect ModeFile") + } +} + +func TestFileGetter_Mode_dir(t *testing.T) { + g := new(FileGetter) + ctx := context.Background() + + // Check the client mode when pointed at a directory. + mode, err := g.Mode(ctx, testModuleURL("basic")) + if err != nil { + t.Fatalf("err: %s", err) + } + if mode != ModeDir { + t.Fatal("expect ModeDir") + } +} diff --git a/get_git_test.go b/get_git_test.go index acd561d8..16b24fd7 100644 --- a/get_git_test.go +++ b/get_git_test.go @@ -775,11 +775,30 @@ func TestGitGetter_subdirectory_symlink(t *testing.T) { ctx := context.Background() _, err = client.Get(ctx, req) - if err == nil { - t.Fatalf("expected client get to fail") - } - if !errors.Is(err, ErrSymlinkCopy) { - t.Fatalf("unexpected error: %v", err) + if runtime.GOOS == "windows" { + // Windows doesn't handle symlinks as one might expect with git. + // + // https://github.com/git-for-windows/git/wiki/Symbolic-Links + filepath.Walk(dst, func(path string, info os.FileInfo, err error) error { + if strings.Contains(path, "this-is-a-symlink") { + if info.Mode()&os.ModeSymlink == os.ModeSymlink { + // If you see this test fail in the future, you've probably enabled + // symlinks within git on your Windows system. Our CI/CD system does + // not do this, so this is this is the only way we can make this test + // make any sense. + t.Fatalf("windows git should not have cloned a symlink") + } + } + return nil + }) + } else { + // We can rely on POSIX compliant systems running git to do the right thing. + if err == nil { + t.Fatalf("expected client get to fail") + } + if !errors.Is(err, ErrSymlinkCopy) { + t.Fatalf("unexpected error: %v", err) + } } } diff --git a/helper/testing/utils.go b/helper/testing/utils.go index 93eb861d..64e61f8a 100644 --- a/helper/testing/utils.go +++ b/helper/testing/utils.go @@ -34,7 +34,7 @@ func AssertContents(t *testing.T, path string, contents string) { } if !reflect.DeepEqual(data, []byte(contents)) { - t.Fatalf("bad. expected:\n\n%s\n\nGot:\n\n%s", contents, string(data)) + t.Fatalf("bad. expected:\n\n%q\n\nGot:\n\n%q", contents, string(data)) } } From e832f61c53ec14e5d07863ded0fe0f3ca4efe916 Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Thu, 19 May 2022 08:45:39 -0400 Subject: [PATCH 09/13] Update get_http_test to fail only when we expect it to Before Change ``` ~> go test ./... --- FAIL: TestHttpGetter__XTerraformGetConfiguredGettersBypass (0.01s) --- FAIL: TestHttpGetter__XTerraformGetConfiguredGettersBypass/configured_getter_for_git_protocol_switch (0.00s) get_http_test.go:742: http://127.0.0.1:52744/start get_http_test.go:726: making request get_http_test.go:903: serving start get_http_test.go:756: expected download not supported for scheme, got: /usr/local/bin/git exited with -1: --- FAIL: TestHttpGetter__XTerraformGetConfiguredGettersBypass/configured_getter_for_multiple_protocol_switch (0.00s) get_http_test.go:742: http://127.0.0.1:52746/start get_http_test.go:726: making request get_http_test.go:903: serving start get_http_test.go:756: expected download not supported for scheme, got: /usr/local/bin/git exited with -1: FAIL ``` After Change ``` ~> go test ./... ok github.com/hashicorp/go-getter/v2 13.908s ? github.com/hashicorp/go-getter/v2/helper/testing [no test files] ok github.com/hashicorp/go-getter/v2/helper/url (cached) ``` --- get_http_test.go | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/get_http_test.go b/get_http_test.go index 0ecc259a..cdfc821e 100644 --- a/get_http_test.go +++ b/get_http_test.go @@ -702,23 +702,23 @@ func TestHttpGetter__XTerraformGetConfiguredGettersBypass(t *testing.T) { errExpected bool }{ {name: "configured getter for git protocol switch", configuredGetters: []Getter{new(GitGetter)}, errExpected: false}, - {name: "configured getter for multiple protocol switch", configuredGetters: []Getter{new(HgGetter), new(GitGetter), new(FileGetter)}, errExpected: false}, + {name: "configured getter for multiple protocol switch", configuredGetters: []Getter{new(GitGetter), new(HgGetter), new(FileGetter)}, errExpected: false}, {name: "configured getter for file protocol switch", configuredGetters: []Getter{new(FileGetter)}, errExpected: true}, } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ln := testHttpServerWithXTerraformGetConfiguredGettersBypass(t) - - var u url.URL - u.Scheme = "http" - u.Host = ln.Addr().String() - u.Path = "/start" - for _, tt := range tc { tt := tt t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetConfiguredGettersBypass(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + dst := testing_helper.TempDir(t) rt := hookableHTTPRoundTripper{ @@ -747,11 +747,15 @@ func TestHttpGetter__XTerraformGetConfiguredGettersBypass(t *testing.T) { GetMode: ModeDir, } + _, err := client.Get(ctx, &req) + // For configured getters that support git, the git repository doesn't exist so error will not be nil. + // If we get a nil error when we expect one other than the git error git exited with -1 we should fail. if tt.errExpected && err == nil { t.Fatalf("error expected") } - if err != nil { + // We only care about the error messages that indicate that we can download the git header URL + if tt.errExpected && err != nil { if !strings.Contains(err.Error(), "download not supported for scheme") { t.Fatalf("expected download not supported for scheme, got: %v", err) } From 6fb9692327d24e3b6de111807b632b5fe2478b20 Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Thu, 19 May 2022 09:22:07 -0400 Subject: [PATCH 10/13] Update get_git_test.go Co-authored-by: Sylvia Moss --- get_git_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/get_git_test.go b/get_git_test.go index 16b24fd7..a8c5a456 100644 --- a/get_git_test.go +++ b/get_git_test.go @@ -784,7 +784,7 @@ func TestGitGetter_subdirectory_symlink(t *testing.T) { if info.Mode()&os.ModeSymlink == os.ModeSymlink { // If you see this test fail in the future, you've probably enabled // symlinks within git on your Windows system. Our CI/CD system does - // not do this, so this is this is the only way we can make this test + // not do this, so this is the only way we can make this test // make any sense. t.Fatalf("windows git should not have cloned a symlink") } From 8d38cad659085a3190b57091118fee368aadf700 Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Thu, 19 May 2022 09:29:59 -0400 Subject: [PATCH 11/13] Fix fmt error --- get_http_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/get_http_test.go b/get_http_test.go index cdfc821e..a1543dae 100644 --- a/get_http_test.go +++ b/get_http_test.go @@ -747,7 +747,6 @@ func TestHttpGetter__XTerraformGetConfiguredGettersBypass(t *testing.T) { GetMode: ModeDir, } - _, err := client.Get(ctx, &req) // For configured getters that support git, the git repository doesn't exist so error will not be nil. // If we get a nil error when we expect one other than the git error git exited with -1 we should fail. From 23af893fe8d52d605db5d3beda6bb55a031bf067 Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Thu, 19 May 2022 09:40:02 -0400 Subject: [PATCH 12/13] Remove windows test for FileGetter --- get_file_test.go | 2 - get_file_windows_test.go | 263 --------------------------------------- 2 files changed, 265 deletions(-) delete mode 100644 get_file_windows_test.go diff --git a/get_file_test.go b/get_file_test.go index 41dfe93d..71e05f7c 100644 --- a/get_file_test.go +++ b/get_file_test.go @@ -1,5 +1,3 @@ -// +build test unix - package getter import ( diff --git a/get_file_windows_test.go b/get_file_windows_test.go deleted file mode 100644 index d10dafbb..00000000 --- a/get_file_windows_test.go +++ /dev/null @@ -1,263 +0,0 @@ -// +build test windows - -package getter - -import ( - "context" - "os" - "path/filepath" - "testing" - - testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" - urlhelper "github.com/hashicorp/go-getter/v2/helper/url" -) - -func TestFileGetter_impl(t *testing.T) { - var _ Getter = new(FileGetter) -} - -func TestFileGetter(t *testing.T) { - g := new(FileGetter) - dst := testing_helper.TempDir(t) - ctx := context.Background() - - req := &Request{ - Dst: dst, - u: testModuleURL("basic"), - } - - // With a dir that doesn't exist - if err := g.Get(ctx, req); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify the destination folder is a symlink - fi, err := os.Lstat(dst) - if err != nil { - t.Fatalf("err: %s", err) - } - if fi.Mode()&os.ModeSymlink == 0 { - t.Fatal("destination is not a symlink") - } - - // Verify the main file exists - mainPath := filepath.Join(dst, "main.tf") - if _, err := os.Stat(mainPath); err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestFileGetter_sourceFile(t *testing.T) { - g := new(FileGetter) - dst := testing_helper.TempDir(t) - ctx := context.Background() - - // With a source URL that is a path to a file - u := testModuleURL("basic") - u.Path += "/main.tf" - u.RawPath = u.Path - - req := &Request{ - Dst: dst, - u: u, - } - if err := g.Get(ctx, req); err == nil { - t.Fatal("should error") - } -} - -func TestFileGetter_sourceNoExist(t *testing.T) { - g := new(FileGetter) - dst := testing_helper.TempDir(t) - ctx := context.Background() - - // With a source URL that doesn't exist - u := testModuleURL("basic") - u.Path += "/main" - u.RawPath = u.Path - - req := &Request{ - Dst: dst, - u: u, - } - if err := g.Get(ctx, req); err == nil { - t.Fatal("should error") - } -} - -func TestFileGetter_dir(t *testing.T) { - g := new(FileGetter) - dst := testing_helper.TempDir(t) - ctx := context.Background() - - if err := os.MkdirAll(dst, 0755); err != nil { - t.Fatalf("err: %s", err) - } - - req := &Request{ - Dst: dst, - u: testModuleURL("basic"), - } - // With a dir that exists that isn't a symlink - if err := g.Get(ctx, req); err == nil { - t.Fatal("should error") - } -} - -func TestFileGetter_dirSymlink(t *testing.T) { - g := new(FileGetter) - dst := testing_helper.TempDir(t) - ctx := context.Background() - - dst2 := testing_helper.TempDir(t) - - // Make parents - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { - t.Fatalf("err: %s", err) - } - if err := os.MkdirAll(dst2, 0755); err != nil { - t.Fatalf("err: %s", err) - } - - // Make a symlink - if err := os.Symlink(dst2, dst); err != nil { - t.Fatalf("err: %s", err) - } - - req := &Request{ - Dst: dst, - u: testModuleURL("basic"), - } - - // With a dir that exists that isn't a symlink - if err := g.Get(ctx, req); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify the main file exists - mainPath := filepath.Join(dst, "main.tf") - if _, err := os.Stat(mainPath); err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestFileGetter_GetFile(t *testing.T) { - g := new(FileGetter) - dst := testing_helper.TempTestFile(t) - defer os.RemoveAll(filepath.Dir(dst)) - ctx := context.Background() - - req := &Request{ - Dst: dst, - u: testModuleURL("basic-file/foo.txt"), - } - - // With a dir that doesn't exist - if err := g.GetFile(ctx, req); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify the destination folder is a symlink - fi, err := os.Lstat(dst) - if err != nil { - t.Fatalf("err: %s", err) - } - if fi.Mode()&os.ModeSymlink == 0 { - t.Fatal("destination is not a symlink") - } - - // Verify the main file exists - testing_helper.AssertContents(t, dst, "Hello\r\n") -} - -func TestFileGetter_GetFile_Copy(t *testing.T) { - g := new(FileGetter) - - dst := testing_helper.TempTestFile(t) - defer os.RemoveAll(filepath.Dir(dst)) - ctx := context.Background() - - req := &Request{ - Dst: dst, - u: testModuleURL("basic-file/foo.txt"), - Copy: true, - } - - // With a dir that doesn't exist - if err := g.GetFile(ctx, req); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify the destination folder is a symlink - fi, err := os.Lstat(dst) - if err != nil { - t.Fatalf("err: %s", err) - } - if fi.Mode()&os.ModeSymlink != 0 { - t.Fatal("destination is a symlink") - } - - // Verify the main file exists - testing_helper.AssertContents(t, dst, "Hello\r\n") -} - -// https://github.com/hashicorp/terraform/issues/8418 -func TestFileGetter_percent2F(t *testing.T) { - g := new(FileGetter) - dst := testing_helper.TempDir(t) - ctx := context.Background() - - req := &Request{ - Dst: dst, - u: testModuleURL("basic%2Ftest"), - } - - // With a dir that doesn't exist - if err := g.Get(ctx, req); err != nil { - t.Fatalf("err: %s", err) - } - - // Verify the main file exists - mainPath := filepath.Join(dst, "main.tf") - if _, err := os.Stat(mainPath); err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestFileGetter_Mode_notexist(t *testing.T) { - g := new(FileGetter) - ctx := context.Background() - - u := urlhelper.MustParse("nonexistent") - if _, err := g.Mode(ctx, u); err == nil { - t.Fatal("expect source file error") - } -} - -func TestFileGetter_Mode_file(t *testing.T) { - g := new(FileGetter) - ctx := context.Background() - - // Check the client mode when pointed at a file. - mode, err := g.Mode(ctx, testModuleURL("basic-file/foo.txt")) - if err != nil { - t.Fatalf("err: %s", err) - } - if mode != ModeFile { - t.Fatal("expect ModeFile") - } -} - -func TestFileGetter_Mode_dir(t *testing.T) { - g := new(FileGetter) - ctx := context.Background() - - // Check the client mode when pointed at a directory. - mode, err := g.Mode(ctx, testModuleURL("basic")) - if err != nil { - t.Fatalf("err: %s", err) - } - if mode != ModeDir { - t.Fatal("expect ModeDir") - } -} From d8bb00d13449cd8d3c670421989db165a4d0c894 Mon Sep 17 00:00:00 2001 From: Wilken Rivera Date: Thu, 19 May 2022 11:35:01 -0400 Subject: [PATCH 13/13] Change to next-get image --- .circleci/config.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index c25fc3ce..6c028952 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -49,11 +49,11 @@ commands: jobs: linux-tests: docker: - - image: circleci/golang:<< parameters.go-version >> + - image: cimg/go:<< parameters.go-version >> parameters: go-version: type: string - environment: + environment: <<: *ENVIRONMENT parallelism: 4 steps: @@ -104,7 +104,7 @@ jobs: path: *TEST_RESULTS_PATH windows-tests: - executor: + executor: name: win/default shell: bash --login -eo pipefail environment: @@ -115,12 +115,12 @@ jobs: type: string gotestsum-version: type: string - steps: + steps: - run: git config --global core.autocrlf false - checkout - attach_workspace: at: . - - run: + - run: name: Setup (remove pre-installed go) command: | rm -rf "c:\Go" @@ -131,16 +131,16 @@ jobs: - win-golang-<< parameters.go-version >>-cache-v1 - win-gomod-cache-{{ checksum "go.mod" }}-v1 - - run: + - run: name: Install go version << parameters.go-version >> - command: | + command: | if [ ! -d "c:\go" ]; then echo "Cache not found, installing new version of go" curl --fail --location https://dl.google.com/go/go<< parameters.go-version >>.windows-amd64.zip --output go.zip unzip go.zip -d "/c" fi - - run: + - run: command: go mod download - save_cache: @@ -176,7 +176,7 @@ jobs: go-smb-test: docker: - - image: circleci/golang:<< parameters.go-version >> + - image: cimg/go:<< parameters.go-version >> parameters: go-version: type: string