Skip to content

Commit

Permalink
Fix *decoding* of cmsgs and add ScmCredentials.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-schievink committed Jul 27, 2018
1 parent 237ec7b commit 9f0af44
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 90 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,8 @@ This project adheres to [Semantic Versioning](http://semver.org/).
([#922](https://github.com/nix-rust/nix/pull/922))
- Support the `SO_PEERCRED` socket option and the `UnixCredentials` type on all Linux and Android targets.
([#921](https://github.com/nix-rust/nix/pull/921))
- Added support for `SCM_CREDENTIALS`, allowing to send process credentials over Unix sockets.
([#923](https://github.com/nix-rust/nix/pull/923))

### Changed

Expand Down
214 changes: 124 additions & 90 deletions src/sys/socket/mod.rs
Expand Up @@ -205,6 +205,18 @@ cfg_if! {
}
impl Eq for UnixCredentials {}

impl From<libc::ucred> for UnixCredentials {
fn from(cred: libc::ucred) -> Self {
UnixCredentials(cred)
}
}

impl Into<libc::ucred> for UnixCredentials {
fn into(self) -> libc::ucred {
self.0
}
}

impl fmt::Debug for UnixCredentials {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("UnixCredentials")
Expand Down Expand Up @@ -359,7 +371,7 @@ impl<T> CmsgSpace<T> {
}
}

#[allow(missing_debug_implementations)]
#[derive(Debug)]
pub struct RecvMsg<'a> {
// The number of bytes received.
pub bytes: usize,
Expand All @@ -374,15 +386,14 @@ impl<'a> RecvMsg<'a> {
pub fn cmsgs(&self) -> CmsgIterator {
CmsgIterator {
buf: self.cmsg_buffer,
next: 0
}
}
}

#[allow(missing_debug_implementations)]
#[derive(Debug)]
pub struct CmsgIterator<'a> {
/// Control message buffer to decode from. Must adhere to cmsg alignment.
buf: &'a [u8],
next: usize,
}

impl<'a> Iterator for CmsgIterator<'a> {
Expand All @@ -392,53 +403,27 @@ impl<'a> Iterator for CmsgIterator<'a> {
// although we handle the invariants in slightly different places to
// get a better iterator interface.
fn next(&mut self) -> Option<ControlMessage<'a>> {
let sizeof_cmsghdr = mem::size_of::<cmsghdr>();
if self.buf.len() < sizeof_cmsghdr {
if self.buf.len() == 0 {
// The iterator assumes that `self.buf` always contains exactly the
// bytes we need, so we're at the end when the buffer is empty.
return None;
}
let cmsg: &'a cmsghdr = unsafe { &*(self.buf.as_ptr() as *const cmsghdr) };

// This check is only in the glibc implementation of CMSG_NXTHDR
// (although it claims the kernel header checks this), but such
// a structure is clearly invalid, either way.
let cmsg_len = cmsg.cmsg_len as usize;
if cmsg_len < sizeof_cmsghdr {
return None;
}
let len = cmsg_len - sizeof_cmsghdr;
let aligned_cmsg_len = if self.next == 0 {
// CMSG_FIRSTHDR
cmsg_len
} else {
// CMSG_NXTHDR
cmsg_align(cmsg_len)
// Safe if: `self.buf` is `cmsghdr`-aligned.
let cmsg: &'a cmsghdr = unsafe {
&*(self.buf[..mem::size_of::<cmsghdr>()].as_ptr() as *const cmsghdr)
};

let cmsg_len = cmsg.cmsg_len as usize;

// Advance our internal pointer.
if aligned_cmsg_len > self.buf.len() {
return None;
}
let cmsg_data = &self.buf[cmsg_align(sizeof_cmsghdr)..cmsg_len];
self.buf = &self.buf[aligned_cmsg_len..];
self.next += 1;

match (cmsg.cmsg_level, cmsg.cmsg_type) {
(libc::SOL_SOCKET, libc::SCM_RIGHTS) => unsafe {
Some(ControlMessage::ScmRights(
slice::from_raw_parts(cmsg_data.as_ptr() as *const _,
cmsg_data.len() / mem::size_of::<RawFd>())))
},
(libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => unsafe {
Some(ControlMessage::ScmTimestamp(
&*(cmsg_data.as_ptr() as *const _)))
},
(_, _) => unsafe {
Some(ControlMessage::Unknown(UnknownCmsg(
cmsg,
slice::from_raw_parts(
cmsg_data.as_ptr() as *const _,
len))))
}
let cmsg_data = &self.buf[cmsg_align(mem::size_of::<cmsghdr>())..cmsg_len];
self.buf = &self.buf[cmsg_align(cmsg_len)..];

// Safe if: `cmsg_data` contains the expected (amount of) content. This
// is verified by the kernel.
unsafe {
Some(ControlMessage::decode_from(cmsg, cmsg_data))
}
}
}
Expand All @@ -459,6 +444,20 @@ pub enum ControlMessage<'a> {
/// or fail with `EINVAL`. Instead, you can put all fds to be passed into a single `ScmRights`
/// message.
ScmRights(&'a [RawFd]),
/// A message of type `SCM_CREDENTIALS`, containing the pid, uid and gid of
/// a process connected to the socket.
///
/// This is similar to the socket option `SO_PEERCRED`, but requires a
/// process to explicitly send its credentials. A process running as root is
/// allowed to specify any credentials, while credentials sent by other
/// processes are verified by the kernel.
///
/// For further information, please refer to the
/// [`unix(7)`](http://man7.org/linux/man-pages/man7/unix.7.html) man page.
// FIXME: When `#[repr(transparent)]` is stable, use it on `UnixCredentials`
// and put that in here instead of a raw ucred.
#[cfg(any(target_os = "android", target_os = "linux"))]
ScmCredentials(&'a libc::ucred),
/// A message of type `SCM_TIMESTAMP`, containing the time the
/// packet was received by the kernel.
///
Expand Down Expand Up @@ -527,6 +526,7 @@ pub enum ControlMessage<'a> {
/// nix::unistd::close(in_socket).unwrap();
/// ```
ScmTimestamp(&'a TimeVal),
/// Catch-all variant for unimplemented cmsg types.
#[doc(hidden)]
Unknown(UnknownCmsg<'a>),
}
Expand Down Expand Up @@ -558,6 +558,10 @@ impl<'a> ControlMessage<'a> {
ControlMessage::ScmRights(fds) => {
mem::size_of_val(fds)
},
#[cfg(any(target_os = "android", target_os = "linux"))]
ControlMessage::ScmCredentials(creds) => {
mem::size_of_val(creds)
}
ControlMessage::ScmTimestamp(t) => {
mem::size_of_val(t)
},
Expand All @@ -567,57 +571,87 @@ impl<'a> ControlMessage<'a> {
}
}

/// Returns the value to put into the `cmsg_type` field of the header.
fn cmsg_type(&self) -> libc::c_int {
match *self {
ControlMessage::ScmRights(_) => libc::SCM_RIGHTS,
#[cfg(any(target_os = "android", target_os = "linux"))]
ControlMessage::ScmCredentials(_) => libc::SCM_CREDENTIALS,
ControlMessage::ScmTimestamp(_) => libc::SCM_TIMESTAMP,
ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_type,
}
}

// Unsafe: start and end of buffer must be cmsg_align'd. Updates
// the provided slice; panics if the buffer is too small.
unsafe fn encode_into(&self, buf: &mut [u8]) {
match *self {
ControlMessage::ScmRights(fds) => {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: libc::SCM_RIGHTS,
..mem::uninitialized()
};
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(fds, buf);

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::ScmTimestamp(t) => {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: libc::SCM_TIMESTAMP,
..mem::uninitialized()
};
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(t, buf);

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
},
ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => {
let buf = copy_bytes(orig_cmsg, buf);
let final_buf = if let ControlMessage::Unknown(ref cmsg) = *self {
let &UnknownCmsg(orig_cmsg, bytes) = cmsg;

let buf = copy_bytes(orig_cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) -
mem::size_of_val(&orig_cmsg);
let buf = pad_bytes(padlen, buf);
let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) -
mem::size_of_val(&orig_cmsg);
let buf = pad_bytes(padlen, buf);

let buf = copy_bytes(bytes, buf);
copy_bytes(bytes, buf)
} else {
let cmsg = cmsghdr {
cmsg_len: self.len() as _,
cmsg_level: libc::SOL_SOCKET,
cmsg_type: self.cmsg_type(),
..mem::zeroed() // zero out platform-dependent padding fields
};
let buf = copy_bytes(&cmsg, buf);

let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
mem::size_of_val(&cmsg);
let buf = pad_bytes(padlen, buf);

match *self {
ControlMessage::ScmRights(fds) => {
copy_bytes(fds, buf)
},
#[cfg(any(target_os = "android", target_os = "linux"))]
ControlMessage::ScmCredentials(creds) => {
copy_bytes(creds, buf)
}
ControlMessage::ScmTimestamp(t) => {
copy_bytes(t, buf)
},
ControlMessage::Unknown(_) => unreachable!(),
}
};

let padlen = self.space() - self.len();
pad_bytes(padlen, buf);
let padlen = self.space() - self.len();
pad_bytes(padlen, final_buf);
}

/// Decodes a `ControlMessage` from raw bytes.
///
/// This is only safe to call if the data is correct for the message type
/// specified in the header. Normally, the kernel ensures that this is the
/// case. "Correct" in this case includes correct length, alignment and
/// actual content.
unsafe fn decode_from(header: &'a cmsghdr, data: &'a [u8]) -> ControlMessage<'a> {
match (header.cmsg_level, header.cmsg_type) {
(libc::SOL_SOCKET, libc::SCM_RIGHTS) => {
ControlMessage::ScmRights(
slice::from_raw_parts(data.as_ptr() as *const _,
data.len() / mem::size_of::<RawFd>()))
},
#[cfg(any(target_os = "android", target_os = "linux"))]
(libc::SOL_SOCKET, libc::SCM_CREDENTIALS) => {
ControlMessage::ScmCredentials(
&*(data.as_ptr() as *const _)
)
}
(libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => {
ControlMessage::ScmTimestamp(
&*(data.as_ptr() as *const _))
},
(_, _) => {
ControlMessage::Unknown(UnknownCmsg(header, data))
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/sys/socket/sockopt.rs
Expand Up @@ -255,6 +255,8 @@ sockopt_impl!(Both, BindAny, libc::SOL_SOCKET, libc::SO_BINDANY, bool);
sockopt_impl!(Both, BindAny, libc::IPPROTO_IP, libc::IP_BINDANY, bool);
#[cfg(target_os = "linux")]
sockopt_impl!(Both, Mark, libc::SOL_SOCKET, libc::SO_MARK, u32);
#[cfg(any(target_os = "android", target_os = "linux"))]
sockopt_impl!(Both, PassCred, libc::SOL_SOCKET, libc::SO_PASSCRED, bool);

/*
*
Expand Down
15 changes: 15 additions & 0 deletions src/unistd.rs
Expand Up @@ -48,6 +48,11 @@ impl Uid {
pub fn is_root(&self) -> bool {
*self == ROOT
}

/// Get the raw `uid_t` wrapped by `self`.
pub fn as_raw(&self) -> uid_t {
self.0
}
}

impl From<Uid> for uid_t {
Expand Down Expand Up @@ -87,6 +92,11 @@ impl Gid {
pub fn effective() -> Self {
getegid()
}

/// Get the raw `gid_t` wrapped by `self`.
pub fn as_raw(&self) -> gid_t {
self.0
}
}

impl From<Gid> for gid_t {
Expand Down Expand Up @@ -123,6 +133,11 @@ impl Pid {
pub fn parent() -> Self {
getppid()
}

/// Get the raw `pid_t` wrapped by `self`.
pub fn as_raw(&self) -> pid_t {
self.0
}
}

impl From<Pid> for pid_t {
Expand Down

0 comments on commit 9f0af44

Please sign in to comment.