Skip to content

Commit

Permalink
Merge pull request #6443 from chipaca/pids-are-signed
Browse files Browse the repository at this point in the history
daemon, polkit: pid_t is signed
  • Loading branch information
chipaca committed Jan 29, 2019
2 parents e9e9c05 + a819ae7 commit 6a05658
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 31 deletions.
6 changes: 3 additions & 3 deletions daemon/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5508,7 +5508,7 @@ func (s *postCreateUserSuite) SetUpTest(c *check.C) {
s.apiBaseSuite.SetUpTest(c)

s.daemon(c)
postCreateUserUcrednetGet = func(string) (uint32, uint32, string, error) {
postCreateUserUcrednetGet = func(string) (int32, uint32, string, error) {
return 100, 0, dirs.SnapdSocket, nil
}
s.mockUserHome = c.MkDir()
Expand Down Expand Up @@ -5915,7 +5915,7 @@ func (s *postCreateUserSuite) TestPostCreateUserFromAssertionAllKnownClassicErro

s.makeSystemUsers(c, []map[string]interface{}{goodUser})

postCreateUserUcrednetGet = func(string) (uint32, uint32, string, error) {
postCreateUserUcrednetGet = func(string) (int32, uint32, string, error) {
return 100, 0, dirs.SnapdSocket, nil
}
defer func() {
Expand Down Expand Up @@ -6653,7 +6653,7 @@ func (s *apiSuite) TestSnapctlGetNoUID(c *check.C) {
func (s *apiSuite) TestSnapctlForbiddenError(c *check.C) {
_ = s.daemon(c)

runSnapctlUcrednetGet = func(string) (uint32, uint32, string, error) {
runSnapctlUcrednetGet = func(string) (int32, uint32, string, error) {
return 100, 9999, dirs.SnapSocket, nil
}
defer func() { runSnapctlUcrednetGet = ucrednetGet }()
Expand Down
4 changes: 2 additions & 2 deletions daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type daemonSuite struct {

var _ = check.Suite(&daemonSuite{})

func (s *daemonSuite) checkAuthorization(pid uint32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) {
func (s *daemonSuite) checkAuthorization(pid int32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) {
s.lastPolkitFlags = flags
return s.authorized, s.err
}
Expand Down Expand Up @@ -372,7 +372,7 @@ func (s *daemonSuite) TestPolkitAccessForGet(c *check.C) {

// for UserOK commands, polkit is not consulted
cmd.UserOK = true
polkitCheckAuthorization = func(pid uint32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) {
polkitCheckAuthorization = func(pid int32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) {
panic("polkit.CheckAuthorization called")
}
c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
Expand Down
48 changes: 30 additions & 18 deletions daemon/ucrednet.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,23 @@ import (
var errNoID = errors.New("no pid/uid found")

const (
ucrednetNoProcess = uint32(0)
ucrednetNoProcess = int32(0)
ucrednetNobody = uint32((1 << 32) - 1)
)

func ucrednetGet(remoteAddr string) (pid uint32, uid uint32, socket string, err error) {
func ucrednetGet(remoteAddr string) (pid int32, uid uint32, socket string, err error) {
pid = ucrednetNoProcess
uid = ucrednetNobody
for _, token := range strings.Split(remoteAddr, ";") {
var v uint64
if strings.HasPrefix(token, "pid=") {
if v, err = strconv.ParseUint(token[4:], 10, 32); err == nil {
pid = uint32(v)
var v int64
if v, err = strconv.ParseInt(token[4:], 10, 32); err == nil {
pid = int32(v)
} else {
break
}
} else if strings.HasPrefix(token, "uid=") {
var v uint64
if v, err = strconv.ParseUint(token[4:], 10, 32); err == nil {
uid = uint32(v)
} else {
Expand All @@ -65,26 +66,35 @@ func ucrednetGet(remoteAddr string) (pid uint32, uid uint32, socket string, err
return pid, uid, socket, err
}

type ucrednet struct {
pid int32
uid uint32
socket string
}

func (un *ucrednet) String() string {
if un == nil {
return "pid=;uid=;socket=;"
}
return fmt.Sprintf("pid=%d;uid=%d;socket=%s;", un.pid, un.uid, un.socket)
}

type ucrednetAddr struct {
net.Addr
pid string
uid string
socket string
*ucrednet
}

func (wa *ucrednetAddr) String() string {
return fmt.Sprintf("pid=%s;uid=%s;socket=%s;%s", wa.pid, wa.uid, wa.socket, wa.Addr)
return wa.ucrednet.String()
}

type ucrednetConn struct {
net.Conn
pid string
uid string
socket string
*ucrednet
}

func (wc *ucrednetConn) RemoteAddr() net.Addr {
return &ucrednetAddr{wc.Conn.RemoteAddr(), wc.pid, wc.uid, wc.socket}
return &ucrednetAddr{wc.Conn.RemoteAddr(), wc.ucrednet}
}

type ucrednetListener struct{ net.Listener }
Expand All @@ -97,7 +107,7 @@ func (wl *ucrednetListener) Accept() (net.Conn, error) {
return nil, err
}

var pid, uid, socket string
var unet *ucrednet
if ucon, ok := con.(*net.UnixConn); ok {
f, err := ucon.File()
if err != nil {
Expand All @@ -111,10 +121,12 @@ func (wl *ucrednetListener) Accept() (net.Conn, error) {
return nil, err
}

pid = strconv.FormatUint(uint64(ucred.Pid), 10)
uid = strconv.FormatUint(uint64(ucred.Uid), 10)
socket = ucon.LocalAddr().String()
unet = &ucrednet{
pid: ucred.Pid,
uid: ucred.Uid,
socket: ucon.LocalAddr().String(),
}
}

return &ucrednetConn{con, pid, uid, socket}, err
return &ucrednetConn{con, unet}, nil
}
8 changes: 4 additions & 4 deletions daemon/ucrednet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (s *ucrednetSuite) TestAcceptConnRemoteAddrString(c *check.C) {
remoteAddr := conn.RemoteAddr().String()
c.Check(remoteAddr, check.Matches, "pid=100;uid=42;.*")
pid, uid, _, err := ucrednetGet(remoteAddr)
c.Check(pid, check.Equals, uint32(100))
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, uint32(42))
c.Check(err, check.IsNil)
}
Expand Down Expand Up @@ -146,14 +146,14 @@ func (s *ucrednetSuite) TestUcredErrors(c *check.C) {
func (s *ucrednetSuite) TestGetNoUid(c *check.C) {
pid, uid, _, err := ucrednetGet("pid=100;uid=;")
c.Check(err, check.Equals, errNoID)
c.Check(pid, check.Equals, uint32(100))
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, ucrednetNobody)
}

func (s *ucrednetSuite) TestGetBadUid(c *check.C) {
pid, uid, _, err := ucrednetGet("pid=100;uid=hello;")
c.Check(err, check.NotNil)
c.Check(pid, check.Equals, uint32(100))
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, ucrednetNobody)
}

Expand All @@ -174,7 +174,7 @@ func (s *ucrednetSuite) TestGetNothing(c *check.C) {
func (s *ucrednetSuite) TestGet(c *check.C) {
pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket")
c.Check(err, check.IsNil)
c.Check(pid, check.Equals, uint32(100))
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, uint32(42))
c.Check(socket, check.Equals, "/run/snap.socket")
}
2 changes: 1 addition & 1 deletion polkit/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func checkAuthorization(subject authSubject, actionId string, details map[string

// CheckAuthorization queries polkit to determine whether a process is
// authorized to perform an action.
func CheckAuthorization(pid uint32, uid uint32, actionId string, details map[string]string, flags CheckFlags) (bool, error) {
func CheckAuthorization(pid int32, uid uint32, actionId string, details map[string]string, flags CheckFlags) (bool, error) {
subject := authSubject{
Kind: "unix-process",
Details: make(map[string]dbus.Variant),
Expand Down
2 changes: 1 addition & 1 deletion polkit/pid_start_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
)

// getStartTimeForPid determines the start time for a given process ID
func getStartTimeForPid(pid uint32) (uint64, error) {
func getStartTimeForPid(pid int32) (uint64, error) {
filename := fmt.Sprintf("/proc/%d/stat", pid)
return getStartTimeForProcStatFile(filename)
}
Expand Down
4 changes: 2 additions & 2 deletions polkit/pid_start_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ var _ = check.Suite(&polkitSuite{})
func (s *polkitSuite) TestGetStartTime(c *check.C) {
pid := os.Getpid()

startTime, err := getStartTimeForPid(uint32(pid))
startTime, err := getStartTimeForPid(int32(pid))
c.Assert(err, check.IsNil)
c.Check(startTime, check.Not(check.Equals), uint64(0))
}
Expand All @@ -54,7 +54,7 @@ func (s *polkitSuite) TestGetStartTimeBadPid(c *check.C) {
pid += 1
}

startTime, err := getStartTimeForPid(uint32(pid))
startTime, err := getStartTimeForPid(int32(pid))
c.Assert(err, check.ErrorMatches, "open .*: no such file or directory")
c.Check(startTime, check.Equals, uint64(0))
}
Expand Down

0 comments on commit 6a05658

Please sign in to comment.