diff --git a/vm/devices/net/net_mana/src/lib.rs b/vm/devices/net/net_mana/src/lib.rs index 1361280f1a..7a03b7a49e 100644 --- a/vm/devices/net/net_mana/src/lib.rs +++ b/vm/devices/net/net_mana/src/lib.rs @@ -4,6 +4,8 @@ #![forbid(unsafe_code)] #![expect(missing_docs)] +mod test; + use anyhow::Context as _; use async_trait::async_trait; use futures::FutureExt; @@ -85,6 +87,12 @@ const SPLIT_HEADER_BOUNCE_PAGE_LIMIT: u32 = 4; const RX_BOUNCE_BUFFER_PAGE_LIMIT: u32 = 64; const TX_BOUNCE_BUFFER_PAGE_LIMIT: u32 = 64; +#[cfg(test)] +#[derive(Debug, Default, Clone, Copy)] +pub struct ManaTestConfiguration { + pub allow_lso_pkt_with_one_sge: bool, +} + pub struct ManaEndpoint { spawner: Box, vport: Arc>, @@ -93,6 +101,8 @@ pub struct ManaEndpoint { receive_update: mesh::Receiver, queue_tracker: Arc<(AtomicUsize, SlimEvent)>, bounce_buffer: bool, + #[cfg(test)] + test_configuration: ManaTestConfiguration, } struct QueueResources { @@ -126,8 +136,15 @@ impl ManaEndpoint { GuestDmaMode::DirectDma => false, GuestDmaMode::BounceBuffer => true, }, + #[cfg(test)] + test_configuration: ManaTestConfiguration::default(), } } + + #[cfg(test)] + fn set_test_configuration(&mut self, config: ManaTestConfiguration) { + self.test_configuration = config; + } } fn inspect_mana_stats(stats: &ManaQueryStatisticsResponse, req: inspect::Request<'_>) { @@ -362,6 +379,8 @@ impl ManaEndpoint { tx_max: tx_max as usize, force_tx_header_bounce: false, stats: QueueStats::default(), + #[cfg(test)] + test_configuration: self.test_configuration, }; self.queue_tracker.0.fetch_add(1, Ordering::AcqRel); queue.rx_avail(initial_rx); @@ -585,6 +604,9 @@ pub struct ManaQueue { force_tx_header_bounce: bool, stats: QueueStats, + + #[cfg(test)] + test_configuration: ManaTestConfiguration, } impl Drop for ManaQueue { @@ -1156,9 +1178,9 @@ impl ManaQueue { }; builder.push_sge(sge); } else { - let mut header_len = head.len; - let (header_segment_count, partial_bytes) = if meta.flags.offload_tcp_segmentation() { - header_len = (meta.l2_len as u16 + meta.l3_len + meta.l4_len as u16) as u32; + let (segments, segment_offset) = if meta.flags.offload_tcp_segmentation() { + // For LSO, GDMA requires that SGE0 should only contain the header. + let header_len = (meta.l2_len as u16 + meta.l3_len + meta.l4_len as u16) as u32; if header_len > PAGE_SIZE32 { tracelimit::error_ratelimited!( header_len, @@ -1170,108 +1192,105 @@ impl ManaQueue { builder.set_client_oob_in_sgl(header_len as u8); builder.set_gd_client_unit_data(meta.max_tcp_segment_size); - let mut partial_bytes = 0; - if header_len > head.len || self.force_tx_header_bounce { - let mut header_bytes_remaining = header_len; - let mut hdr_idx = 0; - while hdr_idx < segments.len() { - if header_bytes_remaining <= segments[hdr_idx].len { - if segments[hdr_idx].len > header_bytes_remaining { - partial_bytes = header_bytes_remaining; + let (head_iova, used_segments, used_segments_len) = + if header_len > head.len || self.force_tx_header_bounce { + let mut copy = match bounce_buffer.allocate(header_len) { + Ok(buf) => buf, + Err(err) => { + tracelimit::error_ratelimited!( + err = &err as &dyn std::error::Error, + header_len, + "Failed to bounce buffer split header" + ); + // Drop the packet + return Ok(None); } - header_bytes_remaining = 0; - break; + }; + + let mut data = copy.as_slice(); + let mut used_segments = 0; + let mut used_segments_len = 0; + for segment in segments { + let (this, rest) = data.split_at(data.len().min(segment.len as usize)); + self.guest_memory.read_to_atomic(segment.gpa, this)?; + data = rest; + if this.len() < segment.len as usize { + break; + } + used_segments += 1; + used_segments_len += segment.len; } - header_bytes_remaining -= segments[hdr_idx].len; - hdr_idx += 1; - } - if header_bytes_remaining > 0 { - tracelimit::error_ratelimited!( - header_len, - missing_header_bytes = header_bytes_remaining, - "Invalid split header" - ); - // Drop the packet - return Ok(None); - } - ((hdr_idx + 1), partial_bytes) - } else { - if head.len > header_len { - partial_bytes = header_len; - } - (1, partial_bytes) - } - } else { - (1, 0) - }; + if !data.is_empty() { + tracelimit::error_ratelimited!( + header_len, + missing_header_bytes = data.len(), + "Invalid split header" + ); + // Drop the packet + return Ok(None); + }; + let ContiguousBufferInUse { gpa, .. } = copy.reserve(); + (gpa, used_segments, used_segments_len) + } else if header_len < head.len { + (self.guest_memory.iova(head.gpa).unwrap(), 0, 0) + } else { + (self.guest_memory.iova(head.gpa).unwrap(), 1, header_len) + }; - let mut last_segment_bounced = false; - // The header needs to be contiguous. - let head_iova = if header_len > head.len || self.force_tx_header_bounce { - let mut copy = match bounce_buffer.allocate(header_len) { - Ok(buf) => buf, - Err(err) => { - tracelimit::error_ratelimited!( - err = &err as &dyn std::error::Error, - header_len, - "Failed to bounce buffer split header" - ); - // Drop the packet - return Ok(None); - } - }; - let mut next = copy.as_slice(); - for hdr_seg in &segments[..header_segment_count] { - let len = std::cmp::min(next.len(), hdr_seg.len as usize); - self.guest_memory - .read_to_atomic(hdr_seg.gpa, &next[..len])?; - next = &next[len..]; + // Drop the LSO packet if it only has a header segment. + // In production builds, this check always runs. + // For tests, use test hooks to bypass this check for allowing code coverage. + #[cfg(not(test))] + let check_lso_segment_count = true; + #[cfg(test)] + let check_lso_segment_count = !self.test_configuration.allow_lso_pkt_with_one_sge; + if check_lso_segment_count && used_segments == segments.len() { + return Ok(None); } - last_segment_bounced = true; - let ContiguousBufferInUse { gpa, .. } = copy.reserve(); - gpa + + // With LSO, GDMA requires that the first segment should only contain + // the header and should not exceed 256 bytes. Otherwise, it treats + // the WQE as "corrupt", disables the queue and return GDMA error. + builder.push_sge(Sge { + address: head_iova, + mem_key: self.mem_key, + size: header_len, + }); + (&segments[used_segments..], header_len - used_segments_len) } else { - self.guest_memory.iova(head.gpa).unwrap() + // Just send the segments as they are. + (segments, 0) }; // Hardware limit for short oob is 31. Max WQE size is 512 bytes. // Hardware limit for long oob is 30. let hardware_segment_limit = if short_format { 31 } else { 30 }; - let mut sge = Sge { - address: head_iova, - mem_key: self.mem_key, - size: header_len, - }; - if partial_bytes > 0 { - last_segment_bounced = false; - let shared_seg = &segments[header_segment_count - 1]; - builder.push_sge(sge); - sge = Sge { - address: self - .guest_memory - .iova(shared_seg.gpa) - .unwrap() - .wrapping_add(partial_bytes as u64), - mem_key: self.mem_key, - size: shared_seg.len - partial_bytes, - }; - } - - let segment_count = - builder.sge_count() + 1 + meta.segment_count - header_segment_count as u8; + let segment_count = builder.sge_count() + segments.len() as u8; if segment_count <= hardware_segment_limit { - builder.push_sge(sge); - for tail in &segments[header_segment_count..] { + let mut segment_offset = segment_offset; + for tail in segments { builder.push_sge(Sge { - address: self.guest_memory.iova(tail.gpa).unwrap(), + address: self + .guest_memory + .iova(tail.gpa.wrapping_add(segment_offset.into())) + .unwrap(), mem_key: self.mem_key, - size: tail.len, + size: tail.len.wrapping_sub(segment_offset), }); + segment_offset = 0; } } else { + let gpa0 = segments[0].gpa.wrapping_add(segment_offset.into()); + let mut sge = Sge { + address: self.guest_memory.iova(gpa0).unwrap(), + mem_key: self.mem_key, + size: segments[0].len.wrapping_sub(segment_offset), + }; + + let mut last_segment_bounced = false; let mut segment_count = segment_count; - for tail_idx in header_segment_count..segments.len() { - let tail = &segments[tail_idx]; + let mut last_segment_gpa = gpa0; + for tail in &segments[1..] { // Try to coalesce segments together if there are more than the hardware allows. // TODO: Could use more expensive techniques such as // copying portions of segments to fill an entire @@ -1288,7 +1307,6 @@ impl ManaQueue { // There is enough room to coalesce the current // segment with the previous. The previous segment // is not yet bounced, so bounce it now. - let last_segment_gpa = segments[tail_idx - 1].gpa; let mut copy = bounce_buffer.allocate(sge.size).unwrap(); self.guest_memory .read_to_atomic(last_segment_gpa, copy.as_slice())?; @@ -1329,6 +1347,7 @@ impl ManaQueue { mem_key: self.mem_key, size: tail.len, }; + last_segment_gpa = tail.gpa; } builder.push_sge(sge); self.stats.tx_packets_coalesced.increment(); @@ -1336,6 +1355,7 @@ impl ManaQueue { assert!(builder.sge_count() <= hardware_segment_limit); } + let wqe_len = builder .finish() .expect("caller ensured enough space for a max sized WQE"); @@ -1506,319 +1526,3 @@ impl Inspect for ContiguousBufferManager { .counter("failed_allocations", self.failed_allocations); } } - -#[cfg(test)] -mod tests { - use crate::GuestDmaMode; - use crate::ManaEndpoint; - use crate::QueueStats; - use chipset_device::mmio::ExternallyManagedMmioIntercepts; - use gdma::VportConfig; - use gdma_defs::bnic::ManaQueryDeviceCfgResp; - use mana_driver::mana::ManaDevice; - use net_backend::Endpoint; - use net_backend::QueueConfig; - use net_backend::RxId; - use net_backend::TxId; - use net_backend::TxSegment; - use net_backend::loopback::LoopbackEndpoint; - use pal_async::DefaultDriver; - use pal_async::async_test; - use pci_core::msi::MsiInterruptSet; - use std::future::poll_fn; - use test_with_tracing::test; - use user_driver_emulated_mock::DeviceTestMemory; - use user_driver_emulated_mock::EmulatedDevice; - use vmcore::vm_task::SingleDriverBackend; - use vmcore::vm_task::VmTaskDriverSource; - - /// Constructs a mana emulator backed by the loopback endpoint, then hooks a - /// mana driver up to it, puts the net_mana endpoint on top of that, and - /// ensures that packets can be sent and received. - #[async_test] - async fn test_endpoint_direct_dma(driver: DefaultDriver) { - send_test_packet(driver, GuestDmaMode::DirectDma, 1138, 1).await; - } - - #[async_test] - async fn test_endpoint_bounce_buffer(driver: DefaultDriver) { - send_test_packet(driver, GuestDmaMode::BounceBuffer, 1138, 1).await; - } - - #[async_test] - async fn test_segment_coalescing(driver: DefaultDriver) { - // 34 segments of 60 bytes each == 2040 - send_test_packet(driver, GuestDmaMode::DirectDma, 2040, 34).await; - } - - #[async_test] - async fn test_segment_coalescing_many(driver: DefaultDriver) { - // 128 segments of 16 bytes each == 2048 - send_test_packet(driver, GuestDmaMode::DirectDma, 2048, 128).await; - } - - async fn send_test_packet( - driver: DefaultDriver, - dma_mode: GuestDmaMode, - packet_len: usize, - num_segments: usize, - ) { - let tx_id = 1; - let tx_metadata = net_backend::TxMetadata { - id: TxId(tx_id), - segment_count: num_segments as u8, - len: packet_len as u32, - ..Default::default() - }; - let expected_num_received_packets = 1; - let (data_to_send, tx_segments) = - build_tx_segments(packet_len, num_segments, tx_metadata.clone()); - - test_endpoint( - driver, - dma_mode, - packet_len, - tx_segments, - data_to_send, - expected_num_received_packets, - ) - .await; - } - - fn build_tx_segments( - packet_len: usize, - num_segments: usize, - tx_metadata: net_backend::TxMetadata, - ) -> (Vec, Vec) { - let data_to_send = (0..packet_len).map(|v| v as u8).collect::>(); - - let mut tx_segments = Vec::new(); - let segment_len = packet_len / num_segments; - assert_eq!(packet_len % num_segments, 0); - assert_eq!(data_to_send.len(), packet_len); - - tx_segments.push(TxSegment { - ty: net_backend::TxSegmentType::Head(tx_metadata.clone()), - gpa: 0, - len: segment_len as u32, - }); - - for j in 0..(num_segments - 1) { - let gpa = (j + 1) * segment_len; - tx_segments.push(TxSegment { - ty: net_backend::TxSegmentType::Tail, - gpa: gpa as u64, - len: segment_len as u32, - }); - } - - assert_eq!(tx_segments.len(), num_segments); - (data_to_send, tx_segments) - } - - async fn test_endpoint( - driver: DefaultDriver, - dma_mode: GuestDmaMode, - packet_len: usize, - tx_segments: Vec, - data_to_send: Vec, - expected_num_received_packets: usize, - ) -> QueueStats { - let tx_id = 1; - let pages = 256; // 1MB - let allow_dma = dma_mode == GuestDmaMode::DirectDma; - let mem: DeviceTestMemory = DeviceTestMemory::new(pages * 2, allow_dma, "test_endpoint"); - let payload_mem = mem.payload_mem(); - - let mut msi_set = MsiInterruptSet::new(); - let device = gdma::GdmaDevice::new( - &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())), - mem.guest_memory(), - &mut msi_set, - vec![VportConfig { - mac_address: [1, 2, 3, 4, 5, 6].into(), - endpoint: Box::new(LoopbackEndpoint::new()), - }], - &mut ExternallyManagedMmioIntercepts, - ); - let device = EmulatedDevice::new(device, msi_set, mem.dma_client()); - let dev_config = ManaQueryDeviceCfgResp { - pf_cap_flags1: 0.into(), - pf_cap_flags2: 0, - pf_cap_flags3: 0, - pf_cap_flags4: 0, - max_num_vports: 1, - reserved: 0, - max_num_eqs: 64, - }; - let thing = ManaDevice::new(&driver, device, 1, 1).await.unwrap(); - let vport = thing.new_vport(0, None, &dev_config).await.unwrap(); - let mut endpoint = ManaEndpoint::new(driver.clone(), vport, dma_mode).await; - let mut queues = Vec::new(); - let pool = net_backend::tests::Bufs::new(payload_mem.clone()); - endpoint - .get_queues( - vec![QueueConfig { - pool: Box::new(pool), - initial_rx: &(1..128).map(RxId).collect::>(), - driver: Box::new(driver.clone()), - }], - None, - &mut queues, - ) - .await - .unwrap(); - - payload_mem.write_at(0, &data_to_send).unwrap(); - - queues[0].tx_avail(tx_segments.as_slice()).unwrap(); - - // Poll for completion - let mut rx_packets = [RxId(0); 2]; - let mut rx_packets_n = 0; - let mut tx_done = [TxId(0); 2]; - let mut tx_done_n = 0; - while rx_packets_n == 0 { - poll_fn(|cx| queues[0].poll_ready(cx)).await; - rx_packets_n += queues[0].rx_poll(&mut rx_packets[rx_packets_n..]).unwrap(); - // GDMA Errors generate a TryReturn error, ignored here. - tx_done_n += queues[0].tx_poll(&mut tx_done[tx_done_n..]).unwrap_or(0); - if expected_num_received_packets == 0 { - break; - } - } - assert_eq!(rx_packets_n, expected_num_received_packets); - - if expected_num_received_packets == 0 { - // If no packets were received, exit. - let stats = get_queue_stats(queues[0].queue_stats()); - drop(queues); - endpoint.stop().await; - return stats; - } - - // Check tx - assert_eq!(tx_done_n, 1); - assert_eq!(tx_done[0].0, tx_id); - - // Check rx - assert_eq!(rx_packets[0].0, 1); - let rx_id = rx_packets[0]; - - let mut received_data = vec![0; packet_len]; - payload_mem - .read_at(2048 * rx_id.0 as u64, &mut received_data) - .unwrap(); - assert_eq!(received_data.len(), packet_len); - assert_eq!(&received_data[..], data_to_send, "{:?}", rx_id); - - let stats = get_queue_stats(queues[0].queue_stats()); - drop(queues); - endpoint.stop().await; - stats - } - - #[async_test] - async fn test_vport_with_query_filter_state(driver: DefaultDriver) { - let pages = 512; // 2MB - let mem = DeviceTestMemory::new(pages, false, "test_vport_with_query_filter_state"); - let mut msi_set = MsiInterruptSet::new(); - let device = gdma::GdmaDevice::new( - &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())), - mem.guest_memory(), - &mut msi_set, - vec![VportConfig { - mac_address: [1, 2, 3, 4, 5, 6].into(), - endpoint: Box::new(LoopbackEndpoint::new()), - }], - &mut ExternallyManagedMmioIntercepts, - ); - let dma_client = mem.dma_client(); - let device = EmulatedDevice::new(device, msi_set, dma_client); - let cap_flags1 = gdma_defs::bnic::BasicNicDriverFlags::new().with_query_filter_state(1); - let dev_config = ManaQueryDeviceCfgResp { - pf_cap_flags1: cap_flags1, - pf_cap_flags2: 0, - pf_cap_flags3: 0, - pf_cap_flags4: 0, - max_num_vports: 1, - reserved: 0, - max_num_eqs: 64, - }; - let thing = ManaDevice::new(&driver, device, 1, 1).await.unwrap(); - let _ = thing.new_vport(0, None, &dev_config).await.unwrap(); - } - - #[async_test] - async fn test_valid_packet(driver: DefaultDriver) { - let tx_id = 1; - let expected_num_received_packets = 1; - let num_segments = 1; - let packet_len = 1138; - let metadata = net_backend::TxMetadata { - id: TxId(tx_id), - segment_count: num_segments as u8, - len: packet_len as u32, - ..Default::default() - }; - - let (data_to_send, tx_segments) = build_tx_segments(packet_len, num_segments, metadata); - - let stats = test_endpoint( - driver, - GuestDmaMode::DirectDma, - packet_len, - tx_segments, - data_to_send, - expected_num_received_packets, - ) - .await; - - assert_eq!(stats.tx_packets.get(), 1, "tx_packets increase"); - assert_eq!(stats.rx_packets.get(), 1, "rx_packets increase"); - assert_eq!(stats.tx_errors.get(), 0, "tx_errors remain the same"); - assert_eq!(stats.rx_errors.get(), 0, "rx_errors remain the same"); - } - - #[async_test] - async fn test_tx_error_handling(driver: DefaultDriver) { - let tx_id = 1; - let expected_num_received_packets = 0; - let segment_count = 1; - let packet_len = 1138; - // LSO Enabled, but sending insufficient number of segments. - let metadata = net_backend::TxMetadata { - id: TxId(tx_id), - segment_count: segment_count as u8, - len: packet_len as u32, - flags: net_backend::TxFlags::new().with_offload_tcp_segmentation(true), - ..Default::default() - }; - - let (data_to_send, tx_segments) = build_tx_segments(packet_len, segment_count, metadata); - - let stats = test_endpoint( - driver, - GuestDmaMode::DirectDma, - packet_len, - tx_segments, - data_to_send, - expected_num_received_packets, - ) - .await; - - assert_eq!(stats.tx_errors.get(), 1, "tx_errors increase"); - assert_eq!(stats.tx_packets.get(), 0, "tx_packets stay the same"); - } - - fn get_queue_stats(queue_stats: Option<&dyn net_backend::BackendQueueStats>) -> QueueStats { - let queue_stats = queue_stats.unwrap(); - QueueStats { - rx_errors: queue_stats.rx_errors(), - tx_errors: queue_stats.tx_errors(), - rx_packets: queue_stats.rx_packets(), - tx_packets: queue_stats.tx_packets(), - ..Default::default() - } - } -} diff --git a/vm/devices/net/net_mana/src/test.rs b/vm/devices/net/net_mana/src/test.rs new file mode 100644 index 0000000000..0dfdbd65ea --- /dev/null +++ b/vm/devices/net/net_mana/src/test.rs @@ -0,0 +1,627 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#![cfg(test)] + +use crate::GuestDmaMode; +use crate::ManaEndpoint; +use crate::ManaTestConfiguration; +use crate::QueueStats; +use chipset_device::mmio::ExternallyManagedMmioIntercepts; +use gdma::VportConfig; +use gdma_defs::bnic::ManaQueryDeviceCfgResp; +use inspect_counters::Counter; +use mana_driver::mana::ManaDevice; +use mesh::CancelContext; +use mesh::CancelReason; +use net_backend::Endpoint; +use net_backend::QueueConfig; +use net_backend::RxId; +use net_backend::TxId; +use net_backend::TxSegment; +use net_backend::loopback::LoopbackEndpoint; +use pal_async::DefaultDriver; +use pal_async::async_test; +use pci_core::msi::MsiInterruptSet; +use std::future::poll_fn; +use std::time::Duration; +use test_with_tracing::test; +use user_driver_emulated_mock::DeviceTestMemory; +use user_driver_emulated_mock::EmulatedDevice; +use vmcore::vm_task::SingleDriverBackend; +use vmcore::vm_task::VmTaskDriverSource; + +const IPV4_HEADER_LENGTH: usize = 54; +const MAX_GDMA_SGE_PER_TX_PACKET: usize = 31; + +/// Constructs a mana emulator backed by the loopback endpoint, then hooks a +/// mana driver up to it, puts the net_mana endpoint on top of that, and +/// ensures that packets can be sent and received. +#[async_test] +async fn test_endpoint_direct_dma(driver: DefaultDriver) { + // 1 segment of 1138 bytes + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + 1138, + 1, + false, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; + + // 10 segments of 113 bytes each == 1130 + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + 1130, + 10, + false, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; +} + +#[async_test] +async fn test_endpoint_bounce_buffer(driver: DefaultDriver) { + // 1 segment of 1138 bytes + send_test_packet( + driver, + GuestDmaMode::BounceBuffer, + 1138, + 1, + false, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; +} + +#[async_test] +async fn test_segment_coalescing(driver: DefaultDriver) { + // 34 segments of 60 bytes each == 2040 + send_test_packet( + driver, + GuestDmaMode::DirectDma, + 2040, + 34, + false, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; +} + +#[async_test] +async fn test_segment_coalescing_many(driver: DefaultDriver) { + // 128 segments of 16 bytes each == 2048 + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + 2048, + 128, + false, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; +} + +#[async_test] +async fn test_packet_header_gt_head(driver: DefaultDriver) { + let num_segments = 32; + let packet_len = num_segments * (IPV4_HEADER_LENGTH - 10); + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + false, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; +} + +#[async_test] +async fn test_lso_header_eq_head(driver: DefaultDriver) { + // For the header (i.e. protocol) length to be equal to the head segment, make + // the segment length equal to the protocol header length. + let segment_len = IPV4_HEADER_LENGTH; + let num_segments = MAX_GDMA_SGE_PER_TX_PACKET - 10; + let packet_len = num_segments * segment_len; + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; + + // Caolescing test + let num_segments = MAX_GDMA_SGE_PER_TX_PACKET + 1; + let packet_len = num_segments * segment_len; + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; +} + +#[async_test] +async fn test_lso_header_lt_head(driver: DefaultDriver) { + // For the header (i.e. protocol) length to be less than the head segment, make + // the segment length greater than the protocol header length to force the header + // to fit in the first segment. + let segment_len = IPV4_HEADER_LENGTH + 6; + let num_segments = MAX_GDMA_SGE_PER_TX_PACKET - 10; + let packet_len = num_segments * segment_len; + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; + + // Coalescing test + let num_segments = MAX_GDMA_SGE_PER_TX_PACKET + 1; + let packet_len = num_segments * segment_len; + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; +} + +#[async_test] +async fn test_lso_header_gt_head(driver: DefaultDriver) { + // For the header (i.e. protocol) length to be greater than the head segment, make + // the segment length smaller than the protocol header length to force the header + // to not fit in the first segment. + let segment_len = IPV4_HEADER_LENGTH - 5; + let num_segments = MAX_GDMA_SGE_PER_TX_PACKET - 10; + let packet_len = num_segments * segment_len; + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; + + // Coalescing test + let num_segments = MAX_GDMA_SGE_PER_TX_PACKET + 1; + let packet_len = num_segments * segment_len; + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; +} + +#[async_test] +async fn test_lso_split_header(driver: DefaultDriver) { + // Invalid split header with header missing bytes (packet should get dropped). + // Keep the total packet length less than the protocol header length. + let segment_len = 1; + let num_segments = IPV4_HEADER_LENGTH - 10; + let packet_len = num_segments * segment_len; + let expected_stats = Some(QueueStats { + tx_packets: Counter::new(), + rx_packets: Counter::new(), + tx_errors: Counter::new(), + rx_errors: Counter::new(), + ..Default::default() + }); + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + expected_stats, + ) + .await; + + // Excessive splitting of the header, but keep the total packet length + // the same as the protocol header length. The header should get coalesced + // correctly back to one segment. With LSO, packet with one segment is + // invalid and the expected result is that the packet gets dropped. + let segment_len = 1; + let num_segments = IPV4_HEADER_LENGTH; + let packet_len = num_segments * segment_len; + let expected_stats = Some(QueueStats { + tx_packets: Counter::new(), + rx_packets: Counter::new(), + tx_errors: Counter::new(), + rx_errors: Counter::new(), + ..Default::default() + }); + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + expected_stats, + ) + .await; + + // Excessive splitting of the header, but total segment will be more than + // one after coalescing. The packet should be accepted. + let segment_len = 1; + let num_segments = IPV4_HEADER_LENGTH + 10; + let packet_len = num_segments * segment_len; + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; + + // Split headers such that the last header has both header and payload bytes. + // i.e. The header should not evenly split into segments. + let segment_len = 5; + assert!(!IPV4_HEADER_LENGTH.is_multiple_of(segment_len)); + let num_segments = IPV4_HEADER_LENGTH + 10; + let packet_len = num_segments * segment_len; + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + None, // Default expected stats + ) + .await; +} + +#[async_test] +async fn test_lso_segment_coalescing_only_header(driver: DefaultDriver) { + let segment_len = IPV4_HEADER_LENGTH; + let num_segments = 1; + let packet_len = num_segments * segment_len; + // An LSO packet without any payload is considered bad packet and should be dropped. + let expected_stats = Some(QueueStats { + tx_packets: Counter::new(), + rx_packets: Counter::new(), + tx_errors: Counter::new(), + rx_errors: Counter::new(), + ..Default::default() + }); + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + None, // Test config + expected_stats, + ) + .await; + + // Allow LSO with only header segment for test coverage and check that it + // results in error stats incremented. + let mut expected_stats = Some(QueueStats { + tx_packets: Counter::new(), + rx_packets: Counter::new(), + tx_errors: Counter::new(), + rx_errors: Counter::new(), + ..Default::default() + }); + + expected_stats.as_mut().unwrap().tx_errors.add(1); + let test_config = Some(ManaTestConfiguration { + allow_lso_pkt_with_one_sge: true, + }); + send_test_packet( + driver.clone(), + GuestDmaMode::DirectDma, + packet_len, + num_segments, + true, // LSO? + test_config, + expected_stats, + ) + .await; +} + +#[async_test] +async fn test_vport_with_query_filter_state(driver: DefaultDriver) { + let pages = 512; // 2MB + let mem = DeviceTestMemory::new(pages, false, "test_vport_with_query_filter_state"); + let mut msi_set = MsiInterruptSet::new(); + let device = gdma::GdmaDevice::new( + &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())), + mem.guest_memory(), + &mut msi_set, + vec![VportConfig { + mac_address: [1, 2, 3, 4, 5, 6].into(), + endpoint: Box::new(LoopbackEndpoint::new()), + }], + &mut ExternallyManagedMmioIntercepts, + ); + let dma_client = mem.dma_client(); + let device = EmulatedDevice::new(device, msi_set, dma_client); + let cap_flags1 = gdma_defs::bnic::BasicNicDriverFlags::new().with_query_filter_state(1); + let dev_config = ManaQueryDeviceCfgResp { + pf_cap_flags1: cap_flags1, + pf_cap_flags2: 0, + pf_cap_flags3: 0, + pf_cap_flags4: 0, + max_num_vports: 1, + reserved: 0, + max_num_eqs: 64, + }; + let thing = ManaDevice::new(&driver, device, 1, 1).await.unwrap(); + let _ = thing.new_vport(0, None, &dev_config).await.unwrap(); +} + +async fn send_test_packet( + driver: DefaultDriver, + dma_mode: GuestDmaMode, + packet_len: usize, + num_segments: usize, + enable_lso: bool, + test_config: Option, + expected_stats: Option, +) { + let (data_to_send, tx_segments) = build_tx_segments(packet_len, num_segments, enable_lso); + + let test_config = test_config.unwrap_or_default(); + let expected_stats = expected_stats.unwrap_or_else(|| { + let mut tx_packets = Counter::new(); + tx_packets.add(1); + let mut rx_packets = Counter::new(); + rx_packets.add(1); + QueueStats { + tx_packets, + rx_packets, + tx_errors: Counter::new(), + rx_errors: Counter::new(), + ..Default::default() + } + }); + + let stats = test_endpoint( + driver, + dma_mode, + packet_len, + tx_segments, + data_to_send, + expected_stats.rx_packets.get() as usize, + test_config, + ) + .await; + + assert_eq!( + stats.tx_packets.get(), + expected_stats.tx_packets.get(), + "tx_packets mismatch" + ); + assert_eq!( + stats.rx_packets.get(), + expected_stats.rx_packets.get(), + "rx_packets mismatch" + ); + assert_eq!( + stats.tx_errors.get(), + expected_stats.tx_errors.get(), + "tx_errors mismatch" + ); + assert_eq!( + stats.rx_errors.get(), + expected_stats.rx_errors.get(), + "rx_errors mismatch" + ); +} + +fn build_tx_segments( + packet_len: usize, + num_segments: usize, + enable_lso: bool, +) -> (Vec, Vec) { + // Packet length must be divisible by number of segments. + assert_eq!(packet_len % num_segments, 0); + let data_to_send = (0..packet_len).map(|v| v as u8).collect::>(); + let tx_id = 1; + let mut tx_segments = Vec::new(); + let segment_len = packet_len / num_segments; + let mut tx_metadata = net_backend::TxMetadata { + id: TxId(tx_id), + segment_count: num_segments as u8, + len: packet_len as u32, + l2_len: 14, // Ethernet header + l3_len: 20, // IPv4 header + l4_len: 20, // TCP header + max_tcp_segment_size: 1460, // Typical MSS for Ethernet + ..Default::default() + }; + + tx_metadata.flags.set_offload_tcp_segmentation(enable_lso); + + assert_eq!( + tx_metadata.l2_len as usize + tx_metadata.l3_len as usize + tx_metadata.l4_len as usize, + IPV4_HEADER_LENGTH + ); + assert_eq!(packet_len % num_segments, 0); + assert_eq!(data_to_send.len(), packet_len); + + tx_segments.push(TxSegment { + ty: net_backend::TxSegmentType::Head(tx_metadata.clone()), + gpa: 0, + len: segment_len as u32, + }); + + for j in 0..(num_segments - 1) { + let gpa = (j + 1) * segment_len; + tx_segments.push(TxSegment { + ty: net_backend::TxSegmentType::Tail, + gpa: gpa as u64, + len: segment_len as u32, + }); + } + + assert_eq!(tx_segments.len(), num_segments); + (data_to_send, tx_segments) +} + +async fn test_endpoint( + driver: DefaultDriver, + dma_mode: GuestDmaMode, + packet_len: usize, + tx_segments: Vec, + data_to_send: Vec, + expected_num_received_packets: usize, + test_configuration: ManaTestConfiguration, +) -> QueueStats { + let tx_id = 1; + let pages = 256; // 1MB + let allow_dma = dma_mode == GuestDmaMode::DirectDma; + let mem: DeviceTestMemory = DeviceTestMemory::new(pages * 2, allow_dma, "test_endpoint"); + let payload_mem = mem.payload_mem(); + + let mut msi_set = MsiInterruptSet::new(); + let device = gdma::GdmaDevice::new( + &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())), + mem.guest_memory(), + &mut msi_set, + vec![VportConfig { + mac_address: [1, 2, 3, 4, 5, 6].into(), + endpoint: Box::new(LoopbackEndpoint::new()), + }], + &mut ExternallyManagedMmioIntercepts, + ); + let device = EmulatedDevice::new(device, msi_set, mem.dma_client()); + let dev_config = ManaQueryDeviceCfgResp { + pf_cap_flags1: 0.into(), + pf_cap_flags2: 0, + pf_cap_flags3: 0, + pf_cap_flags4: 0, + max_num_vports: 1, + reserved: 0, + max_num_eqs: 64, + }; + let thing = ManaDevice::new(&driver, device, 1, 1).await.unwrap(); + let vport = thing.new_vport(0, None, &dev_config).await.unwrap(); + let mut endpoint = ManaEndpoint::new(driver.clone(), vport, dma_mode).await; + endpoint.set_test_configuration(test_configuration); + let mut queues = Vec::new(); + let pool = net_backend::tests::Bufs::new(payload_mem.clone()); + endpoint + .get_queues( + vec![QueueConfig { + pool: Box::new(pool), + initial_rx: &(1..128).map(RxId).collect::>(), + driver: Box::new(driver.clone()), + }], + None, + &mut queues, + ) + .await + .unwrap(); + + payload_mem.write_at(0, &data_to_send).unwrap(); + + queues[0].tx_avail(tx_segments.as_slice()).unwrap(); + + // Poll for completion + let mut rx_packets = [RxId(0); 2]; + let mut rx_packets_n = 0; + let mut tx_done = [TxId(0); 2]; + let mut tx_done_n = 0; + while rx_packets_n == 0 { + let mut context = CancelContext::new().with_timeout(Duration::from_secs(1)); + match context + .until_cancelled(poll_fn(|cx| queues[0].poll_ready(cx))) + .await + { + Err(CancelReason::DeadlineExceeded) => break, + Err(e) => { + tracing::error!(error = ?e, "Failed to poll queue ready"); + break; + } + _ => {} + } + rx_packets_n += queues[0].rx_poll(&mut rx_packets[rx_packets_n..]).unwrap(); + // GDMA Errors generate a TryReturn error, ignored here. + tx_done_n += queues[0].tx_poll(&mut tx_done[tx_done_n..]).unwrap_or(0); + if expected_num_received_packets == 0 { + break; + } + } + assert_eq!(rx_packets_n, expected_num_received_packets); + + if expected_num_received_packets == 0 { + // If no packets were received, exit. + let stats = get_queue_stats(queues[0].queue_stats()); + drop(queues); + endpoint.stop().await; + return stats; + } + + // Check tx + assert_eq!(tx_done_n, 1); + assert_eq!(tx_done[0].0, tx_id); + + // Check rx + assert_eq!(rx_packets[0].0, 1); + let rx_id = rx_packets[0]; + + let mut received_data = vec![0; packet_len]; + payload_mem + .read_at(2048 * rx_id.0 as u64, &mut received_data) + .unwrap(); + assert_eq!(received_data.len(), packet_len); + assert_eq!(&received_data[..], data_to_send, "{:?}", rx_id); + + let stats = get_queue_stats(queues[0].queue_stats()); + drop(queues); + endpoint.stop().await; + stats +} + +fn get_queue_stats(queue_stats: Option<&dyn net_backend::BackendQueueStats>) -> QueueStats { + let queue_stats = queue_stats.unwrap(); + QueueStats { + rx_errors: queue_stats.rx_errors(), + tx_errors: queue_stats.tx_errors(), + rx_packets: queue_stats.rx_packets(), + tx_packets: queue_stats.tx_packets(), + ..Default::default() + } +}