Skip to content

Commit

Permalink
contenthash: unify "follow" and trailing-symlink handling for Checksum
Browse files Browse the repository at this point in the history
This patch is part of a series which fixes the symlink resolution
semantics within BuildKit.

Previously, the concept of the follow flag had different meanings in
various parts of the checksum codepath. FollowLinks is effectively
O_NOFOLLOW, but the implementation in getFollowLinks was actually more
like RESOLVE_NO_SYMLINKS. This was masked by the fact that
checksumFollow would implement the O_NOFOLLOW behaviour (incorrectly),
but checksumFollow would call checksumNoFollow (which would follow
symlinks in path components by setting follow=true for getFollowLinks).

It is much easier to simply remove these layers of indirection and unify
the meaning of FollowLinks across all of the code. This means that the
old follow flag is no longer needed.

This also means that we can now remove the incorrect symlink resolution
logic in (*cacheContext).checksumFollow() and move the followTrailing
logic to (*cacheContext).checksum(), as well as removing
getFollowParentLinks(). Since this removes some redundant re-checksum
loops, we need to add followTrailing logic to scanPath() so that final
symlink components result in the correct directory being scanned
properly.

The only user of (*cacheContext).checksum(follow=false) was
(*cacheContext).includedPaths() which appeared to be simply using this
as an optimisation (since the path being walked already had its parent
path resolved). Having two easily-confused boolean flags for an
optimisation that is probably not necessary (getFollowLinks already does
a fast check to see if the original path is in the cache) seemed
unnecessary, so just keep followTrailing.

Signed-off-by: Aleksa Sarai <cyphar@cyphar.com>
  • Loading branch information
