diff --git a/libcontainer/container_linux.go b/libcontainer/container_linux.go index 1c76782f5..9c79260a9 100644 --- a/libcontainer/container_linux.go +++ b/libcontainer/container_linux.go @@ -300,7 +300,13 @@ func (c *linuxContainer) Start(process *Process) error { if c.config.RootfsUidShiftType == sh.Shiftfs || c.config.BindMntUidShiftType == sh.Shiftfs { - if err := c.setupShiftfsMarks(); err != nil { + + mounts, err := mount.GetMounts() + if err != nil { + return fmt.Errorf("failed to read mountinfo: %s", err) + } + + if err := c.setupShiftfsMarks(mounts); err != nil { return err } } @@ -776,7 +782,12 @@ func (c *linuxContainer) Destroy() error { } else { // If sysbox-mgr is not present (i.e., unit testing), then we teardown // shiftfs marks here. - if err2 := c.teardownShiftfsMarkLocal(); err == nil { + mounts, err := mount.GetMounts() + if err != nil { + return fmt.Errorf("failed to read mountinfo: %s", err) + } + + if err2 := c.teardownShiftfsMarkLocal(mounts); err == nil { err = err2 } } @@ -2558,7 +2569,7 @@ func (c *linuxContainer) procSeccompInit(pid int, fd int32) error { } // sysbox-runc: sets up the shiftfs marks for the container -func (c *linuxContainer) setupShiftfsMarks() error { +func (c *linuxContainer) setupShiftfsMarks(mi []*mount.Info) error { config := c.config shiftfsMounts := []configs.ShiftfsMount{} @@ -2588,11 +2599,6 @@ func (c *linuxContainer) setupShiftfsMarks() error { // orig file (i.e., the source of the bind mount). if !m.BindSrcInfo.IsDir { - mi, err := mount.GetMounts() - if err != nil { - return fmt.Errorf("failed to read mountinfo: %s", err) - } - isBindMnt, origSrc, err := fileIsBindMount(mi, m.Source) if err != nil { return fmt.Errorf("failed to check if %s is a bind-mount: %s", m.Source, err) @@ -2686,15 +2692,15 @@ func (c *linuxContainer) setupShiftfsMarks() error { } else { config.ShiftfsMounts = shiftfsMounts - return c.setupShiftfsMarkLocal() + return c.setupShiftfsMarkLocal(mi) } } // Setup shiftfs marks; meant for testing only -func (c *linuxContainer) setupShiftfsMarkLocal() error { +func (c *linuxContainer) setupShiftfsMarkLocal(mi []*mount.Info) error { for _, m := range c.config.ShiftfsMounts { - mounted, err := mount.MountedWithFs(m.Source, "shiftfs") + mounted, err := mount.MountedWithFs(m.Source, "shiftfs", mi) if err != nil { return newSystemErrorWithCausef(err, "checking for shiftfs mount at %s", m.Source) } @@ -2709,10 +2715,10 @@ func (c *linuxContainer) setupShiftfsMarkLocal() error { } // Teardown shiftfs marks; meant for testing only -func (c *linuxContainer) teardownShiftfsMarkLocal() error { +func (c *linuxContainer) teardownShiftfsMarkLocal(mi []*mount.Info) error { for _, m := range c.config.ShiftfsMounts { - mounted, err := mount.MountedWithFs(m.Source, "shiftfs") + mounted, err := mount.MountedWithFs(m.Source, "shiftfs", mi) if err != nil { return newSystemErrorWithCausef(err, "checking for shiftfs mount at %s", m.Source) } @@ -2748,8 +2754,9 @@ func (c *linuxContainer) rootfsCloningRequired() (bool, error) { } rootfs := c.config.Rootfs + mounts, err := mount.GetMounts() - mi, err := mount.GetMountAtPid(uint32(os.Getpid()), rootfs) + mi, err := mount.GetMountAt(rootfs, mounts) if err == nil && mi.Fstype == "overlay" && !strings.Contains(mi.Opts, "metacopy=on") { return true, nil } diff --git a/libcontainer/mount/mount.go b/libcontainer/mount/mount.go index 077abed2d..7fc65e9da 100644 --- a/libcontainer/mount/mount.go +++ b/libcontainer/mount/mount.go @@ -23,29 +23,13 @@ func FindMount(mountpoint string, mounts []*Info) bool { return false } -// Mounted looks at /proc/self/mountinfo to determine if the specified -// mountpoint has been mounted -func Mounted(mountpoint string) (bool, error) { - mounts, err := parseMountTable() - if err != nil { - return false, err - } - - isMounted := FindMount(mountpoint, mounts) - return isMounted, nil -} - // MountedWithFs looks at /proc/self/mountinfo to determine if the specified // mountpoint has been mounted with the given filesystem type. -func MountedWithFs(mountpoint string, fs string) (bool, error) { - entries, err := parseMountTable() - if err != nil { - return false, err - } +func MountedWithFs(mountpoint string, fs string, mounts []*Info) (bool, error) { // Search the table for the mountpoint - for _, e := range entries { - if e.Mountpoint == mountpoint && e.Fstype == fs { + for _, m := range mounts { + if m.Mountpoint == mountpoint && m.Fstype == fs { return true, nil } } @@ -53,31 +37,12 @@ func MountedWithFs(mountpoint string, fs string) (bool, error) { } // GetMountAt returns information about the given mountpoint. -func GetMountAt(mountpoint string) (*Info, error) { - entries, err := parseMountTable() - if err != nil { - return nil, err - } - // Search the table for the given mountpoint - for _, e := range entries { - if e.Mountpoint == mountpoint { - return e, nil - } - } - return nil, fmt.Errorf("%s is not a mountpoint", mountpoint) -} - -// GetMountAtPid returns information about the given mountpoint and pid. -func GetMountAtPid(pid uint32, mountpoint string) (*Info, error) { - entries, err := parseMountTableForPid(pid) - if err != nil { - return nil, err - } +func GetMountAt(mountpoint string, mounts []*Info) (*Info, error) { - // Search the table for the given mountpoint. - for _, e := range entries { - if e.Mountpoint == mountpoint { - return e, nil + // Search the table for the given mountpoint + for _, m := range mounts { + if m.Mountpoint == mountpoint { + return m, nil } } return nil, fmt.Errorf("%s is not a mountpoint", mountpoint) diff --git a/libcontainer/mount/mount_test.go b/libcontainer/mount/mount_test.go index 453f1c69c..e4ef6d38c 100644 --- a/libcontainer/mount/mount_test.go +++ b/libcontainer/mount/mount_test.go @@ -1,9 +1,6 @@ package mount import ( - "io/ioutil" - "log" - "os" "testing" ) @@ -26,52 +23,39 @@ func TestGetMounts(t *testing.T) { } } -func TestMounted(t *testing.T) { - ok, err := Mounted("/proc") - if err != nil || !ok { - t.Fatalf("Mounted() failed: %v, %v", ok, err) - } - ok, err = Mounted("/sys") - if err != nil || !ok { - t.Fatalf("Mounted() failed: %v, %v", ok, err) - } - - // negative testing - dir, err := ioutil.TempDir("", "TestMounted") +func TestMountedWithFs(t *testing.T) { + allMounts, err := GetMounts() if err != nil { - log.Fatal(err) - } - defer os.RemoveAll(dir) - - ok, err = Mounted("dir") - if err != nil || ok { - t.Fatalf("Mounted() failed: %v, %v", ok, err) + t.Fatalf("GetMounts() failed: %v", err) } -} -func TestMountedWithFs(t *testing.T) { - ok, err := MountedWithFs("/proc", "proc") + ok, err := MountedWithFs("/proc", "proc", allMounts) if err != nil || !ok { t.Fatalf("MountedWithFs() failed: %v, %v", ok, err) } - ok, err = MountedWithFs("/sys", "sysfs") + ok, err = MountedWithFs("/sys", "sysfs", allMounts) if err != nil || !ok { t.Fatalf("MountedWithFs() failed: %v, %v", ok, err) } // negative testing - ok, err = MountedWithFs("/proc", "sysfs") + ok, err = MountedWithFs("/proc", "sysfs", allMounts) if err != nil || ok { t.Fatalf("MountedWithFs() failed: %v, %v", ok, err) } - ok, err = MountedWithFs("/sys", "procfs") + ok, err = MountedWithFs("/sys", "procfs", allMounts) if err != nil || ok { t.Fatalf("MountedWithFs() failed: %v, %v", ok, err) } } func TestGetMountAt(t *testing.T) { - m, err := GetMountAt("/proc") + allMounts, err := GetMounts() + if err != nil { + t.Fatalf("GetMounts() failed: %v", err) + } + + m, err := GetMountAt("/proc", allMounts) if err != nil { t.Fatalf("GetMountAt() failed: %v", err) } @@ -81,7 +65,7 @@ func TestGetMountAt(t *testing.T) { } } - m, err = GetMountAt("/sys") + m, err = GetMountAt("/sys", allMounts) if err != nil { t.Fatalf("GetMountAt() failed: %v", err) } diff --git a/libsysbox/shiftfs/shiftfs.go b/libsysbox/shiftfs/shiftfs.go index eb6f0e712..36a6849e8 100644 --- a/libsysbox/shiftfs/shiftfs.go +++ b/libsysbox/shiftfs/shiftfs.go @@ -56,10 +56,11 @@ func Unmount(path string) error { // Returns a boolean indicating if the given path has a shiftfs mount // on it (mark or actual mount). -func Mounted(path string) (bool, error) { +func Mounted(path string, mounts []*mount.Info) (bool, error) { realPath, err := filepath.EvalSymlinks(path) if err != nil { return false, err } - return mount.MountedWithFs(realPath, "shiftfs") + + return mount.MountedWithFs(realPath, "shiftfs", mounts) } diff --git a/libsysbox/syscont/utils.go b/libsysbox/syscont/utils.go index 156044100..a80db129a 100644 --- a/libsysbox/syscont/utils.go +++ b/libsysbox/syscont/utils.go @@ -18,7 +18,6 @@ package syscont import ( "fmt" - "os" "sort" "strings" @@ -177,7 +176,12 @@ func rootfsCloningRequired(rootfs string) (bool, error) { // snapshots work properly. Once the rootfs is cloned, we then setup the // container using this cloned rootfs. - mi, err := mount.GetMountAtPid(uint32(os.Getpid()), rootfs) + mounts, err := mount.GetMounts() + if err != nil { + return false, err + } + + mi, err := mount.GetMountAt(rootfs, mounts) if err == nil && mi.Fstype == "overlay" && !strings.Contains(mi.Opts, "metacopy=on") { return true, nil }