diff --git a/http/fetch/archive_fetcher.go b/http/fetch/archive_fetcher.go index fa3a82dde..16f70b920 100644 --- a/http/fetch/archive_fetcher.go +++ b/http/fetch/archive_fetcher.go @@ -46,7 +46,7 @@ type ArchiveFetcher struct { retries int maxDownloadSize int fileMode fs.FileMode - untarOpts []tar.TarOption + untarOpts []tar.Option hostnameOverwrite string filename string logger any @@ -75,9 +75,9 @@ func WithMaxDownloadSize(maxDownloadSize int) Option { } // WithUntar tells the ArchiveFetcher to untar the archive expecting it to be a tarball. -func WithUntar(opts ...tar.TarOption) Option { +func WithUntar(opts ...tar.Option) Option { return func(a *ArchiveFetcher) { - a.untarOpts = append([]tar.TarOption{}, opts...) // to make sure a.untarOpts won't be nil + a.untarOpts = append([]tar.Option{}, opts...) // to make sure a.untarOpts won't be nil } } diff --git a/oci/build.go b/oci/build.go index b6a3ddb22..2c9e5e2b6 100644 --- a/oci/build.go +++ b/oci/build.go @@ -17,17 +17,15 @@ limitations under the License. package oci import ( - "archive/tar" - "compress/gzip" "fmt" "io" "os" "path/filepath" "strings" - "time" "github.com/fluxcd/pkg/oci/internal/fs" "github.com/fluxcd/pkg/sourceignore" + "github.com/fluxcd/pkg/tar" ) // Build archives the given directory as a tarball to the given local path. @@ -37,17 +35,20 @@ func (c *Client) Build(artifactPath, sourceDir string, ignorePaths []string) (er } func build(artifactPath, sourceDir string, ignorePaths []string) (err error) { - absDir, err := filepath.Abs(sourceDir) + absSrc, err := filepath.Abs(sourceDir) if err != nil { return err } - dirStat, err := os.Stat(absDir) - if os.IsNotExist(err) { - return fmt.Errorf("invalid source dir path: %s", absDir) + srcInfo, err := os.Stat(absSrc) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("source path does not exist: %s", absSrc) + } + return fmt.Errorf("invalid source path %s: %w", absSrc, err) } - tf, err := os.CreateTemp(filepath.Split(absDir)) + tf, err := os.CreateTemp(filepath.Split(absSrc)) if err != nil { return err } @@ -58,110 +59,60 @@ func build(artifactPath, sourceDir string, ignorePaths []string) (err error) { } }() - ignore := strings.Join(ignorePaths, "\n") - domain := strings.Split(filepath.Clean(absDir), string(filepath.Separator)) - ps := sourceignore.ReadPatterns(strings.NewReader(ignore), domain) - matcher := sourceignore.NewMatcher(ps) - filter := func(p string, fi os.FileInfo) bool { - return matcher.Match(strings.Split(p, string(filepath.Separator)), fi.IsDir()) - } - - sz := &writeCounter{} - mw := io.MultiWriter(tf, sz) - - gw := gzip.NewWriter(mw) - tw := tar.NewWriter(gw) - if err := filepath.Walk(absDir, func(p string, fi os.FileInfo, err error) error { - if err != nil { - return err - } - - // Ignore anything that is not a file or directories e.g. symlinks - if m := fi.Mode(); !(m.IsRegular() || m.IsDir()) { - return nil - } - - if len(ignorePaths) > 0 && filter(p, fi) { - return nil + // If the source is a single file, stage it in a temp dir so Tar can + // archive it as a directory tree containing that one entry. + tarDir := absSrc + if !srcInfo.IsDir() { + stage, stageErr := os.MkdirTemp("", "oci-build-") + if stageErr != nil { + tf.Close() + return stageErr } + defer os.RemoveAll(stage) - header, err := tar.FileInfoHeader(fi, p) - if err != nil { - return err - } - if dirStat.IsDir() { - // The name needs to be modified to maintain directory structure - // as tar.FileInfoHeader only has access to the base name of the file. - // Ref: https://golang.org/src/archive/tar/common.go?#L6264 - // - // we only want to do this if a directory was passed in - relFilePath, err := filepath.Rel(absDir, p) - if err != nil { - return err - } - // Normalize file path so it works on windows - header.Name = filepath.ToSlash(relFilePath) - } - - // Remove any environment specific data. - header.Gid = 0 - header.Uid = 0 - header.Uname = "" - header.Gname = "" - header.ModTime = time.Time{} - header.AccessTime = time.Time{} - header.ChangeTime = time.Time{} - - if err := tw.WriteHeader(header); err != nil { - return err - } - - if !fi.Mode().IsRegular() { - return nil - } - f, err := os.Open(p) - if err != nil { - f.Close() - return err - } - if _, err := io.Copy(tw, f); err != nil { - f.Close() + if err := copyFileContents(filepath.Join(stage, srcInfo.Name()), absSrc, srcInfo.Mode()); err != nil { + tf.Close() return err } - return f.Close() - }); err != nil { - tw.Close() - gw.Close() - tf.Close() - return err + tarDir = stage } - if err := tw.Close(); err != nil { - gw.Close() - tf.Close() - return err + ignore := strings.Join(ignorePaths, "\n") + domain := strings.Split(filepath.Clean(tarDir), string(filepath.Separator)) + ps := sourceignore.ReadPatterns(strings.NewReader(ignore), domain) + matcher := sourceignore.NewMatcher(ps) + filter := func(p string, fi os.FileInfo) bool { + return matcher.Match(strings.Split(p, string(filepath.Separator)), fi.IsDir()) } - if err := gw.Close(); err != nil { + + if _, err := tar.Tar(tarDir, tf, tar.WithFilter(filter)); err != nil { tf.Close() return err } if err := tf.Close(); err != nil { return err } - if err := os.Chmod(tmpName, 0o640); err != nil { return err } - return fs.RenameWithFallback(tmpName, artifactPath) } -type writeCounter struct { - written int64 -} - -func (wc *writeCounter) Write(p []byte) (int, error) { - n := len(p) - wc.written += int64(n) - return n, nil +func copyFileContents(dst, src string, mode os.FileMode) (err error) { + sf, err := os.Open(src) + if err != nil { + return err + } + defer sf.Close() + df, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode.Perm()) + if err != nil { + return err + } + defer func() { + if closeErr := df.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() + _, err = io.Copy(df, sf) + return err } diff --git a/tar/doc.go b/tar/doc.go new file mode 100644 index 000000000..c7860b3a3 --- /dev/null +++ b/tar/doc.go @@ -0,0 +1,96 @@ +/* +Copyright 2026 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package tar provides utilities for creating and extracting tar +// archives, with optional gzip compression. Tar writes a sanitized +// archive of a directory tree, skipping symlinks and other non-regular, +// non-directory entries; use ResolveSymlinks (or the confined +// ResolveSymlinksRoot) to materialize symlink targets before archiving. +// Untar safely extracts a tar archive into a target directory, +// rejecting path traversal and capping the total decompressed size. +// +// # Creating an archive +// +// Archive a directory tree to a file as a gzip-compressed tarball: +// +// f, err := os.Create("archive.tar.gz") +// if err != nil { +// return err +// } +// defer f.Close() +// +// if _, err := tar.Tar("/path/to/dir", f); err != nil { +// return err +// } +// +// Exclude entries with a filter and write a plain (non-gzipped) tar: +// +// skipHidden := func(p string, fi os.FileInfo) bool { +// return strings.HasPrefix(fi.Name(), ".") +// } +// _, err := tar.Tar("/path/to/dir", f, +// tar.WithFilter(skipHidden), +// tar.WithSkipGzip(), +// ) +// +// # Extracting an archive +// +// Extract a gzip-compressed tarball into a directory: +// +// f, err := os.Open("archive.tar.gz") +// if err != nil { +// return err +// } +// defer f.Close() +// +// if err := tar.Untar(f, "/path/to/target"); err != nil { +// return err +// } +// +// Raise the size limit and tolerate symlinks in the archive: +// +// err := tar.Untar(f, "/path/to/target", +// tar.WithMaxUntarSize(500<<20), // 500 MiB +// tar.WithSkipSymlinks(), +// ) +// +// # Archiving with symlinks resolved +// +// By default Tar skips symlinks. For inputs where files live behind +// symlinks (for example, manifest trees generated by Nix), stage the +// source into a caller-owned directory with ResolveSymlinks first, +// then archive the resolved tree: +// +// tmpDir, err := os.MkdirTemp("", "resolve-") +// if err != nil { +// return err +// } +// defer os.RemoveAll(tmpDir) +// +// if err := tar.ResolveSymlinks("/path/to/dir", tmpDir); err != nil { +// return err +// } +// +// if _, err := tar.Tar(tmpDir, w); err != nil { +// return err +// } +// +// For untrusted source trees, use ResolveSymlinksRoot to confine every +// symlink target inside a caller-supplied rootDir. Targets that resolve +// outside rootDir cause the call to fail without materializing them: +// +// err := tar.ResolveSymlinksRoot("/path/to/root", "/path/to/root/src", tmpDir) +package tar diff --git a/tar/options.go b/tar/options.go new file mode 100644 index 000000000..d5575e891 --- /dev/null +++ b/tar/options.go @@ -0,0 +1,81 @@ +/* +Copyright 2026 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tar + +import "os" + +// Option configures the behavior of Tar and Untar. Options are +// silently ignored by operations they do not apply to. +type Option func(*tarOpts) + +type tarOpts struct { + // maxUntarSize represents the limit size (bytes) for archives being decompressed by Untar. + // When max is a negative value the size checks are disabled. + maxUntarSize int + + // skipSymlinks ignores symlinks instead of failing the decompression. + skipSymlinks bool + + // skipGzip disables gzip compression: Tar writes a plain tar stream, + // and Untar reads one. + skipGzip bool + + // filter is called for each entry during archiving or extraction. + // If it returns true, the entry is excluded. + filter func(path string, fi os.FileInfo) bool +} + +// WithMaxUntarSize sets the limit size for archives being decompressed by Untar. +// When max is equal or less than 0 disables size checks. +func WithMaxUntarSize(max int) Option { + return func(t *tarOpts) { + t.maxUntarSize = max + } +} + +// WithSkipSymlinks allows for symlinks to be present +// in the tarball and skips them when decompressing. +func WithSkipSymlinks() Option { + return func(t *tarOpts) { + t.skipSymlinks = true + } +} + +// WithSkipGzip disables gzip compression: Tar writes a plain tar stream, +// and Untar reads one. +func WithSkipGzip() Option { + return func(t *tarOpts) { + t.skipGzip = true + } +} + +// WithFilter sets a predicate called for each entry during archiving +// or extraction. Entries for which fn returns true are excluded. During +// Tar the path is the absolute filesystem path; during Untar it is the +// slash-separated name from the tar header. +func WithFilter(fn func(path string, fi os.FileInfo) bool) Option { + return func(t *tarOpts) { + t.filter = fn + } +} + +// applyOpts applies the given Option to t. +func (t *tarOpts) applyOpts(opts ...Option) { + for _, opt := range opts { + opt(t) + } +} diff --git a/tar/symlink_test.go b/tar/symlink_test.go deleted file mode 100644 index af190f5fe..000000000 --- a/tar/symlink_test.go +++ /dev/null @@ -1,132 +0,0 @@ -/* -Copyright 2023 The Flux authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package tar - -import ( - "archive/tar" - "bytes" - "compress/gzip" - "io" - "os" - "path" - "path/filepath" - "testing" -) - -func TestSkipSymlinks(t *testing.T) { - tmpDir := t.TempDir() - - symlinkTarget := filepath.Join(tmpDir, "symlink.target") - err := os.WriteFile(symlinkTarget, geRandomContent(256), os.ModePerm) - if err != nil { - t.Fatal(err) - } - - symlink := filepath.Join(tmpDir, "symlink") - err = os.Symlink(symlinkTarget, symlink) - if err != nil { - t.Fatal(err) - } - - tgzFileName := filepath.Join(t.TempDir(), "test.tgz") - var buf bytes.Buffer - err = tgzWithSymlinks(tmpDir, &buf) - if err != nil { - t.Fatal(err) - } - - tgzFile, err := os.OpenFile(tgzFileName, os.O_CREATE|os.O_RDWR, os.ModePerm) - if err != nil { - t.Fatal(err) - } - if _, err := io.Copy(tgzFile, &buf); err != nil { - t.Fatal(err) - } - if err = tgzFile.Close(); err != nil { - t.Fatal(err) - } - - targetDirOutput := filepath.Join(t.TempDir(), "output") - f1, err := os.Open(tgzFileName) - if err != nil { - t.Fatal(err) - } - - err = Untar(f1, targetDirOutput, WithMaxUntarSize(-1)) - if err == nil { - t.Errorf("wanted error: unsupported symlink") - } - - f2, err := os.Open(tgzFileName) - if err != nil { - t.Fatal(err) - } - - err = Untar(f2, targetDirOutput, WithMaxUntarSize(-1), WithSkipSymlinks()) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - if _, err := os.Open(path.Join(targetDirOutput, "symlink.target")); err != nil { - t.Errorf("regular file not found: %v", err) - } -} - -func tgzWithSymlinks(src string, buf io.Writer) error { - absDir, err := filepath.Abs(src) - if err != nil { - return err - } - - zr := gzip.NewWriter(buf) - tw := tar.NewWriter(zr) - if err := filepath.Walk(absDir, func(file string, fi os.FileInfo, err error) error { - if err != nil { - return err - } - - header, err := tar.FileInfoHeader(fi, file) - if err != nil { - return err - } - if err := tw.WriteHeader(header); err != nil { - return err - } - - if fi.Mode().IsRegular() { - f, err := os.Open(file) - if err != nil { - return err - } - if _, err := io.Copy(tw, f); err != nil { - return err - } - return f.Close() - } - - return nil - }); err != nil { - return err - } - if err := tw.Close(); err != nil { - return err - } - if err := zr.Close(); err != nil { - return err - } - return nil -} diff --git a/tar/symlinks.go b/tar/symlinks.go new file mode 100644 index 000000000..5a41f8c2b --- /dev/null +++ b/tar/symlinks.go @@ -0,0 +1,275 @@ +/* +Copyright 2026 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tar + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +// ResolveSymlinks stages the contents of srcDir into dstDir, following +// every symlink and copying the target's contents. Both paths must be +// existing directories. +// +// Symlink cycles are detected and skipped. Non-regular, non-directory +// entries (devices, sockets, etc.) are skipped. +// +// This is a staging helper intended to be paired with Tar for inputs +// that contain symlinks (for example, manifest trees generated by Nix +// where the actual files live outside the source directory). +// +// Security: symlinks may point outside srcDir. Their targets are +// materialized in dstDir under the link's name. Only use with trusted +// input. +func ResolveSymlinks(srcDir, dstDir string) error { + absSrc, err := filepath.Abs(srcDir) + if err != nil { + return err + } + realSrc, err := filepath.EvalSymlinks(absSrc) + if err != nil { + return fmt.Errorf("resolving srcDir: %w", err) + } + + info, err := os.Stat(realSrc) + if err != nil { + return err + } + if !info.IsDir() { + return fmt.Errorf("srcDir %s is not a directory", absSrc) + } + + if err := checkDstDir(dstDir); err != nil { + return err + } + + return copyResolvedDir(realSrc, dstDir, make(map[string]bool)) +} + +// ResolveSymlinksRoot is the confined variant of ResolveSymlinks: every +// symlink target must resolve within rootDir. A symlink whose final +// resolved path is outside rootDir causes the function to fail. +// +// srcDir must be within (or equal to) rootDir. dstDir is not required +// to be within rootDir. +// +// Symlink cycles are detected and skipped. Non-regular, non-directory +// entries are skipped. +// +// Note: containment is checked against the resolved absolute path, so +// intermediate "../" navigation within rootDir is allowed (common in +// Nix-style trees). Symlink resolution uses filepath.EvalSymlinks; this +// leaves a small TOCTOU window and should not be used to defend against +// an adversary with concurrent write access to the source tree. +func ResolveSymlinksRoot(rootDir, srcDir, dstDir string) error { + absRoot, err := filepath.Abs(rootDir) + if err != nil { + return err + } + absSrc, err := filepath.Abs(srcDir) + if err != nil { + return err + } + + realRoot, err := filepath.EvalSymlinks(absRoot) + if err != nil { + return fmt.Errorf("resolving rootDir: %w", err) + } + realSrc, err := filepath.EvalSymlinks(absSrc) + if err != nil { + return fmt.Errorf("resolving srcDir: %w", err) + } + + if !isWithin(realRoot, realSrc) { + return fmt.Errorf("srcDir %s is not within rootDir %s", absSrc, absRoot) + } + + info, err := os.Stat(realSrc) + if err != nil { + return err + } + if !info.IsDir() { + return fmt.Errorf("srcDir %s is not a directory", absSrc) + } + + if err := checkDstDir(dstDir); err != nil { + return err + } + + return copyConfinedDir(realRoot, realSrc, dstDir, make(map[string]bool)) +} + +// checkDstDir verifies that dstDir exists and is a directory. +func checkDstDir(dstDir string) error { + info, err := os.Stat(dstDir) + if err != nil { + return fmt.Errorf("stat dstDir: %w", err) + } + if !info.IsDir() { + return fmt.Errorf("dstDir %s is not a directory", dstDir) + } + return nil +} + +// copyResolvedDir recursively copies srcDir (already resolved via +// EvalSymlinks) into dstDir. visited tracks resolved directory paths +// currently on the call stack so that a re-entry via a symlink does +// not loop. Entries are removed when the call returns, so the same +// directory may be copied again through a different symlink — this is +// intentional (both link sites need the content). +func copyResolvedDir(srcDir, dstDir string, visited map[string]bool) error { + if visited[srcDir] { + return nil + } + visited[srcDir] = true + defer delete(visited, srcDir) + + entries, err := os.ReadDir(srcDir) + if err != nil { + return err + } + + for _, entry := range entries { + srcPath := filepath.Join(srcDir, entry.Name()) + dstPath := filepath.Join(dstDir, entry.Name()) + + realPath, err := filepath.EvalSymlinks(srcPath) + if err != nil { + return fmt.Errorf("resolving symlink %s: %w", srcPath, err) + } + realInfo, err := os.Stat(realPath) + if err != nil { + return fmt.Errorf("stat resolved path %s: %w", realPath, err) + } + + if realInfo.IsDir() { + if err := os.MkdirAll(dstPath, realInfo.Mode()); err != nil { + return err + } + if err := copyResolvedDir(realPath, dstPath, visited); err != nil { + return err + } + continue + } + + if !realInfo.Mode().IsRegular() { + continue + } + + if err := copyResolvedFile(realPath, dstPath, realInfo.Mode()); err != nil { + return err + } + } + return nil +} + +// copyConfinedDir is the root-confined equivalent of copyResolvedDir. +// srcDir is assumed already resolved and already verified as within +// realRoot. visited is a stack-based cycle breaker (see copyResolvedDir). +func copyConfinedDir(realRoot, srcDir, dstDir string, visited map[string]bool) error { + if visited[srcDir] { + return nil + } + visited[srcDir] = true + defer delete(visited, srcDir) + + entries, err := os.ReadDir(srcDir) + if err != nil { + return err + } + + for _, entry := range entries { + srcPath := filepath.Join(srcDir, entry.Name()) + dstPath := filepath.Join(dstDir, entry.Name()) + + realPath, err := filepath.EvalSymlinks(srcPath) + if err != nil { + return fmt.Errorf("resolving %s: %w", srcPath, err) + } + // Report the logical path of the offending symlink, not the + // resolved target, to avoid leaking filesystem layout. + if !isWithin(realRoot, realPath) { + return fmt.Errorf("symlink %s resolves outside rootDir", srcPath) + } + realInfo, err := os.Stat(realPath) + if err != nil { + return fmt.Errorf("stat %s: %w", realPath, err) + } + + if realInfo.IsDir() { + if err := os.MkdirAll(dstPath, realInfo.Mode()); err != nil { + return err + } + if err := copyConfinedDir(realRoot, realPath, dstPath, visited); err != nil { + return err + } + continue + } + + if !realInfo.Mode().IsRegular() { + continue + } + + if err := copyResolvedFile(realPath, dstPath, realInfo.Mode()); err != nil { + return err + } + } + return nil +} + +// isWithin reports whether path is equal to or inside root. Both +// arguments should be absolute, cleaned paths; the caller is +// responsible for resolving symlinks in root ahead of time. +func isWithin(root, path string) bool { + rel, err := filepath.Rel(root, path) + if err != nil { + return false + } + return rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) +} + +// copyResolvedFile copies the contents of src to dst with the given +// mode. src is expected to be a regular file (EvalSymlinks'd by the +// caller) and filepath.Dir(dst) is expected to already exist — the +// walk creates each subdirectory via os.MkdirAll before descending, so +// the parent of dst is guaranteed to be present. +func copyResolvedFile(src, dst string, mode os.FileMode) (err error) { + in, err := os.Open(src) + if err != nil { + return err + } + defer func() { + if closeErr := in.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() + + out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode) + if err != nil { + return err + } + defer func() { + if closeErr := out.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() + _, err = io.Copy(out, in) + return err +} diff --git a/tar/symlinks_test.go b/tar/symlinks_test.go new file mode 100644 index 000000000..a357422f3 --- /dev/null +++ b/tar/symlinks_test.go @@ -0,0 +1,373 @@ +/* +Copyright 2026 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tar + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +func TestResolveSymlinks_directoryWithSymlinkToFile(t *testing.T) { + // external/real.txt <-- symlink target outside source + // src/link.txt -> ../external/real.txt + // src/regular.txt + root := t.TempDir() + external := filepath.Join(root, "external") + src := filepath.Join(root, "src") + if err := os.MkdirAll(external, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(src, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(external, "real.txt"), []byte("hello"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(src, "regular.txt"), []byte("world"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Symlink(filepath.Join(external, "real.txt"), filepath.Join(src, "link.txt")); err != nil { + t.Fatal(err) + } + + dst := t.TempDir() + if err := ResolveSymlinks(src, dst); err != nil { + t.Fatalf("ResolveSymlinks: %v", err) + } + + got, err := os.ReadFile(filepath.Join(dst, "link.txt")) + if err != nil { + t.Fatalf("read link.txt: %v", err) + } + if string(got) != "hello" { + t.Errorf("link.txt content: got %q, want %q", got, "hello") + } + + got, err = os.ReadFile(filepath.Join(dst, "regular.txt")) + if err != nil { + t.Fatalf("read regular.txt: %v", err) + } + if string(got) != "world" { + t.Errorf("regular.txt content: got %q, want %q", got, "world") + } + + // The staged tree must contain no symlinks. + fi, err := os.Lstat(filepath.Join(dst, "link.txt")) + if err != nil { + t.Fatal(err) + } + if fi.Mode()&os.ModeSymlink != 0 { + t.Errorf("staged link.txt is still a symlink") + } +} + +func TestResolveSymlinks_rejectsFileInput(t *testing.T) { + root := t.TempDir() + file := filepath.Join(root, "plain.txt") + if err := os.WriteFile(file, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + + err := ResolveSymlinks(file, t.TempDir()) + if err == nil { + t.Fatal("expected error for file input") + } +} + +func TestResolveSymlinks_symlinkToDir(t *testing.T) { + // external/{a.txt, b.txt} + // src/nested -> ../external + root := t.TempDir() + external := filepath.Join(root, "external") + src := filepath.Join(root, "src") + if err := os.MkdirAll(external, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(src, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(external, "a.txt"), []byte("A"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(external, "b.txt"), []byte("B"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Symlink(external, filepath.Join(src, "nested")); err != nil { + t.Fatal(err) + } + + dst := t.TempDir() + if err := ResolveSymlinks(src, dst); err != nil { + t.Fatalf("ResolveSymlinks: %v", err) + } + + for name, want := range map[string]string{"a.txt": "A", "b.txt": "B"} { + got, err := os.ReadFile(filepath.Join(dst, "nested", name)) + if err != nil { + t.Fatalf("read nested/%s: %v", name, err) + } + if string(got) != want { + t.Errorf("nested/%s: got %q, want %q", name, got, want) + } + } +} + +func TestResolveSymlinks_cycle(t *testing.T) { + // src/ + // a/ + // loop -> ../b (symlink) + // b/ + // loop -> ../a (symlink) + // file.txt + root := t.TempDir() + src := filepath.Join(root, "src") + a := filepath.Join(src, "a") + b := filepath.Join(src, "b") + if err := os.MkdirAll(a, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(b, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(b, "file.txt"), []byte("hi"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Symlink(b, filepath.Join(a, "loop")); err != nil { + t.Fatal(err) + } + if err := os.Symlink(a, filepath.Join(b, "loop")); err != nil { + t.Fatal(err) + } + + dst := t.TempDir() + if err := ResolveSymlinks(src, dst); err != nil { + t.Fatalf("ResolveSymlinks: %v", err) + } + + got, err := os.ReadFile(filepath.Join(dst, "b", "file.txt")) + if err != nil { + t.Fatalf("read b/file.txt: %v", err) + } + if string(got) != "hi" { + t.Errorf("b/file.txt: got %q, want %q", got, "hi") + } +} + +func TestResolveSymlinks_thenTar(t *testing.T) { + // End-to-end: resolve a tree with external symlinks, then tar. + root := t.TempDir() + external := filepath.Join(root, "external") + src := filepath.Join(root, "src") + if err := os.MkdirAll(external, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(src, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(external, "real.yaml"), []byte("kind: Thing"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Symlink(filepath.Join(external, "real.yaml"), filepath.Join(src, "manifest.yaml")); err != nil { + t.Fatal(err) + } + + dst := t.TempDir() + if err := ResolveSymlinks(src, dst); err != nil { + t.Fatalf("ResolveSymlinks: %v", err) + } + + var buf bytes.Buffer + if _, err := Tar(dst, &buf); err != nil { + t.Fatalf("Tar: %v", err) + } + + got := readTarEntries(t, &buf) + if got["manifest.yaml"] != "kind: Thing" { + t.Errorf("manifest.yaml: got %q, want %q", got["manifest.yaml"], "kind: Thing") + } +} + +func TestResolveSymlinks_nonexistent(t *testing.T) { + err := ResolveSymlinks("/definitely/not/a/real/path/xyzzy", t.TempDir()) + if err == nil { + t.Fatal("expected error for nonexistent path") + } +} + +func TestResolveSymlinks_rejectsMissingDst(t *testing.T) { + src := t.TempDir() + err := ResolveSymlinks(src, filepath.Join(t.TempDir(), "does-not-exist")) + if err == nil { + t.Fatal("expected error for missing dstDir") + } +} + +func TestResolveSymlinks_rejectsDstFile(t *testing.T) { + src := t.TempDir() + dst := filepath.Join(t.TempDir(), "file") + if err := os.WriteFile(dst, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + err := ResolveSymlinks(src, dst) + if err == nil { + t.Fatal("expected error for dstDir that is a file") + } +} + +func TestResolveSymlinksRoot_resolvesWithinRoot(t *testing.T) { + // root/ + // src/link.txt -> ../external/real.txt + // external/real.txt + root := t.TempDir() + src := filepath.Join(root, "src") + external := filepath.Join(root, "external") + if err := os.MkdirAll(src, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(external, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(external, "real.txt"), []byte("ok"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Symlink(filepath.Join(external, "real.txt"), filepath.Join(src, "link.txt")); err != nil { + t.Fatal(err) + } + + dst := t.TempDir() + if err := ResolveSymlinksRoot(root, src, dst); err != nil { + t.Fatalf("ResolveSymlinksRoot: %v", err) + } + + got, err := os.ReadFile(filepath.Join(dst, "link.txt")) + if err != nil { + t.Fatalf("read link.txt: %v", err) + } + if string(got) != "ok" { + t.Errorf("content: got %q, want %q", got, "ok") + } +} + +func TestResolveSymlinksRoot_rejectsEscape(t *testing.T) { + // root/src/escape -> /tmp//target.txt + root := t.TempDir() + src := filepath.Join(root, "src") + if err := os.MkdirAll(src, 0o755); err != nil { + t.Fatal(err) + } + + outside := t.TempDir() // sibling temp dir, outside root + target := filepath.Join(outside, "target.txt") + if err := os.WriteFile(target, []byte("secret"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Symlink(target, filepath.Join(src, "escape")); err != nil { + t.Fatal(err) + } + + dst := t.TempDir() + err := ResolveSymlinksRoot(root, src, dst) + if err == nil { + t.Fatal("expected error for symlink escaping root") + } + + // The escaped content must not have been materialized. + if _, statErr := os.Stat(filepath.Join(dst, "escape")); statErr == nil { + t.Error("escape target was materialized in dstDir") + } +} + +func TestResolveSymlinksRoot_rejectsSrcOutsideRoot(t *testing.T) { + root := t.TempDir() + otherRoot := t.TempDir() + + err := ResolveSymlinksRoot(root, otherRoot, t.TempDir()) + if err == nil { + t.Fatal("expected error when srcDir is outside rootDir") + } +} + +func TestResolveSymlinksRoot_cycle(t *testing.T) { + // Cycle within the root should terminate. + root := t.TempDir() + src := filepath.Join(root, "src") + a := filepath.Join(src, "a") + b := filepath.Join(src, "b") + if err := os.MkdirAll(a, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(b, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(b, "file.txt"), []byte("hi"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Symlink(b, filepath.Join(a, "loop")); err != nil { + t.Fatal(err) + } + if err := os.Symlink(a, filepath.Join(b, "loop")); err != nil { + t.Fatal(err) + } + + dst := t.TempDir() + if err := ResolveSymlinksRoot(root, src, dst); err != nil { + t.Fatalf("ResolveSymlinksRoot: %v", err) + } + + got, err := os.ReadFile(filepath.Join(dst, "b", "file.txt")) + if err != nil { + t.Fatalf("read b/file.txt: %v", err) + } + if string(got) != "hi" { + t.Errorf("b/file.txt: got %q, want %q", got, "hi") + } +} + +func TestResolveSymlinksRoot_thenTar(t *testing.T) { + root := t.TempDir() + src := filepath.Join(root, "src") + external := filepath.Join(root, "external") + if err := os.MkdirAll(src, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(external, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(external, "real.yaml"), []byte("kind: Thing"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Symlink(filepath.Join(external, "real.yaml"), filepath.Join(src, "manifest.yaml")); err != nil { + t.Fatal(err) + } + + dst := t.TempDir() + if err := ResolveSymlinksRoot(root, src, dst); err != nil { + t.Fatalf("ResolveSymlinksRoot: %v", err) + } + + var buf bytes.Buffer + if _, err := Tar(dst, &buf); err != nil { + t.Fatalf("Tar: %v", err) + } + + got := readTarEntries(t, &buf) + if got["manifest.yaml"] != "kind: Thing" { + t.Errorf("manifest.yaml: got %q, want %q", got["manifest.yaml"], "kind: Thing") + } +} diff --git a/tar/tar.go b/tar/tar.go index f6538bed8..355f6d149 100644 --- a/tar/tar.go +++ b/tar/tar.go @@ -1,246 +1,149 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. +/* +Copyright 2026 The Flux authors -// Copyright 2020 The FluxCD contributors. All rights reserved. -// Adapted from: golang.org/x/build/internal/untar +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ -// Package tar provides ways to manage tarball files. package tar import ( "archive/tar" "compress/gzip" - "errors" "fmt" "io" - "io/fs" "os" "path/filepath" - "runtime" - "strings" "time" - - securejoin "github.com/cyphar/filepath-securejoin" ) -const ( - // DefaultMaxUntarSize defines the default (100MB) max amount of bytes that Untar will process. - DefaultMaxUntarSize = 100 << (10 * 2) - - // UnlimitedUntarSize defines the value which disables untar size checks for maxUntarSize. - UnlimitedUntarSize = -1 - - // bufferSize defines the size of the buffer used when copying the tar file entries. - bufferSize = 32 * 1024 -) - -type tarOpts struct { - // maxUntarSize represents the limit size (bytes) for archives being decompressed by Untar. - // When max is a negative value the size checks are disabled. - maxUntarSize int +// Tar writes a tar archive of dir to w and returns the number of bytes +// written. +// +// By default, the archive is gzip-compressed; use WithSkipGzip to write +// a plain tar stream. Use WithFilter to exclude entries by path or +// FileInfo. The directory tree is walked recursively; symlinks and +// other non-regular, non-directory entries are silently skipped. +// Headers are sanitized to produce reproducible archives: uid, gid, +// user and group names, and all timestamps are zeroed. +func Tar(dir string, w io.Writer, opts ...Option) (int64, error) { + var o tarOpts + o.applyOpts(opts...) + + absDir, err := filepath.Abs(dir) + if err != nil { + return 0, err + } - // skipSymlinks ignores symlinks instead of failing the decompression. - skipSymlinks bool + if fi, err := os.Stat(absDir); err != nil { + return 0, fmt.Errorf("invalid dir path %s: %w", absDir, err) + } else if !fi.IsDir() { + return 0, fmt.Errorf("not a directory: %s", absDir) + } - // skipGzip skip gzip reader an un-tar a plain tar file. - skipGzip bool -} + cw := &countWriter{w: w} -// Untar reads the gzip-compressed tar file from r and writes it into dir. -// -// If dir is a relative path, it cannot ascend from the current working dir. -// If dir exists, it must be a directory. -func Untar(r io.Reader, dir string, inOpts ...TarOption) (err error) { - opts := tarOpts{ - maxUntarSize: DefaultMaxUntarSize, + var gw *gzip.Writer + var tw *tar.Writer + if o.skipGzip { + tw = tar.NewWriter(cw) + } else { + gw = gzip.NewWriter(cw) + tw = tar.NewWriter(gw) } - opts.applyOpts(inOpts...) - dir = filepath.Clean(dir) - if !filepath.IsAbs(dir) { - cwd, err := os.Getwd() + buf := make([]byte, bufferSize) + if err := filepath.Walk(absDir, func(p string, fi os.FileInfo, err error) error { if err != nil { return err } - dir, err = securejoin.SecureJoin(cwd, dir) - if err != nil { - return err + // Skip symlinks and other non-regular, non-directory entries. + if m := fi.Mode(); !(m.IsRegular() || m.IsDir()) { + return nil } - } - - fi, err := os.Lstat(dir) - // Dir does not need to exist, as it can later be created. - if err != nil && !errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("cannot lstat '%s': %w", dir, err) - } - if err == nil && !fi.IsDir() { - return fmt.Errorf("dir '%s' must be a directory", dir) - } + if o.filter != nil && o.filter(p, fi) { + return nil + } - madeDir := map[string]bool{} - var tr *tar.Reader - if opts.skipGzip { - tr = tar.NewReader(r) - } else { - zr, err := gzip.NewReader(r) + header, err := tar.FileInfoHeader(fi, p) if err != nil { - return fmt.Errorf("requires gzip-compressed body: %w", err) + return err } - tr = tar.NewReader(zr) - } - - processedBytes := 0 - t0 := time.Now() + relPath, err := filepath.Rel(absDir, p) + if err != nil { + return err + } + header.Name = filepath.ToSlash(relPath) + + // Sanitize environment-specific data. + header.Gid = 0 + header.Uid = 0 + header.Uname = "" + header.Gname = "" + header.ModTime = time.Time{} + header.AccessTime = time.Time{} + header.ChangeTime = time.Time{} + + if err := tw.WriteHeader(header); err != nil { + return err + } - // For improved concurrency, this could be optimised by sourcing - // the buffer from a sync.Pool. - buf := make([]byte, bufferSize) - for { - f, err := tr.Next() - if err == io.EOF { - break + if !fi.Mode().IsRegular() { + return nil } + + f, err := os.Open(p) if err != nil { - return fmt.Errorf("tar error: %w", err) - } - processedBytes += int(f.Size) - if opts.maxUntarSize > UnlimitedUntarSize && - processedBytes > opts.maxUntarSize { - return fmt.Errorf("tar %q is bigger than max archive size of %d bytes", f.Name, opts.maxUntarSize) + return err } - if !validRelPath(f.Name) { - return fmt.Errorf("tar contained invalid name error %q", f.Name) + _, err = copyBuffer(tw, f, buf) + if closeErr := f.Close(); closeErr != nil && err == nil { + err = closeErr } - rel := filepath.FromSlash(f.Name) - abs := filepath.Join(dir, rel) - - fi := f.FileInfo() - mode := fi.Mode() - - switch { - case mode.IsRegular(): - // Make the directory. This is redundant because it should - // already be made by a directory entry in the tar - // beforehand. Thus, don't check for errors; the next - // write will fail with the same error. - dir := filepath.Dir(abs) - if !madeDir[dir] { - if err := os.MkdirAll(filepath.Dir(abs), 0o750); err != nil { - return err - } - madeDir[dir] = true - } - if runtime.GOOS == "darwin" && mode&0111 != 0 { - // The darwin kernel caches binary signatures - // and SIGKILLs binaries with mismatched - // signatures. Overwriting a binary with - // O_TRUNC does not clear the cache, rendering - // the new copy unusable. Removing the original - // file first does clear the cache. See #54132. - err := os.Remove(abs) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return err - } - } - wf, err := os.OpenFile(abs, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm()) - if err != nil { - return err - } - - n, err := copyBuffer(wf, tr, buf) - if err != nil && err != io.EOF { - return fmt.Errorf("error copying buffer: %w", err) - } - - if closeErr := wf.Close(); closeErr != nil && err == nil { - err = closeErr - } - if err != nil { - return fmt.Errorf("error writing to %s: %w", abs, err) - } - if n != f.Size { - return fmt.Errorf("only wrote %d bytes to %s; expected %d", n, abs, f.Size) - } - modTime := f.ModTime - if modTime.After(t0) { - // Ensures that that files extracted are not newer then the - // current system time. - modTime = t0 - } - if !modTime.IsZero() { - if err = os.Chtimes(abs, modTime, modTime); err != nil { - return fmt.Errorf("error changing file time %s: %w", abs, err) - } - } - case mode.IsDir(): - // Ensure the owner can always traverse, read, and write - // into extracted directories, regardless of what the tar - // header claims. This prevents crafted archives from - // creating directories that block cleanup or future writes. - dirPerm := mode.Perm() | 0o700 - if err := os.MkdirAll(abs, dirPerm); err != nil { - return err - } - madeDir[abs] = true - case mode&os.ModeSymlink == os.ModeSymlink: - if !opts.skipSymlinks { - return fmt.Errorf("tar file entry %s is a symlink, which is not allowed in this context", f.Name) - } - default: - return fmt.Errorf("tar file entry %s contained unsupported file type %v", f.Name, mode) + return err + }); err != nil { + _ = tw.Close() + if gw != nil { + _ = gw.Close() } + return cw.n, err } - return nil -} -// Uses a variant of io.CopyBuffer which ensures that a buffer is being used. -// The upstream version prioritises the use of interfaces WriterTo and ReadFrom -// which in this case causes the entirety of the tar file entry to be loaded -// into memory. -// -// Original source: -// https://github.com/golang/go/blob/6f445a9db55f65e55c5be29d3c506ecf3be37915/src/io/io.go#L405 -func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { - if buf == nil { - return 0, fmt.Errorf("buf is nil") - } - for { - nr, er := src.Read(buf) - if nr > 0 { - nw, ew := dst.Write(buf[0:nr]) - if nw < 0 || nr < nw { - nw = 0 - if ew == nil { - ew = fmt.Errorf("errInvalidWrite") - } - } - written += int64(nw) - if ew != nil { - err = ew - break - } - if nr != nw { - err = io.ErrShortWrite - break - } + if err := tw.Close(); err != nil { + if gw != nil { + _ = gw.Close() } - if er != nil { - if er != io.EOF { - err = er - } - break + return cw.n, err + } + if gw != nil { + if err := gw.Close(); err != nil { + return cw.n, err } } - return written, err + + return cw.n, nil } -func validRelPath(p string) bool { - if p == "" || strings.Contains(p, `\`) || strings.HasPrefix(p, "/") || strings.Contains(p, "../") { - return false - } - return true +// countWriter wraps an io.Writer and counts the bytes written. +type countWriter struct { + w io.Writer + n int64 +} + +func (cw *countWriter) Write(p []byte) (int, error) { + n, err := cw.w.Write(p) + cw.n += int64(n) + return n, err } diff --git a/tar/tar_opts.go b/tar/tar_opts.go deleted file mode 100644 index 1d433c28d..000000000 --- a/tar/tar_opts.go +++ /dev/null @@ -1,48 +0,0 @@ -/* -Copyright 2022 The Flux authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package tar - -// TarOption represents options to be applied to Tar. -type TarOption func(*tarOpts) - -// WithMaxUntarSize sets the limit size for archives being decompressed by Untar. -// When max is equal or less than 0 disables size checks. -func WithMaxUntarSize(max int) TarOption { - return func(t *tarOpts) { - t.maxUntarSize = max - } -} - -// WithSkipSymlinks allows for symlinks to be present in the tarball and skips them when decompressing. -func WithSkipSymlinks() TarOption { - return func(t *tarOpts) { - t.skipSymlinks = true - } -} - -// WithSkipGzip allows for un-taring plain tar files too, that aren't gzipped. -func WithSkipGzip() TarOption { - return func(t *tarOpts) { - t.skipGzip = true - } -} - -func (t *tarOpts) applyOpts(tarOpts ...TarOption) { - for _, clientOpt := range tarOpts { - clientOpt(t) - } -} diff --git a/tar/tar_test.go b/tar/tar_test.go index 213fe379d..a172b7f44 100644 --- a/tar/tar_test.go +++ b/tar/tar_test.go @@ -1,5 +1,5 @@ /* -Copyright 2022 The Flux authors +Copyright 2026 The Flux authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,265 +20,253 @@ import ( "archive/tar" "bytes" "compress/gzip" - "crypto/rand" - "fmt" + "io" "os" "path/filepath" "testing" ) -type untarTestCase struct { - name string - targetDir string - secureTargetDir string - fileName string - content []byte - wantErr string - maxUntarSize int - fileMode int64 -} +func TestTar(t *testing.T) { + srcDir := t.TempDir() -func TestUntar(t *testing.T) { - targetDirOutput := filepath.Join(t.TempDir(), "output") - symlink := filepath.Join(t.TempDir(), "symlink") + // Create test files. + if err := os.MkdirAll(filepath.Join(srcDir, "subdir"), 0o750); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(srcDir, "file.txt"), []byte("hello"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(srcDir, "subdir", "nested.txt"), []byte("world"), 0o644); err != nil { + t.Fatal(err) + } - subdir := filepath.Join(targetDirOutput, "subdir") - err := os.MkdirAll(subdir, 0o750) + var buf bytes.Buffer + n, err := Tar(srcDir, &buf) if err != nil { - t.Fatalf("cannot create subdir: %v", err) + t.Fatalf("Tar() error: %v", err) + } + if n <= 0 { + t.Fatal("Tar() returned zero bytes") + } + if n != int64(buf.Len()) { + t.Fatalf("Tar() returned %d bytes, but buffer has %d", n, buf.Len()) } - err = os.Symlink(subdir, symlink) - if err != nil { - t.Fatalf("cannot create symlink: %v", err) - } - - cases := []untarTestCase{ - { - name: "file at root", - fileName: "file1", - content: geRandomContent(256), - targetDir: targetDirOutput, - secureTargetDir: targetDirOutput, - }, - { - name: "file at subdir root", - fileName: "abc/fileX", - content: geRandomContent(256), - targetDir: targetDirOutput, - secureTargetDir: targetDirOutput, - }, - { - name: "directory traversal parent", - fileName: "../abc/file", - content: geRandomContent(256), - wantErr: `tar contained invalid name error "../abc/file"`, - targetDir: targetDirOutput, - secureTargetDir: targetDirOutput, - }, - { - name: "breach max size", - fileName: "big-file", - content: geRandomContent(256), - maxUntarSize: 255, - wantErr: `tar "big-file" is bigger than max archive size of 255 bytes`, - targetDir: targetDirOutput, - secureTargetDir: targetDirOutput, - }, - { - name: "breach default max untar size", - fileName: "another-big-file", - content: geRandomContent(DefaultMaxUntarSize + 1), - wantErr: `tar "another-big-file" is bigger than max archive size of 104857600 bytes`, - targetDir: targetDirOutput, - secureTargetDir: targetDirOutput, - }, - { - name: "disable max size checks", - fileName: "another-big-file", - content: geRandomContent(DefaultMaxUntarSize + 1), - maxUntarSize: -1, - targetDir: targetDirOutput, - secureTargetDir: targetDirOutput, - }, - { - name: "existing subdir", - fileName: "subdir/file1", - content: geRandomContent(256), - targetDir: targetDirOutput, - secureTargetDir: targetDirOutput, - }, - { - name: "relative target dir", - fileName: "file1", - content: geRandomContent(256), - targetDir: "anydir", - secureTargetDir: "./anydir", - }, - { - name: "relative paths can't ascend", - fileName: "file1", - content: geRandomContent(256), - targetDir: "../../../../../../../../tmp/test", - secureTargetDir: "./tmp/test", - }, - { - name: "symlink", - fileName: "any-file1", - content: geRandomContent(256), - targetDir: symlink, - wantErr: fmt.Sprintf(`dir '%s' must be a directory`, symlink), - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - f, err := createTestTar(tt) - if err != nil { - t.Fatalf("creating test tar: %v", err) - } - defer os.Remove(f.Name()) - defer os.RemoveAll(tt.targetDir) + got := readTarEntries(t, &buf) + want := map[string]string{ + ".": "", + "file.txt": "hello", + "subdir": "", + "subdir/nested.txt": "world", + } + if len(got) != len(want) { + t.Fatalf("got %d entries, want %d: %v", len(got), len(want), got) + } + for name, content := range want { + if got[name] != content { + t.Errorf("entry %q: got %q, want %q", name, got[name], content) + } + } +} - opts := make([]TarOption, 0) - if tt.maxUntarSize != 0 { - opts = append(opts, WithMaxUntarSize(tt.maxUntarSize)) - } +func TestTar_roundTrip(t *testing.T) { + srcDir := t.TempDir() - err = Untar(f, tt.targetDir, opts...) - var got string - if err != nil { - got = err.Error() - } - if tt.wantErr != got { - t.Errorf("wanted error: '%s' got: '%v'", tt.wantErr, err) - } + if err := os.MkdirAll(filepath.Join(srcDir, "a", "b"), 0o750); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(srcDir, "a", "b", "c.txt"), []byte("data"), 0o644); err != nil { + t.Fatal(err) + } - if tt.wantErr == "" && err != nil { - t.Errorf("unexpected error: %v", err) - } + var buf bytes.Buffer + if _, err := Tar(srcDir, &buf); err != nil { + t.Fatalf("Tar() error: %v", err) + } - // only assess file if no errors were expected - if tt.wantErr == "" { - abs := filepath.Join(tt.secureTargetDir, tt.fileName) - fi, err := os.Stat(abs) - if err != nil { - t.Errorf("stat %q: %v", abs, err) - return - } - - if fi.Size() != int64(len(tt.content)) { - t.Errorf("file size wanted: %d got: %d", len(tt.content), fi.Size()) - } - } + dstDir := t.TempDir() + if err := Untar(&buf, dstDir); err != nil { + t.Fatalf("Untar() error: %v", err) + } - if tt.targetDir != tt.secureTargetDir { - os.RemoveAll(tt.secureTargetDir) - } - }) + got, err := os.ReadFile(filepath.Join(dstDir, "a", "b", "c.txt")) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(got) != "data" { + t.Fatalf("got %q, want %q", got, "data") } } -func TestUntarDirectoryPermissions(t *testing.T) { - testDirName := "test-dir" +func TestTar_sanitizesHeaders(t *testing.T) { + srcDir := t.TempDir() + if err := os.WriteFile(filepath.Join(srcDir, "f.txt"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + if _, err := Tar(srcDir, &buf); err != nil { + t.Fatal(err) + } - f, err := createTestTar(untarTestCase{ - fileName: testDirName + "/", // from tar.Header: a trailing slash makes the entry a TypeDir - fileMode: 0o555, - content: nil, - }) + gr, err := gzip.NewReader(&buf) if err != nil { - t.Fatalf("creating test tar: %v", err) + t.Fatal(err) } + defer gr.Close() + + tr := tar.NewReader(gr) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + if hdr.Uid != 0 || hdr.Gid != 0 { + t.Errorf("entry %q: uid=%d gid=%d, want 0", hdr.Name, hdr.Uid, hdr.Gid) + } + if hdr.Uname != "" || hdr.Gname != "" { + t.Errorf("entry %q: uname=%q gname=%q, want empty", hdr.Name, hdr.Uname, hdr.Gname) + } + if !hdr.ModTime.IsZero() && hdr.ModTime.Unix() != 0 { + t.Errorf("entry %q: ModTime=%v, want zero or Unix epoch", hdr.Name, hdr.ModTime) + } + } +} - targetDir := t.TempDir() - - if err := Untar(f, targetDir); err != nil { - t.Fatalf("untar: %v", err) +func TestTar_withFilter(t *testing.T) { + srcDir := t.TempDir() + if err := os.WriteFile(filepath.Join(srcDir, "keep.txt"), []byte("keep"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(srcDir, "skip.log"), []byte("skip"), 0o644); err != nil { + t.Fatal(err) } - fullPath := filepath.Join(targetDir, testDirName) - fi, err := os.Lstat(fullPath) - if err != nil { - t.Errorf("stat %q: %v", fullPath, err) + filter := func(p string, fi os.FileInfo) bool { + return filepath.Ext(p) == ".log" } - if !fi.Mode().IsDir() { - t.Fatalf("%q: not a directory", fullPath) + var buf bytes.Buffer + if _, err := Tar(srcDir, &buf, WithFilter(filter)); err != nil { + t.Fatal(err) } - ownerPerm := fi.Mode().Perm() & 0o700 - if ownerPerm != 0o700 { - t.Errorf("the owner must always be able to traverse, read, and write extracted directories") + got := readTarEntries(t, &buf) + if _, ok := got["skip.log"]; ok { + t.Error("filtered file skip.log should not be in archive") + } + if got["keep.txt"] != "keep" { + t.Errorf("keep.txt: got %q, want %q", got["keep.txt"], "keep") } } -func Fuzz_Untar(f *testing.F) { - tf, err := createTestTar(untarTestCase{ - name: "file at root", - fileName: "file1", - content: geRandomContent(256), - }) - if err != nil { - f.Fatalf("cannot create test tar: %v", err) +func TestTar_skipsSymlinks(t *testing.T) { + srcDir := t.TempDir() + target := filepath.Join(srcDir, "target.txt") + if err := os.WriteFile(target, []byte("t"), 0o644); err != nil { + t.Fatal(err) } - defer os.Remove(tf.Name()) - - var content []byte - _, err = tf.Read(content) - if err != nil { - f.Fatalf("cannot read test tar: %v", err) + if err := os.Symlink(target, filepath.Join(srcDir, "link.txt")); err != nil { + t.Fatal(err) } - f.Add(content) + var buf bytes.Buffer + if _, err := Tar(srcDir, &buf); err != nil { + t.Fatal(err) + } - f.Fuzz(func(t *testing.T, data []byte) { - _ = Untar(bytes.NewReader(data), t.TempDir()) - }) + got := readTarEntries(t, &buf) + if _, ok := got["link.txt"]; ok { + t.Error("symlink should not be in archive") + } + if got["target.txt"] != "t" { + t.Errorf("target.txt: got %q, want %q", got["target.txt"], "t") + } } -func createTestTar(tt untarTestCase) (*os.File, error) { - f, err := os.CreateTemp("", "flux-untar-*.tar.gz") - if err != nil { - return nil, fmt.Errorf("open file: %w", err) +func TestTar_invalidDir(t *testing.T) { + _, err := Tar("/nonexistent/path", io.Discard) + if err == nil { + t.Fatal("expected error for nonexistent dir") } +} - gzw := gzip.NewWriter(f) - writer := tar.NewWriter(gzw) - - fileMode := tt.fileMode - if fileMode == 0 { - fileMode = 0o777 +func TestTar_skipGzip(t *testing.T) { + srcDir := t.TempDir() + if err := os.WriteFile(filepath.Join(srcDir, "file.txt"), []byte("plain"), 0o644); err != nil { + t.Fatal(err) } - writer.WriteHeader(&tar.Header{ - Name: tt.fileName, - Size: int64(len(tt.content)), - Mode: fileMode, - }) - - writer.Write(tt.content) + var buf bytes.Buffer + if _, err := Tar(srcDir, &buf, WithSkipGzip()); err != nil { + t.Fatalf("Tar() error: %v", err) + } - if err = writer.Close(); err != nil { - return nil, fmt.Errorf("close tar: %v", err) + // Should be a valid plain tar, not gzip. + tr := tar.NewReader(&buf) + found := false + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("tar.Next: %v", err) + } + if hdr.Name == "file.txt" { + found = true + content, _ := io.ReadAll(tr) + if string(content) != "plain" { + t.Errorf("got %q, want %q", content, "plain") + } + } } - if err = gzw.Close(); err != nil { - return nil, fmt.Errorf("close gzip: %v", err) + if !found { + t.Error("file.txt not found in plain tar") } +} - name := f.Name() - if err = f.Close(); err != nil { - return nil, fmt.Errorf("close file: %v", err) +func TestTar_notADirectory(t *testing.T) { + f := filepath.Join(t.TempDir(), "file.txt") + if err := os.WriteFile(f, []byte("x"), 0o644); err != nil { + t.Fatal(err) } - f, err = os.Open(name) - if err != nil { - return nil, fmt.Errorf("reopen file: %v", err) + _, err := Tar(f, io.Discard) + if err == nil { + t.Fatal("expected error for file path") } - return f, nil } -func geRandomContent(len int) []byte { - content := make([]byte, len) - rand.Read(content) - return content +// readTarEntries decompresses a tar.gz and returns a map of entry name to content. +func readTarEntries(t *testing.T, r io.Reader) map[string]string { + t.Helper() + gr, err := gzip.NewReader(r) + if err != nil { + t.Fatalf("gzip.NewReader: %v", err) + } + defer gr.Close() + + entries := make(map[string]string) + tr := tar.NewReader(gr) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("tar.Next: %v", err) + } + var content []byte + if hdr.Typeflag == tar.TypeReg { + content, err = io.ReadAll(tr) + if err != nil { + t.Fatalf("ReadAll %q: %v", hdr.Name, err) + } + } + entries[hdr.Name] = string(content) + } + return entries } diff --git a/tar/untar.go b/tar/untar.go new file mode 100644 index 000000000..614a588ee --- /dev/null +++ b/tar/untar.go @@ -0,0 +1,246 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Copyright 2020 The Flux authors. All rights reserved. +// Adapted from: golang.org/x/build/internal/untar + +package tar + +import ( + "archive/tar" + "compress/gzip" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + securejoin "github.com/cyphar/filepath-securejoin" +) + +const ( + // DefaultMaxUntarSize defines the default (100MB) max amount of bytes that Untar will process. + DefaultMaxUntarSize = 100 << (10 * 2) + + // UnlimitedUntarSize defines the value which disables untar size checks for maxUntarSize. + UnlimitedUntarSize = -1 + + // bufferSize defines the size of the buffer used when copying the tar file entries. + bufferSize = 32 * 1024 +) + +// Untar extracts a tar archive read from r into dir. +// +// By default, r is expected to be gzip-compressed; use WithSkipGzip to +// read a plain tar stream. Extraction is capped at DefaultMaxUntarSize +// bytes; use WithMaxUntarSize to raise, lower, or disable the limit. +// Use WithFilter to skip entries by name or FileInfo during extraction. +// Entries with paths that escape dir are rejected. Symlinks fail +// extraction unless WithSkipSymlinks is set, in which case they are +// silently dropped. +// +// If dir is a relative path, it cannot ascend from the current working +// directory. If dir exists, it must be a directory; otherwise it is +// created. +func Untar(r io.Reader, dir string, inOpts ...Option) (err error) { + opts := tarOpts{ + maxUntarSize: DefaultMaxUntarSize, + } + opts.applyOpts(inOpts...) + + dir = filepath.Clean(dir) + if !filepath.IsAbs(dir) { + cwd, err := os.Getwd() + if err != nil { + return err + } + + dir, err = securejoin.SecureJoin(cwd, dir) + if err != nil { + return err + } + } + + fi, err := os.Lstat(dir) + // Dir does not need to exist, as it can later be created. + if err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("cannot lstat '%s': %w", dir, err) + } + + if err == nil && !fi.IsDir() { + return fmt.Errorf("dir '%s' must be a directory", dir) + } + + madeDir := map[string]bool{} + var tr *tar.Reader + if opts.skipGzip { + tr = tar.NewReader(r) + } else { + zr, err := gzip.NewReader(r) + if err != nil { + return fmt.Errorf("requires gzip-compressed body: %w", err) + } + + tr = tar.NewReader(zr) + } + + processedBytes := 0 + t0 := time.Now() + + // For improved concurrency, this could be optimised by sourcing + // the buffer from a sync.Pool. + buf := make([]byte, bufferSize) + for { + f, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("tar error: %w", err) + } + processedBytes += int(f.Size) + if opts.maxUntarSize > UnlimitedUntarSize && + processedBytes > opts.maxUntarSize { + return fmt.Errorf("tar %q is bigger than max archive size of %d bytes", f.Name, opts.maxUntarSize) + } + if !validRelPath(f.Name) { + return fmt.Errorf("tar contained invalid name error %q", f.Name) + } + rel := filepath.FromSlash(f.Name) + abs := filepath.Join(dir, rel) + + fi := f.FileInfo() + mode := fi.Mode() + + if opts.filter != nil && opts.filter(f.Name, fi) { + continue + } + + switch { + case mode.IsRegular(): + // Make the directory. This is redundant because it should + // already be made by a directory entry in the tar + // beforehand. Thus, don't check for errors; the next + // write will fail with the same error. + dir := filepath.Dir(abs) + if !madeDir[dir] { + if err := os.MkdirAll(filepath.Dir(abs), 0o750); err != nil { + return err + } + madeDir[dir] = true + } + if runtime.GOOS == "darwin" && mode&0111 != 0 { + // The darwin kernel caches binary signatures + // and SIGKILLs binaries with mismatched + // signatures. Overwriting a binary with + // O_TRUNC does not clear the cache, rendering + // the new copy unusable. Removing the original + // file first does clear the cache. See #54132. + err := os.Remove(abs) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return err + } + } + wf, err := os.OpenFile(abs, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm()) + if err != nil { + return err + } + + n, err := copyBuffer(wf, tr, buf) + if err != nil && err != io.EOF { + return fmt.Errorf("error copying buffer: %w", err) + } + + if closeErr := wf.Close(); closeErr != nil && err == nil { + err = closeErr + } + if err != nil { + return fmt.Errorf("error writing to %s: %w", abs, err) + } + if n != f.Size { + return fmt.Errorf("only wrote %d bytes to %s; expected %d", n, abs, f.Size) + } + modTime := f.ModTime + if modTime.After(t0) { + // Ensures that that files extracted are not newer then the + // current system time. + modTime = t0 + } + if !modTime.IsZero() { + if err = os.Chtimes(abs, modTime, modTime); err != nil { + return fmt.Errorf("error changing file time %s: %w", abs, err) + } + } + case mode.IsDir(): + // Ensure the owner can always traverse, read, and write + // into extracted directories, regardless of what the tar + // header claims. This prevents crafted archives from + // creating directories that block cleanup or future writes. + dirPerm := mode.Perm() | 0o700 + if err := os.MkdirAll(abs, dirPerm); err != nil { + return err + } + madeDir[abs] = true + case mode&os.ModeSymlink == os.ModeSymlink: + if !opts.skipSymlinks { + return fmt.Errorf("tar file entry %s is a symlink, which is not allowed in this context", f.Name) + } + default: + return fmt.Errorf("tar file entry %s contained unsupported file type %v", f.Name, mode) + } + } + return nil +} + +// Uses a variant of io.CopyBuffer which ensures that a buffer is being used. +// The upstream version prioritises the use of interfaces WriterTo and ReadFrom +// which in this case causes the entirety of the tar file entry to be loaded +// into memory. +// +// Original source: +// https://github.com/golang/go/blob/6f445a9db55f65e55c5be29d3c506ecf3be37915/src/io/io.go#L405 +func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { + if buf == nil { + return 0, fmt.Errorf("buf is nil") + } + for { + nr, er := src.Read(buf) + if nr > 0 { + nw, ew := dst.Write(buf[0:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = fmt.Errorf("errInvalidWrite") + } + } + written += int64(nw) + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return written, err +} + +func validRelPath(p string) bool { + if p == "" || strings.Contains(p, `\`) || strings.HasPrefix(p, "/") || strings.Contains(p, "../") { + return false + } + return true +} diff --git a/tar/untar_test.go b/tar/untar_test.go new file mode 100644 index 000000000..8aecfdecf --- /dev/null +++ b/tar/untar_test.go @@ -0,0 +1,451 @@ +/* +Copyright 2022 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tar + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "fmt" + "io" + "os" + "path" + "path/filepath" + "strings" + "testing" +) + +type untarTestCase struct { + name string + targetDir string + secureTargetDir string + fileName string + content []byte + wantErr string + maxUntarSize int + fileMode int64 +} + +func TestUntar(t *testing.T) { + targetDirOutput := filepath.Join(t.TempDir(), "output") + symlink := filepath.Join(t.TempDir(), "symlink") + + subdir := filepath.Join(targetDirOutput, "subdir") + err := os.MkdirAll(subdir, 0o750) + if err != nil { + t.Fatalf("cannot create subdir: %v", err) + } + + err = os.Symlink(subdir, symlink) + if err != nil { + t.Fatalf("cannot create symlink: %v", err) + } + + cases := []untarTestCase{ + { + name: "file at root", + fileName: "file1", + content: geRandomContent(256), + targetDir: targetDirOutput, + secureTargetDir: targetDirOutput, + }, + { + name: "file at subdir root", + fileName: "abc/fileX", + content: geRandomContent(256), + targetDir: targetDirOutput, + secureTargetDir: targetDirOutput, + }, + { + name: "directory traversal parent", + fileName: "../abc/file", + content: geRandomContent(256), + wantErr: `tar contained invalid name error "../abc/file"`, + targetDir: targetDirOutput, + secureTargetDir: targetDirOutput, + }, + { + name: "breach max size", + fileName: "big-file", + content: geRandomContent(256), + maxUntarSize: 255, + wantErr: `tar "big-file" is bigger than max archive size of 255 bytes`, + targetDir: targetDirOutput, + secureTargetDir: targetDirOutput, + }, + { + name: "breach default max untar size", + fileName: "another-big-file", + content: geRandomContent(1024), + maxUntarSize: 512, + wantErr: `tar "another-big-file" is bigger than max archive size of 512 bytes`, + targetDir: targetDirOutput, + secureTargetDir: targetDirOutput, + }, + { + name: "disable max size checks", + fileName: "another-big-file", + content: geRandomContent(1024), + maxUntarSize: -1, + targetDir: targetDirOutput, + secureTargetDir: targetDirOutput, + }, + { + name: "existing subdir", + fileName: "subdir/file1", + content: geRandomContent(256), + targetDir: targetDirOutput, + secureTargetDir: targetDirOutput, + }, + { + name: "relative target dir", + fileName: "file1", + content: geRandomContent(256), + targetDir: "anydir", + secureTargetDir: "./anydir", + }, + { + name: "relative paths can't ascend", + fileName: "file1", + content: geRandomContent(256), + targetDir: "../../../../../../../../tmp/test", + secureTargetDir: "./tmp/test", + }, + { + name: "symlink", + fileName: "any-file1", + content: geRandomContent(256), + targetDir: symlink, + wantErr: fmt.Sprintf(`dir '%s' must be a directory`, symlink), + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + f := createTestTar(t, tt) + defer os.RemoveAll(tt.targetDir) + + opts := make([]Option, 0) + if tt.maxUntarSize != 0 { + opts = append(opts, WithMaxUntarSize(tt.maxUntarSize)) + } + + err = Untar(f, tt.targetDir, opts...) + var got string + if err != nil { + got = err.Error() + } + if tt.wantErr != got { + t.Errorf("wanted error: '%s' got: '%v'", tt.wantErr, err) + } + + if tt.wantErr == "" && err != nil { + t.Errorf("unexpected error: %v", err) + } + + // only assess file if no errors were expected + if tt.wantErr == "" { + abs := filepath.Join(tt.secureTargetDir, tt.fileName) + fi, err := os.Stat(abs) + if err != nil { + t.Errorf("stat %q: %v", abs, err) + return + } + + if fi.Size() != int64(len(tt.content)) { + t.Errorf("file size wanted: %d got: %d", len(tt.content), fi.Size()) + } + } + + if tt.targetDir != tt.secureTargetDir { + os.RemoveAll(tt.secureTargetDir) + } + }) + } +} + +func TestUntarDirectoryPermissions(t *testing.T) { + testDirName := "test-dir" + + f := createTestTar(t, untarTestCase{ + fileName: testDirName + "/", // from tar.Header: a trailing slash makes the entry a TypeDir + fileMode: 0o555, + content: nil, + }) + + targetDir := t.TempDir() + + if err := Untar(f, targetDir); err != nil { + t.Fatalf("untar: %v", err) + } + + fullPath := filepath.Join(targetDir, testDirName) + fi, err := os.Lstat(fullPath) + if err != nil { + t.Errorf("stat %q: %v", fullPath, err) + } + + if !fi.Mode().IsDir() { + t.Fatalf("%q: not a directory", fullPath) + } + + ownerPerm := fi.Mode().Perm() & 0o700 + if ownerPerm != 0o700 { + t.Errorf("the owner must always be able to traverse, read, and write extracted directories") + } +} + +func Fuzz_Untar(f *testing.F) { + tf := createTestTar(f, untarTestCase{ + name: "file at root", + fileName: "file1", + content: geRandomContent(256), + }) + + var content []byte + if _, err := tf.Read(content); err != nil { + f.Fatalf("cannot read test tar: %v", err) + } + + f.Add(content) + + f.Fuzz(func(t *testing.T, data []byte) { + _ = Untar(bytes.NewReader(data), t.TempDir()) + }) +} + +func createTestTar(t testing.TB, tt untarTestCase) *os.File { + t.Helper() + + name := filepath.Join(t.TempDir(), "test.tar.gz") + f, err := os.Create(name) + if err != nil { + t.Fatalf("create file: %v", err) + } + + gzw := gzip.NewWriter(f) + writer := tar.NewWriter(gzw) + + fileMode := tt.fileMode + if fileMode == 0 { + fileMode = 0o777 + } + + writer.WriteHeader(&tar.Header{ + Name: tt.fileName, + Size: int64(len(tt.content)), + Mode: fileMode, + }) + + writer.Write(tt.content) + + if err = writer.Close(); err != nil { + t.Fatalf("close tar: %v", err) + } + if err = gzw.Close(); err != nil { + t.Fatalf("close gzip: %v", err) + } + if err = f.Close(); err != nil { + t.Fatalf("close file: %v", err) + } + + f, err = os.Open(name) + if err != nil { + t.Fatalf("reopen file: %v", err) + } + return f +} + +func geRandomContent(n int) []byte { + content := make([]byte, n) + for i := range content { + content[i] = byte(i % 251) + } + return content +} + +func TestSkipSymlinks(t *testing.T) { + tmpDir := t.TempDir() + + symlinkTarget := filepath.Join(tmpDir, "symlink.target") + err := os.WriteFile(symlinkTarget, geRandomContent(256), os.ModePerm) + if err != nil { + t.Fatal(err) + } + + symlink := filepath.Join(tmpDir, "symlink") + err = os.Symlink(symlinkTarget, symlink) + if err != nil { + t.Fatal(err) + } + + tgzFileName := filepath.Join(t.TempDir(), "test.tgz") + var buf bytes.Buffer + err = tgzWithSymlinks(tmpDir, &buf) + if err != nil { + t.Fatal(err) + } + + tgzFile, err := os.OpenFile(tgzFileName, os.O_CREATE|os.O_RDWR, os.ModePerm) + if err != nil { + t.Fatal(err) + } + if _, err := io.Copy(tgzFile, &buf); err != nil { + t.Fatal(err) + } + if err = tgzFile.Close(); err != nil { + t.Fatal(err) + } + + targetDirOutput := filepath.Join(t.TempDir(), "output") + f1, err := os.Open(tgzFileName) + if err != nil { + t.Fatal(err) + } + + err = Untar(f1, targetDirOutput, WithMaxUntarSize(-1)) + if err == nil { + t.Errorf("wanted error: unsupported symlink") + } + + f2, err := os.Open(tgzFileName) + if err != nil { + t.Fatal(err) + } + + err = Untar(f2, targetDirOutput, WithMaxUntarSize(-1), WithSkipSymlinks()) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if _, err := os.Open(path.Join(targetDirOutput, "symlink.target")); err != nil { + t.Errorf("regular file not found: %v", err) + } +} + +func tgzWithSymlinks(src string, buf io.Writer) error { + absDir, err := filepath.Abs(src) + if err != nil { + return err + } + + zr := gzip.NewWriter(buf) + tw := tar.NewWriter(zr) + if err := filepath.Walk(absDir, func(file string, fi os.FileInfo, err error) error { + if err != nil { + return err + } + + header, err := tar.FileInfoHeader(fi, file) + if err != nil { + return err + } + if err := tw.WriteHeader(header); err != nil { + return err + } + + if fi.Mode().IsRegular() { + f, err := os.Open(file) + if err != nil { + return err + } + if _, err := io.Copy(tw, f); err != nil { + return err + } + return f.Close() + } + + return nil + }); err != nil { + return err + } + if err := tw.Close(); err != nil { + return err + } + if err := zr.Close(); err != nil { + return err + } + return nil +} + +func TestUntar_withFilter(t *testing.T) { + // Build a gzipped tar with two files. + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + for name, data := range map[string]string{"keep.txt": "keep", "skip.log": "skip"} { + tw.WriteHeader(&tar.Header{Name: name, Size: int64(len(data)), Mode: 0o644}) + tw.Write([]byte(data)) + } + tw.Close() + gw.Close() + + dst := t.TempDir() + filter := func(p string, _ os.FileInfo) bool { + return filepath.Ext(p) == ".log" + } + if err := Untar(&buf, dst, WithMaxUntarSize(-1), WithFilter(filter)); err != nil { + t.Fatalf("Untar: %v", err) + } + + if _, err := os.Stat(filepath.Join(dst, "skip.log")); err == nil { + t.Error("filtered file skip.log should not have been extracted") + } + got, err := os.ReadFile(filepath.Join(dst, "keep.txt")) + if err != nil { + t.Fatalf("read keep.txt: %v", err) + } + if string(got) != "keep" { + t.Errorf("keep.txt: got %q, want %q", string(got), "keep") + } +} + +func TestUntar_withFilterDirectory(t *testing.T) { + // Build a gzipped tar with entries under two directories. + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + tw.WriteHeader(&tar.Header{Name: "skip/", Mode: 0o755, Typeflag: tar.TypeDir}) + s := "secret" + tw.WriteHeader(&tar.Header{Name: "skip/secret.txt", Size: int64(len(s)), Mode: 0o644}) + tw.Write([]byte(s)) + tw.WriteHeader(&tar.Header{Name: "keep/", Mode: 0o755, Typeflag: tar.TypeDir}) + k := "public" + tw.WriteHeader(&tar.Header{Name: "keep/data.txt", Size: int64(len(k)), Mode: 0o644}) + tw.Write([]byte(k)) + tw.Close() + gw.Close() + + dst := t.TempDir() + filter := func(p string, _ os.FileInfo) bool { + return strings.HasPrefix(p, "skip/") + } + if err := Untar(&buf, dst, WithMaxUntarSize(-1), WithFilter(filter)); err != nil { + t.Fatalf("Untar: %v", err) + } + + if _, err := os.Stat(filepath.Join(dst, "skip", "secret.txt")); err == nil { + t.Error("skip/secret.txt should not have been extracted") + } + got, err := os.ReadFile(filepath.Join(dst, "keep", "data.txt")) + if err != nil { + t.Fatalf("read keep/data.txt: %v", err) + } + if string(got) != "public" { + t.Errorf("keep/data.txt: got %q, want %q", string(got), "public") + } +}