Skip to content

Commit

Permalink
daemon, polkit: pid_t is signed
Browse files Browse the repository at this point in the history
We were using uint32 for pids in daemon and polkit, when they're
actually signed. This would be mostly transparent to snapd, but could
lead to spurious denials from polkit in some situations.
  • Loading branch information
chipaca authored and mvo5 committed Jan 29, 2019
1 parent 8b00c6f commit a819ae7
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 31 deletions.
6 changes: 3 additions & 3 deletions daemon/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5451,7 +5451,7 @@ func (s *postCreateUserSuite) SetUpTest(c *check.C) {
s.apiBaseSuite.SetUpTest(c)

s.daemon(c)
postCreateUserUcrednetGet = func(string) (uint32, uint32, string, error) {
postCreateUserUcrednetGet = func(string) (int32, uint32, string, error) {
return 100, 0, dirs.SnapdSocket, nil
}
s.mockUserHome = c.MkDir()
Expand Down Expand Up @@ -5858,7 +5858,7 @@ func (s *postCreateUserSuite) TestPostCreateUserFromAssertionAllKnownClassicErro

s.makeSystemUsers(c, []map[string]interface{}{goodUser})

postCreateUserUcrednetGet = func(string) (uint32, uint32, string, error) {
postCreateUserUcrednetGet = func(string) (int32, uint32, string, error) {
return 100, 0, dirs.SnapdSocket, nil
}
defer func() {
Expand Down Expand Up @@ -6596,7 +6596,7 @@ func (s *apiSuite) TestSnapctlGetNoUID(c *check.C) {
func (s *apiSuite) TestSnapctlForbiddenError(c *check.C) {
_ = s.daemon(c)

runSnapctlUcrednetGet = func(string) (uint32, uint32, string, error) {
runSnapctlUcrednetGet = func(string) (int32, uint32, string, error) {
return 100, 9999, dirs.SnapSocket, nil
}
defer func() { runSnapctlUcrednetGet = ucrednetGet }()
Expand Down
4 changes: 2 additions & 2 deletions daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type daemonSuite struct {

var _ = check.Suite(&daemonSuite{})

func (s *daemonSuite) checkAuthorization(pid uint32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) {
func (s *daemonSuite) checkAuthorization(pid int32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) {
s.lastPolkitFlags = flags
return s.authorized, s.err
}
Expand Down Expand Up @@ -372,7 +372,7 @@ func (s *daemonSuite) TestPolkitAccessForGet(c *check.C) {

// for UserOK commands, polkit is not consulted
cmd.UserOK = true
polkitCheckAuthorization = func(pid uint32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) {
polkitCheckAuthorization = func(pid int32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) {
panic("polkit.CheckAuthorization called")
}
c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
Expand Down
48 changes: 30 additions & 18 deletions daemon/ucrednet.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,23 @@ import (
var errNoID = errors.New("no pid/uid found")

const (
ucrednetNoProcess = uint32(0)
ucrednetNoProcess = int32(0)
ucrednetNobody = uint32((1 << 32) - 1)
)

func ucrednetGet(remoteAddr string) (pid uint32, uid uint32, socket string, err error) {
func ucrednetGet(remoteAddr string) (pid int32, uid uint32, socket string, err error) {
pid = ucrednetNoProcess
uid = ucrednetNobody
for _, token := range strings.Split(remoteAddr, ";") {
var v uint64
if strings.HasPrefix(token, "pid=") {
if v, err = strconv.ParseUint(token[4:], 10, 32); err == nil {
pid = uint32(v)
var v int64
if v, err = strconv.ParseInt(token[4:], 10, 32); err == nil {
pid = int32(v)
} else {
break
}
} else if strings.HasPrefix(token, "uid=") {
var v uint64
if v, err = strconv.ParseUint(token[4:], 10, 32); err == nil {
uid = uint32(v)
} else {
Expand All @@ -65,26 +66,35 @@ func ucrednetGet(remoteAddr string) (pid uint32, uid uint32, socket string, err
return pid, uid, socket, err
}

type ucrednet struct {
pid int32
uid uint32
socket string
}

func (un *ucrednet) String() string {
if un == nil {
return "pid=;uid=;socket=;"
}
return fmt.Sprintf("pid=%d;uid=%d;socket=%s;", un.pid, un.uid, un.socket)
}

type ucrednetAddr struct {
net.Addr
pid string
uid string
socket string
*ucrednet
}

func (wa *ucrednetAddr) String() string {
return fmt.Sprintf("pid=%s;uid=%s;socket=%s;%s", wa.pid, wa.uid, wa.socket, wa.Addr)
return wa.ucrednet.String()
}

type ucrednetConn struct {
net.Conn
pid string
uid string
socket string
*ucrednet
}

func (wc *ucrednetConn) RemoteAddr() net.Addr {
return &ucrednetAddr{wc.Conn.RemoteAddr(), wc.pid, wc.uid, wc.socket}
return &ucrednetAddr{wc.Conn.RemoteAddr(), wc.ucrednet}
}

type ucrednetListener struct{ net.Listener }
Expand All @@ -97,7 +107,7 @@ func (wl *ucrednetListener) Accept() (net.Conn, error) {
return nil, err
}

var pid, uid, socket string
var unet *ucrednet
if ucon, ok := con.(*net.UnixConn); ok {
f, err := ucon.File()
if err != nil {
Expand All @@ -111,10 +121,12 @@ func (wl *ucrednetListener) Accept() (net.Conn, error) {
return nil, err
}

pid = strconv.FormatUint(uint64(ucred.Pid), 10)
uid = strconv.FormatUint(uint64(ucred.Uid), 10)
socket = ucon.LocalAddr().String()
unet = &ucrednet{
pid: ucred.Pid,
uid: ucred.Uid,
socket: ucon.LocalAddr().String(),
}
}

return &ucrednetConn{con, pid, uid, socket}, err
return &ucrednetConn{con, unet}, nil
}
8 changes: 4 additions & 4 deletions daemon/ucrednet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (s *ucrednetSuite) TestAcceptConnRemoteAddrString(c *check.C) {
remoteAddr := conn.RemoteAddr().String()
c.Check(remoteAddr, check.Matches, "pid=100;uid=42;.*")
pid, uid, _, err := ucrednetGet(remoteAddr)
c.Check(pid, check.Equals, uint32(100))
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, uint32(42))
c.Check(err, check.IsNil)
}
Expand Down Expand Up @@ -146,14 +146,14 @@ func (s *ucrednetSuite) TestUcredErrors(c *check.C) {
func (s *ucrednetSuite) TestGetNoUid(c *check.C) {
pid, uid, _, err := ucrednetGet("pid=100;uid=;")
c.Check(err, check.Equals, errNoID)
c.Check(pid, check.Equals, uint32(100))
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, ucrednetNobody)
}

func (s *ucrednetSuite) TestGetBadUid(c *check.C) {
pid, uid, _, err := ucrednetGet("pid=100;uid=hello;")
c.Check(err, check.NotNil)
c.Check(pid, check.Equals, uint32(100))
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, ucrednetNobody)
}

Expand All @@ -174,7 +174,7 @@ func (s *ucrednetSuite) TestGetNothing(c *check.C) {
func (s *ucrednetSuite) TestGet(c *check.C) {
pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket")
c.Check(err, check.IsNil)
c.Check(pid, check.Equals, uint32(100))
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, uint32(42))
c.Check(socket, check.Equals, "/run/snap.socket")
}
2 changes: 1 addition & 1 deletion polkit/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func checkAuthorization(subject authSubject, actionId string, details map[string

// CheckAuthorization queries polkit to determine whether a process is
// authorized to perform an action.
func CheckAuthorization(pid uint32, uid uint32, actionId string, details map[string]string, flags CheckFlags) (bool, error) {
func CheckAuthorization(pid int32, uid uint32, actionId string, details map[string]string, flags CheckFlags) (bool, error) {
subject := authSubject{
Kind: "unix-process",
Details: make(map[string]dbus.Variant),
Expand Down
2 changes: 1 addition & 1 deletion polkit/pid_start_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
)

// getStartTimeForPid determines the start time for a given process ID
func getStartTimeForPid(pid uint32) (uint64, error) {
func getStartTimeForPid(pid int32) (uint64, error) {
filename := fmt.Sprintf("/proc/%d/stat", pid)
return getStartTimeForProcStatFile(filename)
}
Expand Down
4 changes: 2 additions & 2 deletions polkit/pid_start_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ var _ = check.Suite(&polkitSuite{})
func (s *polkitSuite) TestGetStartTime(c *check.C) {
pid := os.Getpid()

startTime, err := getStartTimeForPid(uint32(pid))
startTime, err := getStartTimeForPid(int32(pid))
c.Assert(err, check.IsNil)
c.Check(startTime, check.Not(check.Equals), uint64(0))
}
Expand All @@ -54,7 +54,7 @@ func (s *polkitSuite) TestGetStartTimeBadPid(c *check.C) {
pid += 1
}

startTime, err := getStartTimeForPid(uint32(pid))
startTime, err := getStartTimeForPid(int32(pid))
c.Assert(err, check.ErrorMatches, "open .*: no such file or directory")
c.Check(startTime, check.Equals, uint64(0))
}
Expand Down

0 comments on commit a819ae7

Please sign in to comment.