Skip to content

Commit

Permalink
Merge pull request #100 from kolyshkin/mounted-fast
Browse files Browse the repository at this point in the history
mountinfo: add MountedFast
  • Loading branch information
thaJeztah committed Feb 3, 2022
2 parents 74ec3fe + 5d09d69 commit d01e595
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 54 deletions.
58 changes: 51 additions & 7 deletions mountinfo/mounted_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,34 @@ import (
"golang.org/x/sys/unix"
)

// MountedFast is a method of detecting a mount point without reading
// mountinfo from procfs. A caller can only trust the result if no error
// and sure == true are returned. Otherwise, other methods (e.g. parsing
// /proc/mounts) have to be used. If unsure, use Mounted instead (which
// uses MountedFast, but falls back to parsing mountinfo if needed).
//
// If a non-existent path is specified, an appropriate error is returned.
// In case the caller is not interested in this particular error, it should
// be handled separately using e.g. errors.Is(err, os.ErrNotExist).
//
// This function is only available on Linux. When available (since kernel
// v5.6), openat2(2) syscall is used to reliably detect all mounts. Otherwise,
// the implementation falls back to using stat(2), which can reliably detect
// normal (but not bind) mounts.
func MountedFast(path string) (mounted, sure bool, err error) {
// Root is always mounted.
if path == string(os.PathSeparator) {
return true, true, nil
}

path, err = normalizePath(path)
if err != nil {
return false, false, err
}
mounted, sure, err = mountedFast(path)
return
}

