diff --git a/cmd/fscrypt/commands.go b/cmd/fscrypt/commands.go index 2f23a0fc..2733890f 100644 --- a/cmd/fscrypt/commands.go +++ b/cmd/fscrypt/commands.go @@ -353,7 +353,7 @@ func purgeAction(c *cli.Context) error { } if dropCachesFlag.Value { - if !util.IsUserRoot() { + if util.CurrentUserID() != 0 { return newExitError(c, ErrDropCachesPerm) } } diff --git a/cmd/fscrypt/errors.go b/cmd/fscrypt/errors.go index 81a67985..72f89436 100644 --- a/cmd/fscrypt/errors.go +++ b/cmd/fscrypt/errors.go @@ -54,7 +54,6 @@ var ( ErrSpecifyKeyFile = errors.New("no key file specified") ErrKeyFileLength = errors.Errorf("key file must be %d bytes", metadata.InternalKeyLen) ErrAllLoadsFailed = errors.New("could not load any protectors") - ErrMustBeRoot = errors.New("this command must be run as root") ErrPolicyUnlocked = errors.New("this file or directory is already unlocked") ErrBadOwners = errors.New("you do not own this directory") ErrNotEmptyDir = errors.New("not an empty directory") diff --git a/cmd/fscrypt/setup.go b/cmd/fscrypt/setup.go index 72dfbdb0..ac32484f 100644 --- a/cmd/fscrypt/setup.go +++ b/cmd/fscrypt/setup.go @@ -31,8 +31,8 @@ import ( // createGlobalConfig creates (or overwrites) the global config file func createGlobalConfig(w io.Writer, path string) error { - if !util.IsUserRoot() { - return ErrMustBeRoot + if err := util.CheckIfRoot(); err != nil { + return err } // Ask to create or replace the config file diff --git a/security/keyring.go b/security/keyring.go index ab656319..7ce163e0 100644 --- a/security/keyring.go +++ b/security/keyring.go @@ -114,7 +114,7 @@ func UserKeyringID(target *user.User, checkSession bool) (int, error) { return 0, errors.Wrap(ErrAccessUserKeyring, err.Error()) } - if !util.IsUserRoot() { + if util.CurrentUserID() != 0 { // Make sure the returned keyring will be accessible by checking // that it is in the session keyring. if checkSession && !isUserKeyringInSession(uid) { diff --git a/util/errors.go b/util/errors.go index fada687e..f0b94037 100644 --- a/util/errors.go +++ b/util/errors.go @@ -29,6 +29,8 @@ import ( ) var ( + // ErrNotRoot indicates the action is restricted to the superuser. + ErrNotRoot = errors.New("only root can perform this action") // ErrSkipIntegration indicates integration tests shouldn't be run. ErrSkipIntegration = errors.New("skipping integration test") ) diff --git a/util/users.go b/util/users.go index 92affa88..49abd32d 100644 --- a/util/users.go +++ b/util/users.go @@ -48,3 +48,11 @@ func GetUser(uid int) *user.User { func CurrentUser() *user.User { return GetUser(CurrentUserID()) } + +// CheckIfRoot returns ErrNotRoot if the current user is not the root user. +func CheckIfRoot() error { + if id := CurrentUserID(); id != 0 { + return ErrNotRoot + } + return nil +} diff --git a/util/util.go b/util/util.go index ed78519a..df24a99d 100644 --- a/util/util.go +++ b/util/util.go @@ -117,8 +117,3 @@ func AtoiOrPanic(input string) int { } return i } - -// IsUserRoot checks if the effective user is root. -func IsUserRoot() bool { - return CurrentUserID() == 0 -}