Skip to content

Commit

Permalink
Change IoVecBuffer[Mut] len to u32
Browse files Browse the repository at this point in the history
This commit changes the iovec len primitive to match descriptor chain's
(u32). This removes some ugly casting and potential overflow problems,
and allows us to upcast when needed in a non-lossy manor.

Signed-off-by: Brandon Pike <bpike@amazon.com>
  • Loading branch information
brandonpike authored and root committed Apr 24, 2024
1 parent fd40204 commit 2c51bdd
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 34 deletions.
34 changes: 20 additions & 14 deletions src/vmm/src/devices/virtio/iovec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub enum IoVecError {
WriteOnlyDescriptor,
/// Tried to create an 'IoVecMut` from a read-only descriptor chain
ReadOnlyDescriptor,
/// Tried to create an `IoVec` or `IoVecMut` from a descriptor chain that was too large
OverflowedDescriptor,
/// Guest memory error: {0}
GuestMemory(#[from] GuestMemoryError),
}
Expand All @@ -40,14 +42,14 @@ pub struct IoVecBuffer {
// container of the memory regions included in this IO vector
vecs: IoVecVec,
// Total length of the IoVecBuffer
len: usize,
len: u32,
}

impl IoVecBuffer {
/// Create an `IoVecBuffer` from a `DescriptorChain`
pub fn from_descriptor_chain(head: DescriptorChain) -> Result<Self, IoVecError> {
let mut vecs = IoVecVec::new();
let mut len = 0usize;
let mut len = 0u32;

let mut next_descriptor = Some(head);
while let Some(desc) = next_descriptor {
Expand All @@ -68,7 +70,9 @@ impl IoVecBuffer {
iov_base,
iov_len: desc.len as size_t,
});
len += desc.len as usize;
len = len
.checked_add(desc.len)
.ok_or(IoVecError::OverflowedDescriptor)?;

next_descriptor = desc.next_descriptor();
}
Expand All @@ -77,7 +81,7 @@ impl IoVecBuffer {
}

/// Get the total length of the memory regions covered by this `IoVecBuffer`
pub(crate) fn len(&self) -> usize {
pub(crate) fn len(&self) -> u32 {
self.len
}

Expand Down Expand Up @@ -106,7 +110,7 @@ impl IoVecBuffer {
mut buf: &mut [u8],
offset: usize,
) -> Result<(), VolatileMemoryError> {
if offset < self.len() {
if offset < self.len() as usize {
let expected = buf.len();
let bytes_read = self.read_volatile_at(&mut buf, offset, expected)?;

Expand Down Expand Up @@ -188,14 +192,14 @@ pub struct IoVecBufferMut {
// container of the memory regions included in this IO vector
vecs: IoVecVec,
// Total length of the IoVecBufferMut
len: usize,
len: u32,
}

impl IoVecBufferMut {
/// Create an `IoVecBufferMut` from a `DescriptorChain`
pub fn from_descriptor_chain(head: DescriptorChain) -> Result<Self, IoVecError> {
let mut vecs = IoVecVec::new();
let mut len = 0usize;
let mut len = 0u32;

for desc in head {
if !desc.is_write_only() {
Expand All @@ -217,14 +221,16 @@ impl IoVecBufferMut {
iov_base,
iov_len: desc.len as size_t,
});
len += desc.len as usize;
len = len
.checked_add(desc.len)
.ok_or(IoVecError::OverflowedDescriptor)?;
}

Ok(Self { vecs, len })
}

/// Get the total length of the memory regions covered by this `IoVecBuffer`
pub(crate) fn len(&self) -> usize {
pub(crate) fn len(&self) -> u32 {
self.len
}

Expand All @@ -244,7 +250,7 @@ impl IoVecBufferMut {
mut buf: &[u8],
offset: usize,
) -> Result<(), VolatileMemoryError> {
if offset < self.len() {
if offset < self.len() as usize {
let expected = buf.len();
let bytes_written = self.write_volatile_at(&mut buf, offset, expected)?;

Expand Down Expand Up @@ -335,18 +341,18 @@ mod tests {
iov_len: buf.len(),
}]
.into(),
len: buf.len(),
len: buf.len().try_into().unwrap(),
}
}
}

impl<'a> From<Vec<&'a [u8]>> for IoVecBuffer {
fn from(buffer: Vec<&'a [u8]>) -> Self {
let mut len = 0;
let mut len = 0_u32;
let vecs = buffer
.into_iter()
.map(|slice| {
len += slice.len();
len += TryInto::<u32>::try_into(slice.len()).unwrap();
iovec {
iov_base: slice.as_ptr() as *mut c_void,
iov_len: slice.len(),
Expand All @@ -366,7 +372,7 @@ mod tests {
iov_len: buf.len(),
}]
.into(),
len: buf.len(),
len: buf.len().try_into().unwrap(),
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/vmm/src/devices/virtio/net/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ impl Net {

if let Some(ns) = mmds_ns {
if ns.is_mmds_frame(headers) {
let mut frame = vec![0u8; frame_iovec.len() - vnet_hdr_len()];
let mut frame = vec![0u8; frame_iovec.len() as usize - vnet_hdr_len()];
// Ok to unwrap here, because we are passing a buffer that has the exact size
// of the `IoVecBuffer` minus the VNET headers.
frame_iovec
Expand All @@ -472,7 +472,7 @@ impl Net {
METRICS.mmds.rx_accepted.inc();

// MMDS frames are not accounted by the rate limiter.
Self::rate_limiter_replenish_op(rate_limiter, frame_iovec.len() as u64);
Self::rate_limiter_replenish_op(rate_limiter, u64::from(frame_iovec.len()));

// MMDS consumed the frame.
return Ok(true);
Expand All @@ -493,7 +493,7 @@ impl Net {
let _metric = net_metrics.tap_write_agg.record_latency_metrics();
match Self::write_tap(tap, frame_iovec) {
Ok(_) => {
let len = frame_iovec.len() as u64;
let len = u64::from(frame_iovec.len());
net_metrics.tx_bytes_count.add(len);
net_metrics.tx_packets_count.inc();
net_metrics.tx_count.inc();
Expand Down Expand Up @@ -609,7 +609,7 @@ impl Net {
};

// We only handle frames that are up to MAX_BUFFER_SIZE
if buffer.len() > MAX_BUFFER_SIZE {
if buffer.len() as usize > MAX_BUFFER_SIZE {
error!("net: received too big frame from driver");
self.metrics.tx_malformed_frames.inc();
tx_queue
Expand All @@ -618,7 +618,7 @@ impl Net {
continue;
}

if !Self::rate_limiter_consume_op(&mut self.tx_rate_limiter, buffer.len() as u64) {
if !Self::rate_limiter_consume_op(&mut self.tx_rate_limiter, u64::from(buffer.len())) {
tx_queue.undo_pop();
self.metrics.tx_rate_limiter_throttled.inc();
break;
Expand Down
2 changes: 1 addition & 1 deletion src/vmm/src/devices/virtio/net/tap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ pub mod tests {

tap.write_iovec(&scattered).unwrap();

let mut read_buf = vec![0u8; scattered.len()];
let mut read_buf = vec![0u8; scattered.len().try_into().unwrap()];
assert!(tap_traffic_simulator.pop_rx_packet(&mut read_buf));
assert_eq!(
&read_buf[..PAYLOAD_SIZE - VNET_HDR_SIZE],
Expand Down
6 changes: 3 additions & 3 deletions src/vmm/src/devices/virtio/rng/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ impl Entropy {
return Ok(0);
}

let mut rand_bytes = vec![0; iovec.len()];
let mut rand_bytes = vec![0; iovec.len() as usize];
rand::fill(&mut rand_bytes).map_err(|err| {
METRICS.host_rng_fails.inc();
err
})?;

// It is ok to unwrap here. We are writing `iovec.len()` bytes at offset 0.
iovec.write_all_volatile_at(&rand_bytes, 0).unwrap();
Ok(iovec.len().try_into().unwrap())
Ok(iovec.len())
}

fn process_entropy_queue(&mut self) {
Expand All @@ -142,7 +142,7 @@ impl Entropy {
// Check for available rate limiting budget.
// If not enough budget is available, leave the request descriptor in the queue
// to handle once we do have budget.
if !Self::rate_limit_request(&mut self.rate_limiter, iovec.len() as u64) {
if !Self::rate_limit_request(&mut self.rate_limiter, u64::from(iovec.len())) {
debug!("entropy: throttling entropy queue");
METRICS.entropy_rate_limiter_throttled.inc();
self.queues[RNG_QUEUE].undo_pop();
Expand Down
5 changes: 4 additions & 1 deletion src/vmm/src/devices/virtio/vsock/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ mod defs {
#[rustfmt::skip]
pub enum VsockError {
/** The total length of the descriptor chain ({0}) is too short to hold a packet of length {1} + header */
DescChainTooShortForPacket(usize, u32),
DescChainTooShortForPacket(u32, u32),
/// Empty queue
EmptyQueue,
/// EventFd error: {0}
Expand All @@ -122,6 +122,8 @@ pub enum VsockError {
/** The total length of the descriptor chain ({0}) is less than the number of bytes required\
to hold a vsock packet header.*/
DescChainTooShortForHeader(usize),
/// The descriptor chain length was greater than the max ([u32::MAX])
DescChainOverflow,
/// The vsock header `len` field holds an invalid value: {0}
InvalidPktLen(u32),
/// A data fetch was attempted when no data was available.
Expand All @@ -144,6 +146,7 @@ impl From<IoVecError> for VsockError {
IoVecError::WriteOnlyDescriptor => VsockError::UnreadableDescriptor,
IoVecError::ReadOnlyDescriptor => VsockError::UnwritableDescriptor,
IoVecError::GuestMemory(err) => VsockError::GuestMemoryMmap(err),
IoVecError::OverflowedDescriptor => VsockError::DescChainOverflow,

Check warning on line 149 in src/vmm/src/devices/virtio/vsock/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/vmm/src/devices/virtio/vsock/mod.rs#L149

Added line #L149 was not covered by tests
}
}
}
Expand Down
19 changes: 9 additions & 10 deletions src/vmm/src/devices/virtio/vsock/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl VsockPacket {
return Err(VsockError::InvalidPktLen(hdr.len));
}

if (hdr.len as usize) > buffer.len() - VSOCK_PKT_HDR_SIZE as usize {
if (hdr.len) > buffer.len() - VSOCK_PKT_HDR_SIZE {
return Err(VsockError::DescChainTooShortForPacket(
buffer.len(),
hdr.len,
Expand All @@ -160,8 +160,8 @@ impl VsockPacket {
pub fn from_rx_virtq_head(chain: DescriptorChain) -> Result<Self, VsockError> {
let buffer = IoVecBufferMut::from_descriptor_chain(chain)?;

if buffer.len() < VSOCK_PKT_HDR_SIZE as usize {
return Err(VsockError::DescChainTooShortForHeader(buffer.len()));
if buffer.len() < VSOCK_PKT_HDR_SIZE {
return Err(VsockError::DescChainTooShortForHeader(buffer.len() as usize));
}

Ok(Self {
Expand Down Expand Up @@ -212,7 +212,7 @@ impl VsockPacket {
VsockPacketBuffer::Tx(ref iovec_buf) => iovec_buf.len(),
VsockPacketBuffer::Rx(ref iovec_buf) => iovec_buf.len(),
};
chain_length - VSOCK_PKT_HDR_SIZE as usize
(chain_length - VSOCK_PKT_HDR_SIZE) as usize
}

pub fn read_at_offset_from<T: ReadVolatile + Debug>(
Expand All @@ -225,8 +225,7 @@ impl VsockPacket {
VsockPacketBuffer::Tx(_) => Err(VsockError::UnwritableDescriptor),
VsockPacketBuffer::Rx(ref mut buffer) => {
if count
> buffer
.len()
> (buffer.len() as usize)
.saturating_sub(VSOCK_PKT_HDR_SIZE as usize)
.saturating_sub(offset)
{
Expand All @@ -249,8 +248,7 @@ impl VsockPacket {
match self.buffer {
VsockPacketBuffer::Tx(ref buffer) => {
if count
> buffer
.len()
> (buffer.len() as usize)
.saturating_sub(VSOCK_PKT_HDR_SIZE as usize)
.saturating_sub(offset)
{
Expand Down Expand Up @@ -427,9 +425,10 @@ mod tests {
.unwrap(),
)
.unwrap();

assert_eq!(
pkt.buf_size(),
handler_ctx.guest_txvq.dtable[1].len.get() as usize
TryInto::<u32>::try_into(pkt.buf_size()).unwrap(),
handler_ctx.guest_txvq.dtable[1].len.get()
);
}

Expand Down

0 comments on commit 2c51bdd

Please sign in to comment.