Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use compio::{
op::ReadAt,
};

let driver = Driver::new().unwrap();
let mut driver = Driver::new().unwrap();
let file = File::open("Cargo.toml").unwrap();
// Attach the `RawFd` to driver first.
driver.attach(file.as_raw_fd()).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion examples/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use compio::{
};

fn main() {
let driver = Driver::new().unwrap();
let mut driver = Driver::new().unwrap();
let file = compio::fs::File::open("Cargo.toml").unwrap();
driver.attach(file.as_raw_fd()).unwrap();

Expand Down
41 changes: 20 additions & 21 deletions src/driver/iocp/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
ffi::c_void,
io,
Expand Down Expand Up @@ -119,9 +118,9 @@ pub trait OpCode {
/// Low-level driver of IOCP.
pub struct Driver {
port: OwnedHandle,
operations: RefCell<VecDeque<(*mut dyn OpCode, Overlapped)>>,
submit_map: RefCell<HashMap<usize, *mut OVERLAPPED>>,
cancelled: RefCell<HashSet<*mut OVERLAPPED>>,
operations: VecDeque<(*mut dyn OpCode, Overlapped)>,
submit_map: HashMap<usize, *mut OVERLAPPED>,
cancelled: HashSet<*mut OVERLAPPED>,
}

impl Driver {
Expand All @@ -134,9 +133,9 @@ impl Driver {
.map_err(|_| io::Error::last_os_error())?;
Ok(Self {
port,
operations: RefCell::new(VecDeque::with_capacity(Self::DEFAULT_CAPACITY)),
submit_map: RefCell::default(),
cancelled: RefCell::default(),
operations: VecDeque::with_capacity(Self::DEFAULT_CAPACITY),
submit_map: HashMap::default(),
cancelled: HashSet::default(),
})
}
}
Expand Down Expand Up @@ -232,7 +231,7 @@ fn ntstatus_from_win32(x: i32) -> NTSTATUS {
}

impl Poller for Driver {
fn attach(&self, fd: RawFd) -> io::Result<()> {
fn attach(&mut self, fd: RawFd) -> io::Result<()> {
detach_iocp(fd)?;
let port = unsafe { CreateIoCompletionPort(fd as _, self.port.as_raw_handle() as _, 0, 0) };
if port == 0 {
Expand All @@ -242,39 +241,39 @@ impl Poller for Driver {
}
}

unsafe fn push(&self, op: &mut (impl OpCode + 'static), user_data: usize) -> io::Result<()> {
self.operations
.borrow_mut()
.push_back((op, Overlapped::new(user_data)));
unsafe fn push(
&mut self,
op: &mut (impl OpCode + 'static),
user_data: usize,
) -> io::Result<()> {
self.operations.push_back((op, Overlapped::new(user_data)));
Ok(())
}

fn cancel(&self, user_data: usize) {
if let Some(ptr) = self.submit_map.borrow_mut().remove(&user_data) {
fn cancel(&mut self, user_data: usize) {
if let Some(ptr) = self.submit_map.remove(&user_data) {
// TODO: should we call CancelIoEx?
self.cancelled.borrow_mut().insert(ptr);
self.cancelled.insert(ptr);
}
}

fn poll(
&self,
&mut self,
timeout: Option<Duration>,
entries: &mut [MaybeUninit<Entry>],
) -> io::Result<usize> {
if entries.is_empty() {
return Ok(0);
}
while let Some((op, overlapped)) = self.operations.borrow_mut().pop_front() {
while let Some((op, overlapped)) = self.operations.pop_front() {
let overlapped = Box::new(overlapped);
let user_data = overlapped.user_data;
let overlapped_ptr = Box::into_raw(overlapped);
let result = unsafe { op.as_mut().unwrap().operate(overlapped_ptr.cast()) };
if let Poll::Ready(result) = result {
post_driver_raw(self.port.as_raw_handle(), result, overlapped_ptr.cast())?;
} else {
self.submit_map
.borrow_mut()
.insert(user_data, overlapped_ptr.cast());
self.submit_map.insert(user_data, overlapped_ptr.cast());
}
}

Expand Down Expand Up @@ -307,7 +306,7 @@ impl Poller for Driver {
let transferred = iocp_entry.dwNumberOfBytesTransferred;
let overlapped_ptr = iocp_entry.lpOverlapped;
let overlapped = unsafe { Box::from_raw(overlapped_ptr.cast::<Overlapped>()) };
if self.cancelled.borrow_mut().remove(&overlapped_ptr) {
if self.cancelled.remove(&overlapped_ptr) {
continue;
}
let res = if matches!(
Expand Down
67 changes: 38 additions & 29 deletions src/driver/iour/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[doc(no_inline)]
pub use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::{cell::RefCell, collections::VecDeque, io, mem::MaybeUninit, time::Duration};
use std::{collections::VecDeque, io, mem::MaybeUninit, time::Duration};

use io_uring::{
cqueue,
Expand All @@ -24,8 +24,8 @@ pub trait OpCode {
/// Low-level driver of io-uring.
pub struct Driver {
inner: IoUring,
squeue: RefCell<VecDeque<squeue::Entry>>,
cqueue: RefCell<VecDeque<Entry>>,
squeue: VecDeque<squeue::Entry>,
cqueue: VecDeque<Entry>,
}

impl Driver {
Expand All @@ -38,25 +38,27 @@ impl Driver {
pub fn with_entries(entries: u32) -> io::Result<Self> {
Ok(Self {
inner: IoUring::new(entries)?,
squeue: RefCell::new(VecDeque::with_capacity(entries as usize)),
cqueue: RefCell::new(VecDeque::with_capacity(entries as usize)),
squeue: VecDeque::with_capacity(entries as usize),
cqueue: VecDeque::with_capacity(entries as usize),
})
}

unsafe fn submit(&self, timeout: Option<Duration>) -> io::Result<()> {
fn submit(&mut self, timeout: Option<Duration>) -> io::Result<()> {
// Anyway we need to submit once, no matter there are entries in squeue.
loop {
let mut inner_squeue = self.inner.submission_shared();
while !inner_squeue.is_full() {
if let Some(entry) = self.squeue.borrow_mut().pop_front() {
inner_squeue.push(&entry).unwrap();
} else {
break;
{
let mut inner_squeue = self.inner.submission();
while !inner_squeue.is_full() {
if let Some(entry) = self.squeue.pop_front() {
unsafe { inner_squeue.push(&entry) }.unwrap();
} else {
break;
}
}
inner_squeue.sync();
}
inner_squeue.sync();

let res = if self.squeue.borrow().is_empty() {
let res = if self.squeue.is_empty() {
// Last part of submission queue, wait till timeout.
if let Some(duration) = timeout {
let timespec = timespec(duration);
Expand All @@ -68,7 +70,6 @@ impl Driver {
} else {
self.inner.submit()
};
inner_squeue.sync();
match res {
Ok(_) => Ok(()),
Err(e) => match e.raw_os_error() {
Expand All @@ -78,7 +79,7 @@ impl Driver {
},
}?;

for entry in self.inner.completion_shared() {
for entry in self.inner.completion() {
let entry = create_entry(entry);
if entry.user_data() == u64::MAX as _ {
// This is a cancel operation.
Expand All @@ -90,45 +91,47 @@ impl Driver {
continue;
}
}
self.cqueue.borrow_mut().push_back(entry);
self.cqueue.push_back(entry);
}

if self.squeue.borrow().is_empty() && inner_squeue.is_empty() {
if self.squeue.is_empty() && self.inner.submission().is_empty() {
break;
}
}
Ok(())
}

fn poll_entries(&self, entries: &mut [MaybeUninit<Entry>]) -> usize {
let mut cqueue = self.cqueue.borrow_mut();
let len = cqueue.len().min(entries.len());
fn poll_entries(&mut self, entries: &mut [MaybeUninit<Entry>]) -> usize {
let len = self.cqueue.len().min(entries.len());
for entry in &mut entries[..len] {
entry.write(cqueue.pop_front().unwrap());
entry.write(self.cqueue.pop_front().unwrap());
}
len
}
}

impl Poller for Driver {
fn attach(&self, _fd: RawFd) -> io::Result<()> {
fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
Ok(())
}

unsafe fn push(&self, op: &mut (impl OpCode + 'static), user_data: usize) -> io::Result<()> {
unsafe fn push(
&mut self,
op: &mut (impl OpCode + 'static),
user_data: usize,
) -> io::Result<()> {
let entry = op.create_entry().user_data(user_data as _);
self.squeue.borrow_mut().push_back(entry);
self.squeue.push_back(entry);
Ok(())
}

fn cancel(&self, user_data: usize) {
fn cancel(&mut self, user_data: usize) {
self.squeue
.borrow_mut()
.push_back(AsyncCancel::new(user_data as _).build().user_data(u64::MAX));
}

fn poll(
&self,
&mut self,
timeout: Option<Duration>,
entries: &mut [MaybeUninit<Entry>],
) -> io::Result<usize> {
Expand All @@ -139,12 +142,18 @@ impl Poller for Driver {
if len > 0 {
return Ok(len);
}
unsafe { self.submit(timeout) }?;
self.submit(timeout)?;
let len = self.poll_entries(entries);
Ok(len)
}
}

impl AsRawFd for Driver {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}

fn create_entry(entry: cqueue::Entry) -> Entry {
let result = entry.result();
let result = if result < 0 {
Expand Down
13 changes: 7 additions & 6 deletions src/driver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ cfg_if::cfg_if! {
/// socket.connect(second_addr).unwrap();
/// other_socket.connect(first_addr).unwrap();
///
/// let driver = Driver::new().unwrap();
/// let mut driver = Driver::new().unwrap();
/// driver.attach(socket.as_raw_fd()).unwrap();
/// driver.attach(other_socket.as_raw_fd()).unwrap();
///
Expand Down Expand Up @@ -72,18 +72,19 @@ pub trait Poller {
/// ## Platform specific
/// * IOCP: it will be attached to the IOCP completion port.
/// * io-uring: it will do nothing and return `Ok(())`
fn attach(&self, fd: RawFd) -> io::Result<()>;
fn attach(&mut self, fd: RawFd) -> io::Result<()>;

/// Push an operation with user-defined data.
/// The data could be retrived from [`Entry`] when polling.
///
/// # Safety
///
/// - `op` should be alive until [`Poller::poll`] returns its result.
unsafe fn push(&self, op: &mut (impl OpCode + 'static), user_data: usize) -> io::Result<()>;
unsafe fn push(&mut self, op: &mut (impl OpCode + 'static), user_data: usize)
-> io::Result<()>;

/// Cancel an operation with the pushed user-defined data.
fn cancel(&self, user_data: usize);
fn cancel(&mut self, user_data: usize);

/// Poll the driver with an optional timeout.
///
Expand All @@ -96,15 +97,15 @@ pub trait Poller {
///
/// [`Event`]: crate::event::Event
fn poll(
&self,
&mut self,
timeout: Option<Duration>,
entries: &mut [MaybeUninit<Entry>],
) -> io::Result<usize>;

/// Poll the driver and get only one entry back.
///
/// See [`Poller::poll`].
fn poll_one(&self, timeout: Option<Duration>) -> io::Result<Entry> {
fn poll_one(&mut self, timeout: Option<Duration>) -> io::Result<Entry> {
let mut entry = MaybeUninit::uninit();
let polled = self.poll(timeout, std::slice::from_mut(&mut entry))?;
debug_assert_eq!(polled, 1);
Expand Down
4 changes: 2 additions & 2 deletions src/event/iocp.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{io, marker::PhantomData};

use crate::{
driver::{post_driver, AsRawFd, RawFd},
driver::{post_driver, RawFd},
key::Key,
task::{op::OpFuture, RUNTIME},
};
Expand Down Expand Up @@ -45,7 +45,7 @@ unsafe impl Sync for EventHandle<'_> {}

impl<'a> EventHandle<'a> {
pub(crate) fn new(user_data: &'a Key<()>) -> Self {
let handle = RUNTIME.with(|runtime| runtime.driver().as_raw_fd());
let handle = RUNTIME.with(|runtime| runtime.raw_driver());
Self {
user_data: **user_data,
handle,
Expand Down
Loading