From deeab0a6ba7bced5803dd5b6f44a5ead266ae9b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Sun, 10 Sep 2023 11:54:12 +0800 Subject: [PATCH 1/2] Move RefCell to runtime. --- README.md | 2 +- examples/driver.rs | 2 +- src/driver/iocp/mod.rs | 41 ++++++++++++++++++++--------------------- src/driver/mod.rs | 13 +++++++------ src/event/iocp.rs | 4 ++-- src/task/runtime.rs | 18 +++++++++--------- tests/driver.rs | 2 +- 7 files changed, 41 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index d144c27a..636cefe7 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ use compio::{ op::ReadAt, }; -let driver = Driver::new().unwrap(); +let mut driver = Driver::new().unwrap(); let file = File::open("Cargo.toml").unwrap(); // Attach the `RawFd` to driver first. driver.attach(file.as_raw_fd()).unwrap(); diff --git a/examples/driver.rs b/examples/driver.rs index f3d8780a..53ce3c02 100644 --- a/examples/driver.rs +++ b/examples/driver.rs @@ -4,7 +4,7 @@ use compio::{ }; fn main() { - let driver = Driver::new().unwrap(); + let mut driver = Driver::new().unwrap(); let file = compio::fs::File::open("Cargo.toml").unwrap(); driver.attach(file.as_raw_fd()).unwrap(); diff --git a/src/driver/iocp/mod.rs b/src/driver/iocp/mod.rs index f1c4aaa1..5ae2a09a 100644 --- a/src/driver/iocp/mod.rs +++ b/src/driver/iocp/mod.rs @@ -1,5 +1,4 @@ use std::{ - cell::RefCell, collections::{HashMap, HashSet, VecDeque}, ffi::c_void, io, @@ -119,9 +118,9 @@ pub trait OpCode { /// Low-level driver of IOCP. pub struct Driver { port: OwnedHandle, - operations: RefCell>, - submit_map: RefCell>, - cancelled: RefCell>, + operations: VecDeque<(*mut dyn OpCode, Overlapped)>, + submit_map: HashMap, + cancelled: HashSet<*mut OVERLAPPED>, } impl Driver { @@ -134,9 +133,9 @@ impl Driver { .map_err(|_| io::Error::last_os_error())?; Ok(Self { port, - operations: RefCell::new(VecDeque::with_capacity(Self::DEFAULT_CAPACITY)), - submit_map: RefCell::default(), - cancelled: RefCell::default(), + operations: VecDeque::with_capacity(Self::DEFAULT_CAPACITY), + submit_map: HashMap::default(), + cancelled: HashSet::default(), }) } } @@ -232,7 +231,7 @@ fn ntstatus_from_win32(x: i32) -> NTSTATUS { } impl Poller for Driver { - fn attach(&self, fd: RawFd) -> io::Result<()> { + fn attach(&mut self, fd: RawFd) -> io::Result<()> { detach_iocp(fd)?; let port = unsafe { CreateIoCompletionPort(fd as _, self.port.as_raw_handle() as _, 0, 0) }; if port == 0 { @@ -242,29 +241,31 @@ impl Poller for Driver { } } - unsafe fn push(&self, op: &mut (impl OpCode + 'static), user_data: usize) -> io::Result<()> { - self.operations - .borrow_mut() - .push_back((op, Overlapped::new(user_data))); + unsafe fn push( + &mut self, + op: &mut (impl OpCode + 'static), + user_data: usize, + ) -> io::Result<()> { + self.operations.push_back((op, Overlapped::new(user_data))); Ok(()) } - fn cancel(&self, user_data: usize) { - if let Some(ptr) = self.submit_map.borrow_mut().remove(&user_data) { + fn cancel(&mut self, user_data: usize) { + if let Some(ptr) = self.submit_map.remove(&user_data) { // TODO: should we call CancelIoEx? - self.cancelled.borrow_mut().insert(ptr); + self.cancelled.insert(ptr); } } fn poll( - &self, + &mut self, timeout: Option, entries: &mut [MaybeUninit], ) -> io::Result { if entries.is_empty() { return Ok(0); } - while let Some((op, overlapped)) = self.operations.borrow_mut().pop_front() { + while let Some((op, overlapped)) = self.operations.pop_front() { let overlapped = Box::new(overlapped); let user_data = overlapped.user_data; let overlapped_ptr = Box::into_raw(overlapped); @@ -272,9 +273,7 @@ impl Poller for Driver { if let Poll::Ready(result) = result { post_driver_raw(self.port.as_raw_handle(), result, overlapped_ptr.cast())?; } else { - self.submit_map - .borrow_mut() - .insert(user_data, overlapped_ptr.cast()); + self.submit_map.insert(user_data, overlapped_ptr.cast()); } } @@ -307,7 +306,7 @@ impl Poller for Driver { let transferred = iocp_entry.dwNumberOfBytesTransferred; let overlapped_ptr = iocp_entry.lpOverlapped; let overlapped = unsafe { Box::from_raw(overlapped_ptr.cast::()) }; - if self.cancelled.borrow_mut().remove(&overlapped_ptr) { + if self.cancelled.remove(&overlapped_ptr) { continue; } let res = if matches!( diff --git a/src/driver/mod.rs b/src/driver/mod.rs index a3449c35..51bb41d0 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -43,7 +43,7 @@ cfg_if::cfg_if! { /// socket.connect(second_addr).unwrap(); /// other_socket.connect(first_addr).unwrap(); /// -/// let driver = Driver::new().unwrap(); +/// let mut driver = Driver::new().unwrap(); /// driver.attach(socket.as_raw_fd()).unwrap(); /// driver.attach(other_socket.as_raw_fd()).unwrap(); /// @@ -72,7 +72,7 @@ pub trait Poller { /// ## Platform specific /// * IOCP: it will be attached to the IOCP completion port. /// * io-uring: it will do nothing and return `Ok(())` - fn attach(&self, fd: RawFd) -> io::Result<()>; + fn attach(&mut self, fd: RawFd) -> io::Result<()>; /// Push an operation with user-defined data. /// The data could be retrived from [`Entry`] when polling. @@ -80,10 +80,11 @@ pub trait Poller { /// # Safety /// /// - `op` should be alive until [`Poller::poll`] returns its result. - unsafe fn push(&self, op: &mut (impl OpCode + 'static), user_data: usize) -> io::Result<()>; + unsafe fn push(&mut self, op: &mut (impl OpCode + 'static), user_data: usize) + -> io::Result<()>; /// Cancel an operation with the pushed user-defined data. - fn cancel(&self, user_data: usize); + fn cancel(&mut self, user_data: usize); /// Poll the driver with an optional timeout. /// @@ -96,7 +97,7 @@ pub trait Poller { /// /// [`Event`]: crate::event::Event fn poll( - &self, + &mut self, timeout: Option, entries: &mut [MaybeUninit], ) -> io::Result; @@ -104,7 +105,7 @@ pub trait Poller { /// Poll the driver and get only one entry back. /// /// See [`Poller::poll`]. - fn poll_one(&self, timeout: Option) -> io::Result { + fn poll_one(&mut self, timeout: Option) -> io::Result { let mut entry = MaybeUninit::uninit(); let polled = self.poll(timeout, std::slice::from_mut(&mut entry))?; debug_assert_eq!(polled, 1); diff --git a/src/event/iocp.rs b/src/event/iocp.rs index dec236a7..faca18f3 100644 --- a/src/event/iocp.rs +++ b/src/event/iocp.rs @@ -1,7 +1,7 @@ use std::{io, marker::PhantomData}; use crate::{ - driver::{post_driver, AsRawFd, RawFd}, + driver::{post_driver, RawFd}, key::Key, task::{op::OpFuture, RUNTIME}, }; @@ -45,7 +45,7 @@ unsafe impl Sync for EventHandle<'_> {} impl<'a> EventHandle<'a> { pub(crate) fn new(user_data: &'a Key<()>) -> Self { - let handle = RUNTIME.with(|runtime| runtime.driver().as_raw_fd()); + let handle = RUNTIME.with(|runtime| runtime.raw_driver()); Self { user_data: **user_data, handle, diff --git a/src/task/runtime.rs b/src/task/runtime.rs index 1521e3bc..6510ce67 100644 --- a/src/task/runtime.rs +++ b/src/task/runtime.rs @@ -13,13 +13,13 @@ use futures_util::future::Either; #[cfg(feature = "time")] use crate::task::time::{TimerFuture, TimerRuntime}; use crate::{ - driver::{Driver, Entry, OpCode, Poller, RawFd}, + driver::{AsRawFd, Driver, Entry, OpCode, Poller, RawFd}, task::op::{OpFuture, OpRuntime}, Key, }; pub(crate) struct Runtime { - driver: Driver, + driver: RefCell, runnables: RefCell>, op_runtime: RefCell, #[cfg(feature = "time")] @@ -29,7 +29,7 @@ pub(crate) struct Runtime { impl Runtime { pub fn new() -> io::Result { Ok(Self { - driver: Driver::new()?, + driver: RefCell::new(Driver::new()?), runnables: RefCell::default(), op_runtime: RefCell::default(), #[cfg(feature = "time")] @@ -38,8 +38,8 @@ impl Runtime { } #[allow(dead_code)] - pub fn driver(&self) -> &Driver { - &self.driver + pub fn raw_driver(&self) -> RawFd { + self.driver.borrow().as_raw_fd() } unsafe fn spawn_unchecked(&self, future: F) -> (Runnable, Task) { @@ -77,7 +77,7 @@ impl Runtime { } pub fn attach(&self, fd: RawFd) -> io::Result<()> { - self.driver.attach(fd) + self.driver.borrow_mut().attach(fd) } pub fn submit( @@ -86,7 +86,7 @@ impl Runtime { ) -> impl Future, T)> { let mut op_runtime = self.op_runtime.borrow_mut(); let (user_data, op) = op_runtime.insert(op); - let res = unsafe { self.driver.push(op.as_mut::(), *user_data) }; + let res = unsafe { self.driver.borrow_mut().push(op.as_mut::(), *user_data) }; match res { Ok(()) => { let (runnable, task) = unsafe { self.spawn_unchecked(OpFuture::new(user_data)) }; @@ -116,7 +116,7 @@ impl Runtime { } pub fn cancel_op(&self, user_data: Key) { - self.driver.cancel(*user_data); + self.driver.borrow_mut().cancel(*user_data); self.op_runtime.borrow_mut().cancel(user_data); } @@ -175,7 +175,7 @@ impl Runtime { const UNINIT_ENTRY: MaybeUninit = MaybeUninit::uninit(); let mut entries = [UNINIT_ENTRY; 16]; - match self.driver.poll(timeout, &mut entries) { + match self.driver.borrow_mut().poll(timeout, &mut entries) { Ok(len) => { for entry in &mut entries[..len] { let entry = unsafe { std::mem::replace(entry, UNINIT_ENTRY).assume_init() }; diff --git a/tests/driver.rs b/tests/driver.rs index be1969c4..d5443591 100644 --- a/tests/driver.rs +++ b/tests/driver.rs @@ -2,7 +2,7 @@ use compio::driver::{Driver, Poller}; #[test] fn poll_zero() { - let driver = Driver::new().unwrap(); + let mut driver = Driver::new().unwrap(); let polled = driver.poll(None, &mut []).unwrap(); assert_eq!(polled, 0); } From 2277839a0e9d322a00ae5c08264ad9368ca76ddd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Sun, 10 Sep 2023 12:01:44 +0800 Subject: [PATCH 2/2] Update impl for io-uring. --- src/driver/iour/mod.rs | 67 ++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/src/driver/iour/mod.rs b/src/driver/iour/mod.rs index 27643d5b..04f0c223 100644 --- a/src/driver/iour/mod.rs +++ b/src/driver/iour/mod.rs @@ -1,6 +1,6 @@ #[doc(no_inline)] pub use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; -use std::{cell::RefCell, collections::VecDeque, io, mem::MaybeUninit, time::Duration}; +use std::{collections::VecDeque, io, mem::MaybeUninit, time::Duration}; use io_uring::{ cqueue, @@ -24,8 +24,8 @@ pub trait OpCode { /// Low-level driver of io-uring. pub struct Driver { inner: IoUring, - squeue: RefCell>, - cqueue: RefCell>, + squeue: VecDeque, + cqueue: VecDeque, } impl Driver { @@ -38,25 +38,27 @@ impl Driver { pub fn with_entries(entries: u32) -> io::Result { Ok(Self { inner: IoUring::new(entries)?, - squeue: RefCell::new(VecDeque::with_capacity(entries as usize)), - cqueue: RefCell::new(VecDeque::with_capacity(entries as usize)), + squeue: VecDeque::with_capacity(entries as usize), + cqueue: VecDeque::with_capacity(entries as usize), }) } - unsafe fn submit(&self, timeout: Option) -> io::Result<()> { + fn submit(&mut self, timeout: Option) -> io::Result<()> { // Anyway we need to submit once, no matter there are entries in squeue. loop { - let mut inner_squeue = self.inner.submission_shared(); - while !inner_squeue.is_full() { - if let Some(entry) = self.squeue.borrow_mut().pop_front() { - inner_squeue.push(&entry).unwrap(); - } else { - break; + { + let mut inner_squeue = self.inner.submission(); + while !inner_squeue.is_full() { + if let Some(entry) = self.squeue.pop_front() { + unsafe { inner_squeue.push(&entry) }.unwrap(); + } else { + break; + } } + inner_squeue.sync(); } - inner_squeue.sync(); - let res = if self.squeue.borrow().is_empty() { + let res = if self.squeue.is_empty() { // Last part of submission queue, wait till timeout. if let Some(duration) = timeout { let timespec = timespec(duration); @@ -68,7 +70,6 @@ impl Driver { } else { self.inner.submit() }; - inner_squeue.sync(); match res { Ok(_) => Ok(()), Err(e) => match e.raw_os_error() { @@ -78,7 +79,7 @@ impl Driver { }, }?; - for entry in self.inner.completion_shared() { + for entry in self.inner.completion() { let entry = create_entry(entry); if entry.user_data() == u64::MAX as _ { // This is a cancel operation. @@ -90,45 +91,47 @@ impl Driver { continue; } } - self.cqueue.borrow_mut().push_back(entry); + self.cqueue.push_back(entry); } - if self.squeue.borrow().is_empty() && inner_squeue.is_empty() { + if self.squeue.is_empty() && self.inner.submission().is_empty() { break; } } Ok(()) } - fn poll_entries(&self, entries: &mut [MaybeUninit]) -> usize { - let mut cqueue = self.cqueue.borrow_mut(); - let len = cqueue.len().min(entries.len()); + fn poll_entries(&mut self, entries: &mut [MaybeUninit]) -> usize { + let len = self.cqueue.len().min(entries.len()); for entry in &mut entries[..len] { - entry.write(cqueue.pop_front().unwrap()); + entry.write(self.cqueue.pop_front().unwrap()); } len } } impl Poller for Driver { - fn attach(&self, _fd: RawFd) -> io::Result<()> { + fn attach(&mut self, _fd: RawFd) -> io::Result<()> { Ok(()) } - unsafe fn push(&self, op: &mut (impl OpCode + 'static), user_data: usize) -> io::Result<()> { + unsafe fn push( + &mut self, + op: &mut (impl OpCode + 'static), + user_data: usize, + ) -> io::Result<()> { let entry = op.create_entry().user_data(user_data as _); - self.squeue.borrow_mut().push_back(entry); + self.squeue.push_back(entry); Ok(()) } - fn cancel(&self, user_data: usize) { + fn cancel(&mut self, user_data: usize) { self.squeue - .borrow_mut() .push_back(AsyncCancel::new(user_data as _).build().user_data(u64::MAX)); } fn poll( - &self, + &mut self, timeout: Option, entries: &mut [MaybeUninit], ) -> io::Result { @@ -139,12 +142,18 @@ impl Poller for Driver { if len > 0 { return Ok(len); } - unsafe { self.submit(timeout) }?; + self.submit(timeout)?; let len = self.poll_entries(entries); Ok(len) } } +impl AsRawFd for Driver { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } +} + fn create_entry(entry: cqueue::Entry) -> Entry { let result = entry.result(); let result = if result < 0 {