From e0a78ae754571e160ba3dea1381b73c6e9ffbd66 Mon Sep 17 00:00:00 2001 From: Brian Goff Date: Wed, 9 Aug 2023 18:36:06 +0000 Subject: [PATCH] Add type to itterate directory This prevents reading from directory from allocating an unbounded slice at the cost of potentially having to read more than once. It also prevents unccessary allocation in cases where an error occurs handling the file type. Signed-off-by: Brian Goff --- fs/copy.go | 32 ++++++--- fs/dir.go | 53 +++++++++++++++ fs/dir_test.go | 181 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 257 insertions(+), 9 deletions(-) create mode 100644 fs/dir.go create mode 100644 fs/dir_test.go diff --git a/fs/copy.go b/fs/copy.go index af3abdd4..a3fef3c0 100644 --- a/fs/copy.go +++ b/fs/copy.go @@ -103,11 +103,6 @@ func copyDirectory(dst, src string, inodes map[uint64]string, o *copyDirOpts) er } } - entries, err := os.ReadDir(src) - if err != nil { - return fmt.Errorf("failed to read %s: %w", src, err) - } - if err := copyFileInfo(stat, src, dst); err != nil { return fmt.Errorf("failed to copy file info for %s: %w", dst, err) } @@ -116,7 +111,15 @@ func copyDirectory(dst, src string, inodes map[uint64]string, o *copyDirOpts) er return fmt.Errorf("failed to copy xattrs: %w", err) } - for _, entry := range entries { + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + + dr := &dirReader{f: f} + + handleEntry := func(entry os.DirEntry) error { source := filepath.Join(src, entry.Name()) target := filepath.Join(dst, entry.Name()) @@ -130,7 +133,7 @@ func copyDirectory(dst, src string, inodes map[uint64]string, o *copyDirOpts) er if err := copyDirectory(target, source, inodes, o); err != nil { return err } - continue + return nil case (fileInfo.Mode() & os.ModeType) == 0: link, err := getLinkSource(target, fileInfo, inodes) if err != nil { @@ -159,7 +162,7 @@ func copyDirectory(dst, src string, inodes map[uint64]string, o *copyDirOpts) er } default: logrus.Warnf("unsupported mode: %s: %s", source, fileInfo.Mode()) - continue + return nil } if err := copyFileInfo(fileInfo, source, target); err != nil { @@ -169,9 +172,20 @@ func copyDirectory(dst, src string, inodes map[uint64]string, o *copyDirOpts) er if err := copyXAttrs(target, source, o.xex, o.xeh); err != nil { return fmt.Errorf("failed to copy xattrs: %w", err) } + return nil } - return nil + for { + entry := dr.Next() + if entry == nil { + break + } + + if err := handleEntry(entry); err != nil { + return err + } + } + return dr.Err() } // CopyFile copies the source file to the target. diff --git a/fs/dir.go b/fs/dir.go new file mode 100644 index 00000000..6c7e32e9 --- /dev/null +++ b/fs/dir.go @@ -0,0 +1,53 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package fs + +import ( + "io" + "os" +) + +type dirReader struct { + buf []os.DirEntry + f *os.File + err error +} + +func (r *dirReader) Next() os.DirEntry { + if len(r.buf) == 0 { + infos, err := r.f.ReadDir(32) + if err != nil { + if err != io.EOF { + r.err = err + } + return nil + } + r.buf = infos + } + + if len(r.buf) == 0 { + return nil + } + out := r.buf[0] + r.buf[0] = nil + r.buf = r.buf[1:] + return out +} + +func (r *dirReader) Err() error { + return r.err +} diff --git a/fs/dir_test.go b/fs/dir_test.go new file mode 100644 index 00000000..ebe0480e --- /dev/null +++ b/fs/dir_test.go @@ -0,0 +1,181 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package fs + +import ( + "bytes" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestDirReader(t *testing.T) { + t.Run("empty dir", func(t *testing.T) { + t.Parallel() + + dr := newTestDirReader(t, nil) + if dr.Next() != nil { + t.Fatal("expected nil dir entry for empty dir") + } + // validate that another call will still be nil and not panic + if dr.Next() != nil { + t.Fatal("expected nil dir entry for empty dir") + } + }) + + t.Run("populated dir", func(t *testing.T) { + t.Parallel() + + content := map[string]*testFile{ + "foo": newTestFile([]byte("hello"), 0644), + "bar/baz": newTestFile([]byte("world"), 0600), + "bar": newTestFile(nil, os.ModeDir|0710), + } + found := make(map[string]bool, len(content)) + shouldSkip := map[string]bool{ + "bar/baz": true, + } + dr := newTestDirReader(t, content) + + check := func(entry os.DirEntry) { + tf := content[entry.Name()] + if tf == nil { + t.Errorf("got unknown entry: %s", entry) + return + } + + fi, err := entry.Info() + if err != nil { + t.Error() + return + } + + // Windows file permissions are not accurately represented in mode like this and will show 0666 (files) and 0777 (dirs) + // As such, do not try to compare mode equality + if runtime.GOOS != "windows" { + if fi.Mode() != tf.mode { + t.Errorf("%s: file modes do not match, expected: %s, got: %s", fi.Name(), tf.mode, fi.Mode()) + } + } else { + if (fi.Mode().IsRegular() != tf.mode.IsRegular()) || (fi.Mode().IsDir() != tf.mode.IsDir()) { + t.Error("%s: file modes does not match") + } + } + + if fi.Mode().IsRegular() { + dt, err := os.ReadFile(filepath.Join(dr.f.Name(), entry.Name())) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(tf.dt, dt) { + t.Errorf("expected %q, got: %q", string(tf.dt), string(dt)) + } + } + } + + for { + entry := dr.Next() + if entry == nil { + break + } + found[entry.Name()] = true + check(entry) + } + + if err := dr.Err(); err != nil { + t.Fatal(err) + } + + if len(found) != len(content)-len(shouldSkip) { + t.Fatalf("exected files [%s], got: [%s]", mapToStringer(content), mapToStringer(found)) + } + for k := range shouldSkip { + if found[k] { + t.Errorf("expected dir reader to skip %s", k) + } + } + }) +} + +type stringerFunc func() string + +func (f stringerFunc) String() string { + return f() +} + +func mapToStringer[T any](in map[string]T) stringerFunc { + return func() string { + out := make([]string, 0, len(in)) + for k := range in { + out = append(out, k) + } + return strings.Join(out, ",") + } +} + +type testFile struct { + dt []byte + mode os.FileMode +} + +func newTestFile(dt []byte, mode os.FileMode) *testFile { + return &testFile{ + dt: dt, + mode: mode, + } +} + +func newTestDirReader(t *testing.T, content map[string]*testFile) *dirReader { + p := t.TempDir() + + for cp, info := range content { + fp := filepath.Join(p, cp) + + switch { + case info.mode.IsRegular(): + if err := os.MkdirAll(filepath.Dir(fp), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(fp, info.dt, info.mode.Perm()); err != nil { + t.Fatal(err) + } + if err := os.Chmod(fp, info.mode.Perm()); err != nil { + t.Fatal(err) + } + case info.mode.IsDir(): + if err := os.MkdirAll(fp, info.mode); err != nil { + t.Fatal(err) + } + // make sure the dir has the right perms in case it was created earlier while writing a file + if err := os.Chmod(fp, info.mode.Perm()); err != nil { + t.Fatal(err) + } + default: + t.Fatal("unexpected file mode") + } + } + + f, err := os.Open(p) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { f.Close() }) + return &dirReader{f: f} +}