// mountedByOpenat2 is a method of detecting a mount that works for all kinds
// of mounts (incl. bind mounts), but requires a recent (v5.6+) linux kernel.
func mountedByOpenat2(path string) (bool, error) {
Expand Down Expand Up @@ -34,24 +62,40 @@ func mountedByOpenat2(path string) (bool, error) {
return false, &os.PathError{Op: "openat2", Path: path, Err: err}
}

func mounted(path string) (bool, error) {
path, err := normalizePath(path)
if err != nil {
return false, err
// mountedFast is similar to MountedFast, except it expects a normalized path.
func mountedFast(path string) (mounted, sure bool, err error) {
// Root is always mounted.
if path == string(os.PathSeparator) {
return true, true, nil
}

// Try a fast path, using openat2() with RESOLVE_NO_XDEV.
mounted, err := mountedByOpenat2(path)
mounted, err = mountedByOpenat2(path)
if err == nil {
return mounted, nil
return mounted, true, nil
}

// Another fast path: compare st.st_dev fields.
mounted, err = mountedByStat(path)
// This does not work for bind mounts, so false negative
// is possible, therefore only trust if return is true.
if mounted && err == nil {
return true, true, nil
}

return
}

func mounted(path string) (bool, error) {
path, err := normalizePath(path)
if err != nil {
return false, err
}
mounted, sure, err := mountedFast(path)
if sure && err == nil {
return mounted, nil
}

// Fallback to parsing mountinfo
// Fallback to parsing mountinfo.
return mountedByMountinfo(path)
}
158 changes: 114 additions & 44 deletions mountinfo/mounted_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"net"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"

"golang.org/x/sys/unix"
Expand All @@ -31,7 +34,7 @@ func tMount(t *testing.T, src, dst, fstype string, flags uintptr, options string
return nil
}

var testMounts = []struct {
type testMount struct {
desc string
isNotExist bool
isMount bool
Expand All @@ -45,7 +48,9 @@ var testMounts = []struct {
// simplicity (no need to check for errors or call t.Cleanup()), and
// it may call t.Fatal, but practically we don't expect it.
prepare func(t *testing.T) (string, error)
}{
}

var testMounts = []testMount{
{
desc: "non-existent path",
isNotExist: true,
Expand Down Expand Up @@ -275,9 +280,68 @@ func tryOpenat2() error {
return err
}

func testMountedFast(t *testing.T, path string, tc *testMount, openat2Supported bool) {
mounted, sure, err := MountedFast(path)
if err != nil {
// Got an error; is it expected?
if !(tc.isNotExist && errors.Is(err, os.ErrNotExist)) {
t.Errorf("MountedFast: unexpected error: %v", err)
}

// In case of an error, sure and mounted must be false.
if sure {
t.Error("MountedFast: expected sure to be false on error")
}
if mounted {
t.Error("MountedFast: expected mounted to be false on error")
}

// No more checks.
return
}

if openat2Supported {
if mounted != tc.isMount {
t.Errorf("MountedFast: expected mounted to be %v, got %v", tc.isMount, mounted)
}

// No more checks.
return
}

if tc.isBind {
// For bind mounts, in case openat2 is not supported,
// sure and mounted must be false.
if sure {
t.Error("MountedFast: expected sure to be false for a bind mount")
}
if mounted {
t.Error("MountedFast: expected mounted to be false for a bind mount")
}
} else {
if mounted != tc.isMount {
t.Errorf("MountFast: expected mounted to be %v, got %v", tc.isMount, mounted)
}
if tc.isMount && !sure {
t.Error("MountFast: expected sure to be true for normal mount")
}
if !tc.isMount && sure {
t.Error("MountFast: expected sure to be false for non-mount")
}
}
}

func TestMountedBy(t *testing.T) {
openat2Supported := tryOpenat2() == nil
checked := false
openat2Supported := false

// List of individual implementations to check.
toCheck := []func(string) (bool, error){mountedByMountinfo, mountedByStat}
if tryOpenat2() == nil {
openat2Supported = true
toCheck = append(toCheck, mountedByOpenat2)
}

for _, tc := range testMounts {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
Expand All @@ -286,25 +350,26 @@ func TestMountedBy(t *testing.T) {
t.Fatalf("prepare: %v", err)
}

exp := tc.isMount

// Check the public Mounted() function as a whole.
mounted, err := Mounted(m)
if err == nil {
if mounted != exp {
t.Errorf("Mounted: expected %v, got %v", exp, mounted)
if mounted != tc.isMount {
t.Errorf("Mounted: expected %v, got %v", tc.isMount, mounted)
}
} else {
// Got an error; is it expected?
if !(tc.isNotExist && errors.Is(err, os.ErrNotExist)) {
t.Errorf("Mounted: unexpected error: %v", err)
}
// Check false is returned in error case.
if mounted != false {
t.Errorf("Mounted: expected false on error, got %v", mounted)
if mounted {
t.Error("Mounted: expected false on error")
}
}

// Check the public MountedFast() function as a whole.
testMountedFast(t, m, &tc, openat2Supported)

// Check individual mountedBy* implementations.

// All mountedBy* functions should be called with normalized paths.
Expand All @@ -316,45 +381,29 @@ func TestMountedBy(t *testing.T) {
t.Fatalf("normalizePath: %v", err)
}

mounted, err = mountedByMountinfo(m)
if err != nil {
t.Errorf("mountedByMountinfo error: %v", err)
// Check false is returned in error case.
if mounted != false {
t.Errorf("MountedByMountinfo: expected false on error, got %v", mounted)
}
} else if mounted != exp {
t.Errorf("mountedByMountinfo: expected %v, got %v", exp, mounted)
}
checked = true

mounted, err = mountedByStat(m)
if err != nil {
t.Errorf("mountedByStat error: %v", err)
// Check false is returned in error case.
if mounted != false {
t.Errorf("MountedByStat: expected false on error, got %v", mounted)
}
} else if mounted != exp && !tc.isBind { // mountedByStat can not detect bind mounts
t.Errorf("mountedByStat: expected %v, got %v", exp, mounted)
}

if !openat2Supported {
return
}
mounted, err = mountedByOpenat2(m)
if err != nil {
t.Errorf("mountedByOpenat2 error: %v", err)
// Check false is returned in error case.
if mounted != false {
t.Errorf("MountedByOpenat2: expected false on error, got %v", mounted)
for _, fn := range toCheck {
// Figure out function name.
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()

mounted, err = fn(m)
if err != nil {
t.Errorf("%s: %v", name, err)
// Check false is returned in error case.
if mounted {
t.Errorf("%s: expected false on error", name)
}
} else if mounted != tc.isMount {
if tc.isBind && strings.HasSuffix(name, "mountedByStat") {
// mountedByStat can not detect bind mounts.
} else {
t.Errorf("%s: expected %v, got %v", name, tc.isMount, mounted)
}
}
} else if mounted != exp {
t.Errorf("mountedByOpenat2: expected %v, got %v", exp, mounted)
checked = true
}
})

}

if !checked {
t.Skip("no mounts to check")
}
Expand Down Expand Up @@ -385,3 +434,24 @@ func TestMountedByOpenat2VsMountinfo(t *testing.T) {
}
}
}

// TestMountedRoot checks that Mounted* functions always return true for root
// directory (since / is always mounted).
func TestMountedRoot(t *testing.T) {
for _, path := range []string{
"/",
"/../../",
"/tmp/..",
strings.Repeat("../", unix.PathMax/3), // Hope $CWD is not too deep down.
} {
mounted, err := Mounted(path)
if err != nil || !mounted {
t.Errorf("Mounted(%q): expected true, <nil>; got %v, %v", path, mounted, err)
}

mounted, sure, err := MountedFast(path)
if err != nil || !mounted || !sure {
t.Errorf("MountedFast(%q): expected true, true, <nil>; got %v, %v, %v", path, mounted, sure, err)
}
}
}
6 changes: 3 additions & 3 deletions mountinfo/mountinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ func GetMounts(f FilterFunc) ([]*Info, error) {
// Mounted determines if a specified path is a mount point. In case of any
// error, false (and an error) is returned.
//
// The non-existent path returns an error. If a caller is not interested
// in this particular error, it should handle it separately using e.g.
// errors.Is(err, os.ErrNotExist).
// If a non-existent path is specified, an appropriate error is returned.
// In case the caller is not interested in this particular error, it should
// be handled separately using e.g. errors.Is(err, os.ErrNotExist).
func Mounted(path string) (bool, error) {
// root is always mounted
if path == string(os.PathSeparator) {
Expand Down

0 comments on commit d01e595

Please sign in to comment.