cyphar committed May 3, 2024
1 parent fa77668 commit ce7521c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 104 deletions.
168 changes: 65 additions & 103 deletions cache/contenthash/checksum.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ func (cc *cacheContext) Checksum(ctx context.Context, mountable cache.Mountable,
defer m.clean()

if !opts.Wildcard && len(opts.IncludePatterns) == 0 && len(opts.ExcludePatterns) == 0 {
return cc.checksumFollow(ctx, m, p, opts.FollowLinks)
return cc.lazyChecksum(ctx, m, p, opts.FollowLinks)
}

includedPaths, err := cc.includedPaths(ctx, m, p, opts)
Expand All @@ -418,7 +418,7 @@ func (cc *cacheContext) Checksum(ctx context.Context, mountable cache.Mountable,
if opts.FollowLinks {
for i, w := range includedPaths {
if w.record.Type == CacheRecordTypeSymlink {
dgst, err := cc.checksumFollow(ctx, m, w.path, opts.FollowLinks)
dgst, err := cc.lazyChecksum(ctx, m, w.path, opts.FollowLinks)
if err != nil {
return "", err
}
Expand All @@ -445,30 +445,6 @@ func (cc *cacheContext) Checksum(ctx context.Context, mountable cache.Mountable,
return digester.Digest(), nil
}

func (cc *cacheContext) checksumFollow(ctx context.Context, m *mount, p string, follow bool) (digest.Digest, error) {
const maxSymlinkLimit = 255
i := 0
for {
if i > maxSymlinkLimit {
return "", errors.Errorf("too many symlinks: %s", p)
}
cr, err := cc.checksumNoFollow(ctx, m, p)
if err != nil {
return "", err
}
if cr.Type == CacheRecordTypeSymlink && follow {
link := cr.Linkname
if !path.IsAbs(cr.Linkname) {
link = path.Join(path.Dir(p), link)
}
i++
p = link
} else {
return cr.Digest, nil
}
}
}

func (cc *cacheContext) includedPaths(ctx context.Context, m *mount, p string, opts ChecksumOpts) ([]*includedPath, error) {
cc.mu.Lock()
defer cc.mu.Unlock()
Expand All @@ -478,12 +454,12 @@ func (cc *cacheContext) includedPaths(ctx context.Context, m *mount, p string, o
}

root := cc.tree.Root()
scan, err := cc.needsScan(root, "")
scan, err := cc.needsScan(root, "", false)
if err != nil {
return nil, err
}
if scan {
if err := cc.scanPath(ctx, m, ""); err != nil {
if err := cc.scanPath(ctx, m, "", false); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -542,7 +518,7 @@ func (cc *cacheContext) includedPaths(ctx context.Context, m *mount, p string, o
// involves a symlink. That will match fsutil behavior of
// calling functions such as stat and walk.
var cr *CacheRecord
k, cr, err = getFollowParentLinks(root, k, true)
k, cr, err = getFollowLinks(root, k, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -753,11 +729,7 @@ func wildcardPrefix(root *iradix.Node, p string) (string, []byte, bool, error) {

// Only resolve the final symlink component if there are components in the
// wildcard segment.
resolveFn := getFollowParentLinks
if d2 != "" {
resolveFn = getFollowLinks
}
k, cr, err := resolveFn(root, convertPathToKey([]byte(d1)), true)
k, cr, err := getFollowLinks(root, convertPathToKey([]byte(d1)), d2 != "")
if err != nil {
return "", k, false, err
}
Expand Down Expand Up @@ -796,19 +768,22 @@ func containsWildcards(name string) bool {
return false
}

func (cc *cacheContext) checksumNoFollow(ctx context.Context, m *mount, p string) (*CacheRecord, error) {
func (cc *cacheContext) lazyChecksum(ctx context.Context, m *mount, p string, followTrailing bool) (digest.Digest, error) {
p = keyPath(p)
k := convertPathToKey([]byte(p))

// Try to look up the path directly without doing a scan.
cc.mu.RLock()
if cc.txn == nil {
root := cc.tree.Root()
cc.mu.RUnlock()
v, ok := root.Get(convertPathToKey([]byte(p)))
if ok {
cr := v.(*CacheRecord)
if cr.Digest != "" {
return cr, nil
}

_, cr, err := getFollowLinks(root, k, followTrailing)
if err != nil {
return "", err
}
if cr != nil && cr.Digest != "" {
return cr.Digest, nil
}
} else {
cc.mu.RUnlock()
Expand All @@ -828,7 +803,11 @@ func (cc *cacheContext) checksumNoFollow(ctx context.Context, m *mount, p string
}
}()

return cc.lazyChecksum(ctx, m, p)
cr, err := cc.scanChecksum(ctx, m, p, followTrailing)
if err != nil {
return "", err
}
return cr.Digest, nil
}

func (cc *cacheContext) commitActiveTransaction() {
Expand All @@ -847,21 +826,21 @@ func (cc *cacheContext) commitActiveTransaction() {
cc.txn = nil
}

func (cc *cacheContext) lazyChecksum(ctx context.Context, m *mount, p string) (*CacheRecord, error) {
func (cc *cacheContext) scanChecksum(ctx context.Context, m *mount, p string, followTrailing bool) (*CacheRecord, error) {
root := cc.tree.Root()
scan, err := cc.needsScan(root, p)
scan, err := cc.needsScan(root, p, followTrailing)
if err != nil {
return nil, err
}
if scan {
if err := cc.scanPath(ctx, m, p); err != nil {
if err := cc.scanPath(ctx, m, p, followTrailing); err != nil {
return nil, err
}
}
k := convertPathToKey([]byte(p))
txn := cc.tree.Txn()
root = txn.Root()
cr, updated, err := cc.checksum(ctx, root, txn, m, k, true)
cr, updated, err := cc.checksum(ctx, root, txn, m, k, followTrailing)
if err != nil {
return nil, err
}
Expand All @@ -870,9 +849,9 @@ func (cc *cacheContext) lazyChecksum(ctx context.Context, m *mount, p string) (*
return cr, err
}

func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *iradix.Txn, m *mount, k []byte, follow bool) (*CacheRecord, bool, error) {
func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *iradix.Txn, m *mount, k []byte, followTrailing bool) (*CacheRecord, bool, error) {
origk := k
k, cr, err := getFollowParentLinks(root, k, follow)
k, cr, err := getFollowLinks(root, k, followTrailing)
if err != nil {
return nil, false, err
}
Expand All @@ -898,7 +877,9 @@ func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *ir
}
h.Write(bytes.TrimPrefix(subk, k))

subcr, _, err := cc.checksum(ctx, root, txn, m, subk, true)
// We do not follow trailing links when checksumming a directory's
// contents.
subcr, _, err := cc.checksum(ctx, root, txn, m, subk, false)
if err != nil {
return nil, false, err
}
Expand Down Expand Up @@ -949,13 +930,13 @@ func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *ir

// needsScan returns false if path is in the tree or a parent path is in tree
// and subpath is missing.
func (cc *cacheContext) needsScan(root *iradix.Node, path string) (bool, error) {
func (cc *cacheContext) needsScan(root *iradix.Node, path string, followTrailing bool) (bool, error) {
var (
lastGoodPath string
hasParentInTree bool
)
k := convertPathToKey([]byte(path))
_, cr, err := getFollowLinksCallback(root, k, true, func(subpath string, cr *CacheRecord) error {
_, cr, err := getFollowLinksCallback(root, k, followTrailing, func(subpath string, cr *CacheRecord) error {
if cr != nil {
// If the path is not a symlink, then for now we have a parent in
// the tree. Otherwise, we reset hasParentInTree because we
Expand All @@ -981,8 +962,8 @@ func (cc *cacheContext) needsScan(root *iradix.Node, path string) (bool, error)
return cr == nil && !hasParentInTree, nil
}

func (cc *cacheContext) scanPath(ctx context.Context, m *mount, p string) (retErr error) {
d := path.Dir(path.Join("/", p))
func (cc *cacheContext) scanPath(ctx context.Context, m *mount, p string, followTrailing bool) (retErr error) {
p = path.Join("/", p)

mp, err := m.mount(ctx)
if err != nil {
Expand All @@ -992,7 +973,7 @@ func (cc *cacheContext) scanPath(ctx context.Context, m *mount, p string) (retEr
n := cc.tree.Root()
txn := cc.tree.Txn()

parentPath, err := rootPath(mp, filepath.FromSlash(d), func(p, link string) error {
resolvedPath, err := rootPath(mp, filepath.FromSlash(p), followTrailing, func(p, link string) error {
cr := &CacheRecord{
Type: CacheRecordTypeSymlink,
Linkname: filepath.ToSlash(link),
Expand All @@ -1006,7 +987,14 @@ func (cc *cacheContext) scanPath(ctx context.Context, m *mount, p string) (retEr
return err
}

err = filepath.Walk(parentPath, func(itemPath string, fi os.FileInfo, err error) error {
// Scan the parent directory of the path we resolved, unless we're at the
// root (in which case we scan the root).
scanPath := filepath.Dir(resolvedPath)
if !strings.HasPrefix(filepath.ToSlash(scanPath)+"/", filepath.ToSlash(mp)+"/") {
scanPath = resolvedPath
}

err = filepath.Walk(scanPath, func(itemPath string, fi os.FileInfo, err error) error {
if err != nil {
// If the root doesn't exist, ignore the error.
if errors.Is(err, os.ErrNotExist) {
Expand Down Expand Up @@ -1055,48 +1043,33 @@ func (cc *cacheContext) scanPath(ctx context.Context, m *mount, p string) (retEr
return nil
}

// getFollowParentLinks is effectively O_PATH|O_NOFOLLOW, where the final
// component is looked up without doing any symlink resolution (if it is a
// symlink).
func getFollowParentLinks(root *iradix.Node, k []byte, follow bool) ([]byte, *CacheRecord, error) {
v, ok := root.Get(k)
if ok {
return k, v.(*CacheRecord), nil
}
if !follow || len(k) == 0 {
return k, nil, nil
}

// Only fully evaluate the parent path.
dir, file := splitKey(k)
dir, _, err := getFollowLinks(root, dir, follow)
if err != nil {
return nil, nil, err
}

// Do a direct lookup of the final component.
k = append(dir, file...)
v, ok = root.Get(k)
if ok {
return k, v.(*CacheRecord), nil
}
return k, nil, nil
}

// followLinksCallback is called after we try to resolve each element. If the
// path was not found, cr is nil.
type followLinksCallback func(path string, cr *CacheRecord) error

func getFollowLinks(root *iradix.Node, k []byte, follow bool) ([]byte, *CacheRecord, error) {
return getFollowLinksCallback(root, k, follow, nil)
// getFollowLinks is shorthand for getFollowLinksCallback(..., nil).
func getFollowLinks(root *iradix.Node, k []byte, followTrailing bool) ([]byte, *CacheRecord, error) {
return getFollowLinksCallback(root, k, followTrailing, nil)
}

func getFollowLinksCallback(root *iradix.Node, k []byte, follow bool, cb followLinksCallback) ([]byte, *CacheRecord, error) {
// getFollowLinksCallback looks up the requested key, fully resolving any
// symlink components encountered. The implementation is heavily based on
// <https://github.com/cyphar/filepath-securejoin>.
//
// followTrailing indicates whether the *final component* of the path should be
// resolved (effectively O_PATH|O_NOFOLLOW). Note that (in contrast to some
// Linux APIs), followTrailing is obeyed even if the key has a trailing slash
// (though paths like "foo/link/." will cause the link to be resolved).
//
// The callback cb is called after each cache lookup done by
// getFollowLinksCallback, except for the first lookup where the verbatim key
// is looked up in the cache.
func getFollowLinksCallback(root *iradix.Node, k []byte, followTrailing bool, cb followLinksCallback) ([]byte, *CacheRecord, error) {
v, ok := root.Get(k)
if ok && v.(*CacheRecord).Type != CacheRecordTypeSymlink {
if ok && (!followTrailing || v.(*CacheRecord).Type != CacheRecordTypeSymlink) {
return k, v.(*CacheRecord), nil
}
if !follow || len(k) == 0 {
if len(k) == 0 {
return k, nil, nil
}

Expand Down Expand Up @@ -1146,6 +1119,10 @@ func getFollowLinksCallback(root *iradix.Node, k []byte, follow bool, cb followL
currentPath = nextPath
continue
}
if !followTrailing && remainingPath == "" {
currentPath = nextPath
break
}

linksWalked++
if linksWalked > maxSymlinkLimit {
Expand Down Expand Up @@ -1237,18 +1214,3 @@ func convertPathToKey(p []byte) []byte {
func convertKeyToPath(p []byte) []byte {
return bytes.Replace([]byte(p), []byte{0}, []byte("/"), -1)
}

func splitKey(k []byte) ([]byte, []byte) {
foundBytes := false
i := len(k) - 1
for {
if i <= 0 || foundBytes && k[i] == 0 {
break
}
if k[i] != 0 {
foundBytes = true
}
i--
}
return append([]byte{}, k[:i]...), k[i:]
}
11 changes: 10 additions & 1 deletion cache/contenthash/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type onSymlinkFunc func(string, string) error
// the root directory. This is a slightly modified version of SecureJoin from
// github.com/cyphar/filepath-securejoin, with a callback which we call after
// each symlink resolution.
func rootPath(root, unsafePath string, cb onSymlinkFunc) (string, error) {
func rootPath(root, unsafePath string, followTrailing bool, cb onSymlinkFunc) (string, error) {
if unsafePath == "" {
return root, nil
}
Expand All @@ -41,6 +41,9 @@ func rootPath(root, unsafePath string, cb onSymlinkFunc) (string, error) {
unsafePath = unsafePath[len(v):]
}

// Remove any unneccesary trailing slashes.
unsafePath = strings.TrimSuffix(unsafePath, string(filepath.Separator))

// Get the next path component.
var part string
if i := strings.IndexRune(unsafePath, filepath.Separator); i == -1 {
Expand Down Expand Up @@ -71,6 +74,11 @@ func rootPath(root, unsafePath string, cb onSymlinkFunc) (string, error) {
currentPath = nextPath
continue
}
// Don't resolve the final component with !followTrailing.
if !followTrailing && unsafePath == "" {
currentPath = nextPath
break
}

// It's a symlink, so get its contents and expand it by prepending it
// to the yet-unparsed path.
Expand All @@ -88,6 +96,7 @@ func rootPath(root, unsafePath string, cb onSymlinkFunc) (string, error) {
return "", err
}
}

unsafePath = dest + string(filepath.Separator) + unsafePath
// Absolute symlinks reset any work we've already done.
if filepath.IsAbs(dest) {
Expand Down

0 comments on commit ce7521c

Please sign in to comment.