Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 156 additions & 4 deletions vm/devices/storage/ide/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod drive;
mod protocol;

use crate::drive::save_restore::DriveSaveRestore;
use crate::PAGE_SIZE64;
use crate::protocol::BusMasterReg;
use crate::protocol::DeviceControlReg;
use crate::protocol::IdeCommand;
Expand All @@ -27,6 +28,7 @@ use disk_backend::Disk;
use drive::DiskDrive;
use drive::DriveRegister;
use guestmem::GuestMemory;
use guestmem::ranges::PagedRange;
use ide_resources::IdePath;
use inspect::Inspect;
use inspect::InspectMut;
Expand All @@ -49,6 +51,7 @@ use thiserror::Error;
use vmcore::device_state::ChangeDeviceState;
use vmcore::line_interrupt::LineInterrupt;
use zerocopy::IntoBytes;
use guestmem::ranges::PagedRange;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove this redundant import in the next diff.


open_enum! {
pub enum IdeIoPort: u16 {
Expand Down Expand Up @@ -696,28 +699,40 @@ impl Channel {
write.deferred.complete();
}

fn gpa_to_gpn(gpa: u64) -> u64 {
gpa / PAGE_SIZE64
}

fn perform_dma_memory_phase(&mut self) {
tracing::trace!("perform_dma_memory_phase");
let Some(drive) = &mut self.drives[self.state.current_drive_idx] else {
tracing::trace!("returning from perform_dma_memory_phase");
return;
};

if self.bus_master_state.dma_error {
tracing::trace!("DMA error");
if drive.handle_read_dma_descriptor_error() {
self.bus_master_state.dma_error = false;
}
return;
}

let mut dma_avail = match drive.dma_request() {
tracing::trace!(dmaState = ?self.bus_master_state.dma_state, dmaType = ?self.bus_master_state.dma_io_type(), "DMA state");
let (dma_type, mut dma_avail) = match drive.dma_request() {
Some((dma_type, avail)) if *dma_type == self.bus_master_state.dma_io_type() => {
avail as u32
(Some(*dma_type), avail as u32)
}
_ => {
// No active, appropriate DMA buffer.
tracing::trace!("Invalid dma : returning from perform_dma_memory_phase");
return;
}
};

tracing::trace!(dmaType = ?dma_type, dmaAvail = ?dma_avail, busMasterState = ?self.bus_master_state.dma_state, "DMA TYPE here");
let Some(dma) = &mut self.bus_master_state.dma_state else {
tracing::trace!("No active DMA state");
return;
};

Expand All @@ -738,6 +753,9 @@ impl Channel {
.wrapping_add(8 * (dma.descriptor_idx as u32))
.into();

tracing::trace!(gm = ?self.guest_memory, "guest_memory");

tracing::trace!(desc_addr = ?descriptor_addr, "desc_addr");
let cur_desc_table_entry = match self
.guest_memory
.read_plain::<protocol::BusMasterDmaDesc>(descriptor_addr)
Expand All @@ -764,8 +782,66 @@ impl Channel {
dma.transfer_bytes_left = 0x10000;
}

dma.transfer_base_addr = cur_desc_table_entry.mem_physical_base.into();
// Check that the every page starting from the base address is within
// the guest's physical address space.
// This is a sanity check, the guest should not be able to program the DMA
// controller with an invalid page access.

let end_gpa = cur_desc_table_entry
.mem_physical_base
.checked_add(dma.transfer_bytes_left);

if let Some(end_gpa) = end_gpa {
let start_gpn = Self::gpa_to_gpn(cur_desc_table_entry.mem_physical_base.into());
let end_gpn = Self::gpa_to_gpn(end_gpa.into());
tracing::trace!(startGpa = ?cur_desc_table_entry.mem_physical_base, endGpa = ?end_gpa, "start and end GPAs");
tracing::trace!(startGpn = ?start_gpn, endGpn = ?end_gpn, "start and end GPNs");
let gpns: Vec<u64> = (start_gpn..end_gpn).collect();

tracing::trace!(paged_range = ?PagedRange::new(0, gpns.len() * PAGE_SIZE64 as usize , &gpns), "PagedRange of GPNs");
tracing::trace!(gpns_vector = ?gpns, gpns_len = ?gpns.len(), "PagedRange values");
let paged_range = PagedRange::new(0, gpns.len() * PAGE_SIZE64 as usize, &gpns).unwrap();
let r = match dma_type.unwrap() {
DmaType::Read => {
self.guest_memory
.probe_gpn_readable_range(&paged_range)
},
DmaType::Write => {
self.guest_memory
.probe_gpn_writable_range(&paged_range)
},
};
if let Err(err) = r {
// If there is an error and there is no other IO in parallel,
// we need to stop the current DMA transfer and set the error bit
// in the Bus Master Status register.
self.bus_master_state.dma_state = None;
if !drive.handle_read_dma_descriptor_error() {
self.bus_master_state.dma_error = true;
}

tracelimit::error_ratelimited!(
error = ?err,
"dma base address out-of-range error"
);
return;
}
} else {
// If there is an error and there is no other IO in parallel,
// we need to stop the current DMA transfer and set the error bit
// in the Bus Master Status register.
self.bus_master_state.dma_state = None;
if !drive.handle_read_dma_descriptor_error() {
self.bus_master_state.dma_error = true;
}

tracelimit::error_ratelimited!(
"dma base address out-of-range error"
);
return;
}

dma.transfer_base_addr = cur_desc_table_entry.mem_physical_base.into();
dma.transfer_complete = (cur_desc_table_entry.end_of_table & 0x80) != 0;

// Increment to the next descriptor.
Expand All @@ -780,6 +856,7 @@ impl Channel {

assert!(bytes_to_transfer != 0);

tracing::trace!(bytes_to_transfer = ?bytes_to_transfer, "bytes to transfer");
drive.dma_transfer(
&self.guest_memory,
dma.transfer_base_addr,
Expand All @@ -795,7 +872,7 @@ impl Channel {
drive.set_prd_exhausted();
drive.dma_advance_buffer(dma_avail as usize);
}
tracing::trace!("dma transfer is complete");
tracing::trace!(dma_avail = ?dma_avail, "dma transfer is complete");
self.bus_master_state.dma_state = None;
break;
}
Expand Down Expand Up @@ -1172,12 +1249,16 @@ impl Channel {
};

let status = self.drive_status(drive_index);

tracing::trace!(?status, "post driveaccess");

let completed = match self.drive_type(drive_index) {
DriveType::Hard => !(status.bsy() || status.drq()),
DriveType::Optical => status.drdy(),
};
if completed {
// The command is done.
tracing::trace!(completed, "post_drive_access: completed");
let write = self.enlightened_write.take().unwrap();
match write {
EnlightenedWrite::Hard(write) => {
Expand Down Expand Up @@ -1251,6 +1332,9 @@ impl Channel {
// Save this for restoring in the enlightened path.
self.state.shadow_adapter_control_reg = data;
let v = DeviceControlReg::from_bits_truncate(data);

tracing::trace!(Reset = ?v.reset(), ?v, "Reset set");

if v.reset() && (self.drives[0].is_some() || self.drives[1].is_some()) {
self.state = ChannelState::default();
}
Expand Down Expand Up @@ -2351,6 +2435,74 @@ mod tests {
assert_eq!(buffer, file_contents.as_bytes()[..buffer.len()]);
}

#[async_test]
async fn enlightened_cmd_test_invalid_dma_base() {
/*
This is a negative test case where the DMA base address is invalid.
The test sets the DMA base address to an out-of-bounds memory
address of the guest range and expects the device to not read any data.
*/
const SECTOR_COUNT: u16 = 8;
const BYTE_COUNT: u16 = SECTOR_COUNT * protocol::HARD_DRIVE_SECTOR_BYTES as u16;

let test_guest_mem = GuestMemory::allocate(16384);

let table_gpa = 0x1000;
let data_gpa = 0x4000; // Invalid out-of-bounds memory address
test_guest_mem
.write_plain(
table_gpa,
&BusMasterDmaDesc {
mem_physical_base: data_gpa,
byte_count: BYTE_COUNT / 2,
unused: 0,
end_of_table: 0x80,
},
)
.unwrap();

let data_buffer = table_gpa as u32;
let byte_count = 0;

let eint13_command = protocol::EnlightenedInt13Command {
command: IdeCommand::WRITE_DMA_ALT,
device_head: DeviceHeadReg::new().with_lba(true),
flags: 0,
result_status: 0,
lba_low: 0,
lba_high: 0,
block_count: SECTOR_COUNT,
byte_count,
data_buffer,
skip_bytes_head: 0,
skip_bytes_tail: 0,
};
test_guest_mem.write_plain(0, &eint13_command).unwrap();

let dev_path = IdePath::default();
let (mut ide_device, _disk, _file_contents, _geometry) =
ide_test_setup(Some(test_guest_mem.clone()), DriveType::Hard);

// select device [0,0] = primary channel, primary drive
device_select(&mut ide_device, &dev_path).await;
prep_ide_channel(&mut ide_device, DriveType::Hard, &dev_path);

// READ SECTORS - enlightened
let r = ide_device.io_write(IdeIoPort::PRI_ENLIGHTENED.0, 0_u32.as_bytes()); // read from gpa 0

match r {
IoResult::Defer(mut deferred) => {
poll_fn(|cx| {
ide_device.poll_device(cx);
deferred.poll_write(cx)
})
.await
.unwrap();
}
_ => panic!("{:?}", r),
}
}

#[async_test]
async fn identify_test_cd() {
let dev_path = IdePath::default();
Expand Down