Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: reduce tokio tasks #45

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
32 changes: 16 additions & 16 deletions src/axum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use axum_crate as axum;

use crate::CloseCode;
use crate::CloseFrame;
use crate::RawMessage;
use crate::Message;
use crate::Server;
use crate::ServerExt;
use crate::Socket;
Expand Down Expand Up @@ -121,34 +121,34 @@ where
}
}

impl From<ws::Message> for RawMessage {
impl From<ws::Message> for Message {
fn from(message: ws::Message) -> Self {
match message {
ws::Message::Text(text) => RawMessage::Text(text),
ws::Message::Binary(binary) => RawMessage::Binary(binary),
ws::Message::Ping(ping) => RawMessage::Ping(ping),
ws::Message::Pong(pong) => RawMessage::Pong(pong),
ws::Message::Close(Some(close)) => RawMessage::Close(Some(CloseFrame {
ws::Message::Text(text) => Message::Text(text),
ws::Message::Binary(binary) => Message::Binary(binary),
ws::Message::Ping(ping) => Message::Ping(ping),
ws::Message::Pong(pong) => Message::Pong(pong),
ws::Message::Close(Some(close)) => Message::Close(Some(CloseFrame {
code: CloseCode::try_from(close.code).unwrap(),
reason: close.reason.into(),
})),
ws::Message::Close(None) => RawMessage::Close(None),
ws::Message::Close(None) => Message::Close(None),
}
}
}

impl From<RawMessage> for ws::Message {
fn from(message: RawMessage) -> Self {
impl From<Message> for ws::Message {
fn from(message: Message) -> Self {
match message {
RawMessage::Text(text) => ws::Message::Text(text),
RawMessage::Binary(binary) => ws::Message::Binary(binary),
RawMessage::Ping(ping) => ws::Message::Ping(ping),
RawMessage::Pong(pong) => ws::Message::Pong(pong),
RawMessage::Close(Some(close)) => ws::Message::Close(Some(ws::CloseFrame {
Message::Text(text) => ws::Message::Text(text),
Message::Binary(binary) => ws::Message::Binary(binary),
Message::Ping(ping) => ws::Message::Ping(ping),
Message::Pong(pong) => ws::Message::Pong(pong),
Message::Close(Some(close)) => ws::Message::Close(Some(ws::CloseFrame {
code: close.code.into(),
reason: close.reason.into(),
})),
RawMessage::Close(None) => ws::Message::Close(None),
Message::Close(None) => ws::Message::Close(None),
}
}
}
Expand Down
10 changes: 6 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ pub async fn connect<E: ClientExt + 'static>(
}
let socket = Socket::new(stream, Config::default());
tracing::info!("connected to {}", config.url);
let mut actor = ClientActor {
let actor = ClientActor {
client,
socket_receiver,
call_receiver,
Expand All @@ -287,12 +287,13 @@ struct ClientActor<E: ClientExt> {
}

impl<E: ClientExt> ClientActor<E> {
async fn run(&mut self) -> Result<(), Error> {
async fn run(mut self) -> Result<(), Error> {
loop {
tokio::select! {
Some(message) = self.socket_receiver.recv() => {
self.socket.send(message.clone()).await;
if let Message::Close(_frame) = message {
let is_closing = matches!(&message, Message::Close(_));
self.socket.sink.send(message).await?;
if is_closing {
return Ok(())
}
}
Expand All @@ -309,6 +310,7 @@ impl<E: ClientExt> ClientActor<E> {
self.client.on_close().await?;
self.reconnect().await;
}
_ => {}
};
}
Some(Err(error)) => {
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
//!
//! Refer to [`client`] or [`server`] module for detailed implementation guides.

extern crate core;

mod socket;

pub use socket::CloseCode;
pub use socket::CloseFrame;
pub use socket::Message;
pub use socket::RawMessage;
pub use socket::Sink;
pub use socket::Socket;
pub use socket::Stream;
Expand Down
50 changes: 16 additions & 34 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ use crate::Session;
use crate::SessionExt;
use crate::Socket;
use async_trait::async_trait;
use std::any::Any;
use std::net::SocketAddr;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
Expand All @@ -104,16 +105,16 @@ struct NewConnection<E: ServerExt> {
respond_to: oneshot::Sender<<E::Session as SessionExt>::ID>,
}

struct Disconnected<E: ServerExt> {
id: <E::Session as SessionExt>::ID,
result: Result<Option<CloseFrame>, Error>,
pub(crate) struct Disconnected<I> {
pub(crate) id: I,
pub(crate) result: Result<Option<CloseFrame>, Error>,
}

struct ServerActor<E: ServerExt> {
connections: mpsc::UnboundedReceiver<NewConnection<E>>,
disconnections: mpsc::UnboundedReceiver<Disconnected<E>>,
disconnections: mpsc::UnboundedReceiver<Box<dyn Any + Send>>,
disconnections_tx: mpsc::UnboundedSender<Box<dyn Any + Send>>,
calls: mpsc::UnboundedReceiver<E::Call>,
server: Server<E>,
extension: E,
}

Expand All @@ -127,21 +128,14 @@ where
loop {
if let Err(err) = async {
tokio::select! {
Some(NewConnection{socket, address, respond_to, request}) = self.connections.recv() => {
Some(NewConnection{mut socket, address, respond_to, request}) = self.connections.recv() => {
socket.disconnected = Some(self.disconnections_tx.clone());
let session = self.extension.on_connect(socket, request, address).await?;
let session_id = session.id.clone();
tracing::info!("connection from {address} accepted");
respond_to.send(session_id.clone()).unwrap();

tokio::spawn({
let server = self.server.clone();
async move {
let result = session.closed().await;
server.disconnected(session_id, result).await;
}
});
let _ = respond_to.send(session.id);
}
Some(Disconnected{id, result}) = self.disconnections.recv() => {
Some(x) = self.disconnections.recv() => {
let Disconnected{id, result}: Disconnected<<E::Session as SessionExt>::ID> = *x.downcast().unwrap();
self.extension.on_disconnect(id.clone()).await?;
match result {
Ok(Some(CloseFrame { code, reason })) => {
Expand All @@ -152,7 +146,9 @@ where
};
}
Some(call) = self.calls.recv() => {
self.extension.on_call(call).await?
if let Err(err) = self.extension.on_call(call).await {
tracing::error!("error when calling {:?}", err);
}
}
}
Ok::<_, Error>(())
Expand Down Expand Up @@ -193,7 +189,6 @@ pub trait ServerExt: Send {
#[derive(Debug)]
pub struct Server<E: ServerExt> {
connections: mpsc::UnboundedSender<NewConnection<E>>,
disconnections: mpsc::UnboundedSender<Disconnected<E>>,
calls: mpsc::UnboundedSender<E::Call>,
}

Expand All @@ -211,15 +206,14 @@ impl<E: ServerExt + 'static> Server<E> {
let handle = Self {
connections: connection_sender,
calls: call_sender,
disconnections: disconnection_sender,
};
let extension = create(handle.clone());
let actor = ServerActor {
connections: connection_receiver,
disconnections: disconnection_receiver,
disconnections_tx: disconnection_sender,
calls: call_receiver,
extension,
server: handle.clone(),
};
let future = tokio::spawn(actor.run());

Expand Down Expand Up @@ -248,17 +242,6 @@ impl<E: ServerExt> Server<E> {
receiver.await.unwrap()
}

pub(crate) async fn disconnected(
&self,
id: <E::Session as SessionExt>::ID,
result: Result<Option<CloseFrame>, Error>,
) {
self.disconnections
.send(Disconnected { id, result })
.map_err(|_| ())
.unwrap();
}

pub fn call(&self, call: E::Call) {
self.calls.send(call).map_err(|_| ()).unwrap();
}
Expand All @@ -277,11 +260,10 @@ impl<E: ServerExt> Server<E> {
}
}

impl<E: ServerExt> std::clone::Clone for Server<E> {
impl<E: ServerExt> Clone for Server<E> {
fn clone(&self) -> Self {
Self {
connections: self.connections.clone(),
disconnections: self.disconnections.clone(),
calls: self.calls.clone(),
}
}
Expand Down
Loading