diff --git a/go.mod b/go.mod index 6a36ba4..4135585 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,4 @@ go 1.16 require ( github.com/blang/semver/v4 v4.0.0 github.com/stretchr/testify v1.3.0 - golang.org/x/sys v0.0.0-20191022100944-742c48ecaeb7 ) diff --git a/go.sum b/go.sum index a3ad36b..67a62ab 100644 --- a/go.sum +++ b/go.sum @@ -7,5 +7,3 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -golang.org/x/sys v0.0.0-20191022100944-742c48ecaeb7 h1:HmbHVPwrPEKPGLAcHSrMe6+hqSUlvZU0rab6x5EXfGU= -golang.org/x/sys v0.0.0-20191022100944-742c48ecaeb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/validators/cgroup_validator_linux.go b/validators/cgroup_validator_linux.go index 8337b75..3945b39 100644 --- a/validators/cgroup_validator_linux.go +++ b/validators/cgroup_validator_linux.go @@ -27,8 +27,6 @@ import ( "os" "path/filepath" "strings" - - "golang.org/x/sys/unix" ) var _ Validator = &CgroupsValidator{} @@ -44,16 +42,17 @@ func (c *CgroupsValidator) Name() string { } const ( - cgroupsConfigPrefix = "CGROUPS_" - mountsFilePath = "/proc/mounts" + cgroupsConfigPrefix = "CGROUPS_" + mountsFilePath = "/proc/mounts" + defaultUnifiedMountPoint = "/sys/fs/cgroup" ) // getUnifiedMountpoint checks if the default mount point is available. // If not, it parses the mounts file to find a valid cgroup mount point. -func getUnifiedMountpoint(path string) (string, error) { +func getUnifiedMountpoint(path string) (string, bool, error) { f, err := os.Open(path) if err != nil { - return "", err + return "", false, err } defer f.Close() scanner := bufio.NewScanner(f) @@ -66,10 +65,19 @@ func getUnifiedMountpoint(path string) (string, error) { // Example fields: `cgroup2 /sys/fs/cgroup cgroup2 rw,seclabel,nosuid,nodev,noexec,relatime 0 0`. fields := strings.Fields(line) if len(fields) >= 3 { + // If default unified mount point is available, return it directly. + if fields[1] == defaultUnifiedMountPoint { + if fields[2] == "tmpfs" { + // if `/sys/fs/cgroup/memory` is a dir, this means it uses cgroups v1 + info, err := os.Stat(filepath.Join(defaultUnifiedMountPoint, "memory")) + return defaultUnifiedMountPoint, os.IsNotExist(err) || !info.IsDir(), nil + } + return defaultUnifiedMountPoint, fields[2] == "cgroup2", nil + } switch fields[2] { case "cgroup2": // Return the first cgroups v2 mount point directly. - return fields[1], nil + return fields[1], true, nil case "cgroup": // Set the first cgroups v1 mount point only, // and continue the loop to find if there is a cgroups v2 mount point. @@ -81,29 +89,22 @@ func getUnifiedMountpoint(path string) (string, error) { } // Return cgroups v1 mount point if no cgroups v2 mount point is found. if len(cgroupV1MountPoint) != 0 { - return cgroupV1MountPoint, nil + return cgroupV1MountPoint, false, nil } - return "", fmt.Errorf("cannot get a cgroupfs mount point from %q", path) + return "", false, fmt.Errorf("cannot get a cgroupfs mount point from %q", path) } // Validate is part of the system.Validator interface. func (c *CgroupsValidator) Validate(spec SysSpec) (warns, errs []error) { - // Get the subsystems from /sys/fs/cgroup/cgroup.controllers when cgroups v2 is used. - // /proc/cgroups is meaningless for v2 - // https://github.com/torvalds/linux/blob/v5.3/Documentation/admin-guide/cgroup-v2.rst#deprecated-v1-core-features - var st unix.Statfs_t - unifiedMountpoint, err := getUnifiedMountpoint(mountsFilePath) + unifiedMountpoint, isCgroupsV2, err := getUnifiedMountpoint(mountsFilePath) if err != nil { return nil, []error{fmt.Errorf("cannot get a cgroup mount point: %w", err)} } - if err := unix.Statfs(unifiedMountpoint, &st); err != nil { - return nil, []error{fmt.Errorf("cannot statfs the cgroupv2 root: %w", err)} - } var requiredCgroupSpec []string var optionalCgroupSpec []string var subsystems []string var warn error - if st.Type == unix.CGROUP2_SUPER_MAGIC { + if isCgroupsV2 { subsystems, err, warn = c.getCgroupV2Subsystems(unifiedMountpoint) if err != nil { return nil, []error{fmt.Errorf("failed to get cgroups v2 subsystems: %w", err)} diff --git a/validators/cgroup_validator_test.go b/validators/cgroup_validator_test.go index bfc7733..7168fb3 100644 --- a/validators/cgroup_validator_test.go +++ b/validators/cgroup_validator_test.go @@ -21,6 +21,7 @@ package system import ( "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -96,15 +97,24 @@ func TestValidateCgroupSubsystem(t *testing.T) { } func TestGetUnifiedMountpoint(t *testing.T) { + c, err := os.Open(filepath.Join(defaultUnifiedMountPoint, "cgroup.controllers")) + if err == nil { + defer c.Close() + } tests := map[string]struct { - mountsFileContent string - expectedErr bool - expectedPath string + mountsFileContent string + expectedErr bool + expectedPath string + expectedIsCgroupsV2 bool + // when /sys/fs/cgroup is mounted as tmpfs, + // the cgroup version check depends on checking local dir: `/sys/fs/cgroup/memory` + skipIsCgroupsV2Check bool }{ "cgroups v2": { - mountsFileContent: "cgroup2 /sys/fs/cgroup cgroup2 rw,seclabel,nosuid,nodev,noexec,relatime 0 0", - expectedErr: false, - expectedPath: "/sys/fs/cgroup", + mountsFileContent: "cgroup2 /sys/fs/cgroup cgroup2 rw,seclabel,nosuid,nodev,noexec,relatime 0 0", + expectedErr: false, + expectedPath: "/sys/fs/cgroup", + expectedIsCgroupsV2: true, }, "cgroups v1": { mountsFileContent: "cgroup /sys/fs/cgroup cgroup rw,seclabel,nosuid,nodev,noexec,relatime 0 0", @@ -126,8 +136,9 @@ sysfs /sys sysfs rw,seclabel,nosuid,nodev,noexec,relatime 0 0`, mountsFileContent: `cgroup /sys/fs/cgroup/cpuset cgroup rw,nosuid,nodev,noexec,relatime,cpuset cgroup /sys/fs/cgroup/memory cgroup rw,nosuid,nodev,noexec,relatime,memory cgroup2 /sys/fs/cgroup/unified cgroup2 rw,seclabel,nosuid,nodev,noexec,relatime`, - expectedErr: false, - expectedPath: "/sys/fs/cgroup/unified", + expectedErr: false, + expectedPath: "/sys/fs/cgroup/unified", + expectedIsCgroupsV2: true, }, "cgroups v1 only with multiple subsystems": { mountsFileContent: `cgroup /sys/fs/cgroup/cpuset cgroup rw,nosuid,nodev,noexec,relatime,cpuset @@ -140,6 +151,25 @@ cgroup /sys/fs/cgroup/memory cgroup rw,nosuid,nodev,noexec,relatime,memory`, expectedErr: true, expectedPath: "", }, + "cgroups using tmpfs, v1 and v2": { + mountsFileContent: `tmpfs /run tmpfs rw,nosuid,nodev,size=803108k,nr_inodes=819200,mode=755 0 0 +tmpfs /sys/fs/cgroup tmpfs ro,nosuid,nodev,noexec,size=4096k,nr_inodes=1024,mode=755 0 0 +cgroup2 /sys/fs/cgroup/unified cgroup2 rw,nosuid,nodev,noexec,relatime,nsdelegate 0 0 +cgroup /sys/fs/cgroup/systemd cgroup rw,nosuid,nodev,noexec,relatime,xattr,name=systemd 0 0`, + expectedErr: false, + expectedPath: "/sys/fs/cgroup", + skipIsCgroupsV2Check: true, + }, + "cgroups using tmpfs, v1": { + mountsFileContent: `tmpfs /sys/fs/cgroup tmpfs ro,seclabel,nosuid,nodev,noexec,mode=755 0 0 +cgroup /sys/fs/cgroup/systemd cgroup rw,seclabel,nosuid,nodev,noexec,relatime,xattr,release_agent=/usr/lib/systemd/systemd-cgroups-agent,name=systemd 0 0 +cgroup /sys/fs/cgroup/net_cls,net_prio cgroup rw,seclabel,nosuid,nodev,noexec,relatime,net_cls,net_prio 0 0 +cgroup /sys/fs/cgroup/blkio cgroup rw,seclabel,nosuid,nodev,noexec,relatime,blkio 0 0 +cgroup /sys/fs/cgroup/memory cgroup rw,seclabel,nosuid,nodev,noexec,relatime,memory 0 0`, + expectedErr: false, + expectedPath: "/sys/fs/cgroup", + skipIsCgroupsV2Check: true, + }, } for desc, test := range tests { @@ -152,7 +182,7 @@ cgroup /sys/fs/cgroup/memory cgroup rw,nosuid,nodev,noexec,relatime,memory`, assert.NoError(t, err, "Unexpected error writing to temp file") tmpFile.Close() - path, err := getUnifiedMountpoint(tmpFile.Name()) + path, isCgroupsV2, err := getUnifiedMountpoint(tmpFile.Name()) if test.expectedErr { assert.Error(t, err, "Expected error but got none") @@ -161,6 +191,9 @@ cgroup /sys/fs/cgroup/memory cgroup rw,nosuid,nodev,noexec,relatime,memory`, } assert.Equal(t, test.expectedPath, path, "Expected cgroup path mismatch") + if !test.skipIsCgroupsV2Check { + assert.Equal(t, test.expectedIsCgroupsV2, isCgroupsV2, "Expected cgroup version mismatch") + } }) } }