Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: reduce tokio tasks #45

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
33 changes: 16 additions & 17 deletions src/axum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@

use axum_crate as axum;

use crate::CloseCode;
use crate::CloseFrame;
use crate::RawMessage;
use crate::Server;
use crate::ServerExt;
use crate::Socket;
use crate::{CloseCode, Message};
gbaranski marked this conversation as resolved.
Show resolved Hide resolved
use async_trait::async_trait;
use axum::extract::ws;
use axum::extract::ws::rejection::*;
Expand Down Expand Up @@ -121,34 +120,34 @@ where
}
}

impl From<ws::Message> for RawMessage {
impl From<ws::Message> for Message {
fn from(message: ws::Message) -> Self {
match message {
ws::Message::Text(text) => RawMessage::Text(text),
ws::Message::Binary(binary) => RawMessage::Binary(binary),
ws::Message::Ping(ping) => RawMessage::Ping(ping),
ws::Message::Pong(pong) => RawMessage::Pong(pong),
ws::Message::Close(Some(close)) => RawMessage::Close(Some(CloseFrame {
ws::Message::Text(text) => Message::Text(text),
ws::Message::Binary(binary) => Message::Binary(binary),
ws::Message::Ping(ping) => Message::Ping(ping),
ws::Message::Pong(pong) => Message::Pong(pong),
ws::Message::Close(Some(close)) => Message::Close(Some(CloseFrame {
code: CloseCode::try_from(close.code).unwrap(),
reason: close.reason.into(),
})),
ws::Message::Close(None) => RawMessage::Close(None),
ws::Message::Close(None) => Message::Close(None),
}
}
}

impl From<RawMessage> for ws::Message {
fn from(message: RawMessage) -> Self {
impl From<Message> for ws::Message {
fn from(message: Message) -> Self {
match message {
RawMessage::Text(text) => ws::Message::Text(text),
RawMessage::Binary(binary) => ws::Message::Binary(binary),
RawMessage::Ping(ping) => ws::Message::Ping(ping),
RawMessage::Pong(pong) => ws::Message::Pong(pong),
RawMessage::Close(Some(close)) => ws::Message::Close(Some(ws::CloseFrame {
Message::Text(text) => ws::Message::Text(text),
Message::Binary(binary) => ws::Message::Binary(binary),
Message::Ping(ping) => ws::Message::Ping(ping),
Message::Pong(pong) => ws::Message::Pong(pong),
Message::Close(Some(close)) => ws::Message::Close(Some(ws::CloseFrame {
code: close.code.into(),
reason: close.reason.into(),
})),
RawMessage::Close(None) => ws::Message::Close(None),
Message::Close(None) => ws::Message::Close(None),
}
}
}
Expand Down
16 changes: 12 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ pub async fn connect<E: ClientExt + 'static>(
}
let socket = Socket::new(stream, Config::default());
tracing::info!("connected to {}", config.url);
let mut actor = ClientActor {
let actor = ClientActor {
client,
socket_receiver,
call_receiver,
Expand All @@ -287,12 +287,19 @@ struct ClientActor<E: ClientExt> {
}

impl<E: ClientExt> ClientActor<E> {
async fn run(&mut self) -> Result<(), Error> {
async fn run(mut self) -> Result<(), Error> {
loop {
tokio::select! {
Some(message) = self.socket_receiver.recv() => {
self.socket.send(message.clone()).await;
if let Message::Close(_frame) = message {
let close =
gbaranski marked this conversation as resolved.
Show resolved Hide resolved
if let Message::Close(_frame) = &message {
true
} else {
false
};

self.socket.sink.send(message).await?;
if close {
return Ok(())
}
}
Expand All @@ -309,6 +316,7 @@ impl<E: ClientExt> ClientActor<E> {
self.client.on_close().await?;
self.reconnect().await;
}
_ => {}
};
}
Some(Err(error)) => {
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
//!
//! Refer to [`client`] or [`server`] module for detailed implementation guides.

extern crate core;

mod socket;

pub use socket::CloseCode;
pub use socket::CloseFrame;
pub use socket::Message;
pub use socket::RawMessage;
pub use socket::Sink;
pub use socket::Socket;
pub use socket::Stream;
Expand Down
49 changes: 17 additions & 32 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ use crate::Session;
use crate::SessionExt;
use crate::Socket;
use async_trait::async_trait;
use std::any::Any;
use std::net::SocketAddr;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
Expand All @@ -104,16 +105,16 @@ struct NewConnection<E: ServerExt> {
respond_to: oneshot::Sender<<E::Session as SessionExt>::ID>,
}

struct Disconnected<E: ServerExt> {
id: <E::Session as SessionExt>::ID,
result: Result<Option<CloseFrame>, Error>,
pub(crate) struct Disconnected<I> {
pub(crate) id: I,
pub(crate) result: Result<Option<CloseFrame>, Error>,
}

struct ServerActor<E: ServerExt> {
connections: mpsc::UnboundedReceiver<NewConnection<E>>,
disconnections: mpsc::UnboundedReceiver<Disconnected<E>>,
disconnections: mpsc::UnboundedReceiver<Box<dyn Any + Send>>,
disconnections_tx: mpsc::UnboundedSender<Box<dyn Any + Send>>,
calls: mpsc::UnboundedReceiver<E::Call>,
server: Server<E>,
extension: E,
}

Expand All @@ -127,21 +128,17 @@ where
loop {
if let Err(err) = async {
tokio::select! {
Some(NewConnection{socket, address, respond_to, request}) = self.connections.recv() => {
Some(NewConnection{mut socket, address, respond_to, request}) = self.connections.recv() => {
socket.disconnected = Some(self.disconnections_tx.clone());
let session = self.extension.on_connect(socket, request, address).await?;
let session_id = session.id.clone();
gbaranski marked this conversation as resolved.
Show resolved Hide resolved
tracing::info!("connection from {address} accepted");
respond_to.send(session_id.clone()).unwrap();
let _ = respond_to.send(session_id.clone());
gbaranski marked this conversation as resolved.
Show resolved Hide resolved


tokio::spawn({
let server = self.server.clone();
async move {
let result = session.closed().await;
server.disconnected(session_id, result).await;
}
});
}
Some(Disconnected{id, result}) = self.disconnections.recv() => {
Some(x) = self.disconnections.recv() => {
let Disconnected{id, result}: Disconnected<<E::Session as SessionExt>::ID> = *x.downcast().unwrap();
self.extension.on_disconnect(id.clone()).await?;
match result {
Ok(Some(CloseFrame { code, reason })) => {
Expand All @@ -152,7 +149,9 @@ where
};
}
Some(call) = self.calls.recv() => {
self.extension.on_call(call).await?
if let Err(err) = self.extension.on_call(call).await {
tracing::error!("Error when calling {:?}", err);
gbaranski marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Ok::<_, Error>(())
Expand Down Expand Up @@ -193,7 +192,6 @@ pub trait ServerExt: Send {
#[derive(Debug)]
pub struct Server<E: ServerExt> {
connections: mpsc::UnboundedSender<NewConnection<E>>,
disconnections: mpsc::UnboundedSender<Disconnected<E>>,
calls: mpsc::UnboundedSender<E::Call>,
}

Expand All @@ -211,15 +209,14 @@ impl<E: ServerExt + 'static> Server<E> {
let handle = Self {
connections: connection_sender,
calls: call_sender,
disconnections: disconnection_sender,
};
let extension = create(handle.clone());
let actor = ServerActor {
connections: connection_receiver,
disconnections: disconnection_receiver,
disconnections_tx: disconnection_sender,
calls: call_receiver,
extension,
server: handle.clone(),
};
let future = tokio::spawn(actor.run());

Expand Down Expand Up @@ -248,17 +245,6 @@ impl<E: ServerExt> Server<E> {
receiver.await.unwrap()
}

pub(crate) async fn disconnected(
&self,
id: <E::Session as SessionExt>::ID,
result: Result<Option<CloseFrame>, Error>,
) {
self.disconnections
.send(Disconnected { id, result })
.map_err(|_| ())
.unwrap();
}

pub fn call(&self, call: E::Call) {
self.calls.send(call).map_err(|_| ()).unwrap();
}
Expand All @@ -277,11 +263,10 @@ impl<E: ServerExt> Server<E> {
}
}

impl<E: ServerExt> std::clone::Clone for Server<E> {
impl<E: ServerExt> Clone for Server<E> {
fn clone(&self) -> Self {
Self {
connections: self.connections.clone(),
disconnections: self.disconnections.clone(),
calls: self.calls.clone(),
}
}
Expand Down
Loading