Skip to content

Commit

Permalink
#40 Just one websocket connection per client which can subscribe to t…
Browse files Browse the repository at this point in the history
…opics
  • Loading branch information
helgoboss committed Oct 2, 2020
1 parent 39c8a9f commit bef4ec2
Showing 1 changed file with 103 additions and 27 deletions.
130 changes: 103 additions & 27 deletions main/src/infrastructure/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use reaper_high::Reaper;
use rxrust::prelude::*;
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::fmt::Debug;
use std::path::PathBuf;
use std::rc::Rc;
Expand Down Expand Up @@ -151,11 +152,24 @@ fn controller_not_found() -> Response<&'static str> {
}

fn not_found(msg: &'static str) -> Response<&'static str> {
Response::builder().status(404).body(msg).unwrap()
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(msg)
.unwrap()
}

fn internal_server_error(msg: &'static str) -> Response<&'static str> {
Response::builder().status(500).body(msg).unwrap()
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(msg)
.unwrap()
}

fn bad_request(msg: &'static str) -> Response<&'static str> {
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(msg)
.unwrap()
}

fn handle_controller_route(session_id: String) -> Result<Json, Response<&'static str>> {
Expand Down Expand Up @@ -189,12 +203,21 @@ async fn start_server(port: u16, clients: ServerClients) {
});
let ws_route = {
let clients = warp::any().map(move || clients.clone());
warp::path!("realearn" / String / "projection")
warp::path::end()
.and(warp::ws())
.and(warp::query::<WebSocketRequest>())
.and(clients)
.map(|realearn_session_id, ws: warp::ws::Ws, clients| {
ws.on_upgrade(move |ws| client_connected(ws, realearn_session_id, clients))
})
.map(
|ws: warp::ws::Ws, req: WebSocketRequest, clients| -> Box<dyn Reply> {
let topics: Result<HashSet<_>, _> =
req.topics.split(',').map(Topic::try_from).collect();
if let Ok(topics) = topics {
Box::new(ws.on_upgrade(move |ws| client_connected(ws, topics, clients)))
} else {
Box::new(bad_request("at least one of the given topics is invalid"))
}
},
)
};
let routes = controller_route
.or(controller_routing_route)
Expand All @@ -203,6 +226,11 @@ async fn start_server(port: u16, clients: ServerClients) {
warp::serve(routes).run(([0, 0, 0, 0], port)).await;
}

#[derive(Deserialize)]
struct WebSocketRequest {
topics: String,
}

#[derive(Deserialize)]
struct PatchRequest {
op: PatchRequestOp,
Expand All @@ -216,7 +244,9 @@ enum PatchRequestOp {
Replace,
}

async fn client_connected(ws: WebSocket, realearn_session_id: String, clients: ServerClients) {
type Topics = HashSet<Topic>;

async fn client_connected(ws: WebSocket, topics: Topics, clients: ServerClients) {
use futures::{FutureExt, StreamExt};
use warp::Filter;
let (ws_sender_sink, mut ws_receiver_stream) = ws.split();
Expand All @@ -228,14 +258,14 @@ async fn client_connected(ws: WebSocket, realearn_session_id: String, clients: S
}
}));
let client_id = NEXT_CLIENT_ID.fetch_add(1, Ordering::Relaxed);
let client = ProjectionClient {
let client = WebSocketClient {
id: client_id,
realearn_session_id,
topics,
sender: client_sender,
};
clients.write().unwrap().insert(client_id, client.clone());
Reaper::get().do_later_in_main_thread_asap(move || {
send_initial_controller_projection(&client);
send_initial_data(&client);
});
// Keep receiving websocket receiver stream messages
while let Some(result) = ws_receiver_stream.next().await {
Expand All @@ -253,22 +283,26 @@ async fn client_connected(ws: WebSocket, realearn_session_id: String, clients: S
}

#[derive(Clone)]
pub struct ProjectionClient {
pub id: usize,
pub realearn_session_id: String,
pub sender: mpsc::UnboundedSender<std::result::Result<Message, warp::Error>>,
pub struct WebSocketClient {
id: usize,
topics: Topics,
sender: mpsc::UnboundedSender<std::result::Result<Message, warp::Error>>,
}

impl ProjectionClient {
pub fn send(&self, text: &str) -> Result<(), &'static str> {
impl WebSocketClient {
fn send(&self, text: &str) -> Result<(), &'static str> {
self.sender
.send(Ok(Message::text(text)))
.map_err(|_| "couldn't send")
}

fn is_subscribed_to(&self, topic: &Topic) -> bool {
self.topics.contains(topic)
}
}

// We don't take the async RwLock by Tokio because we need to access this in sync code, too!
pub type ServerClients = Arc<std::sync::RwLock<HashMap<usize, ProjectionClient>>>;
pub type ServerClients = Arc<std::sync::RwLock<HashMap<usize, WebSocketClient>>>;

pub fn keep_informing_clients(shared_session: &SharedSession) {
let session = shared_session.borrow();
Expand All @@ -281,11 +315,11 @@ pub fn keep_informing_clients(shared_session: &SharedSession) {
)
.with(Rc::downgrade(shared_session))
.do_async(|session, _| {
let _ = send_updated_controller_projection(&session.borrow());
let _ = send_updated_controller_routing(&session.borrow());
});
}

fn send_updated_controller_projection(session: &Session) -> Result<(), &'static str> {
fn send_updated_controller_routing(session: &Session) -> Result<(), &'static str> {
let clients = App::get().server().borrow().clients()?.clone();
let clients = clients
.read()
Expand All @@ -294,18 +328,60 @@ fn send_updated_controller_projection(session: &Session) -> Result<(), &'static
return Ok(());
}
let json = get_controller_routing_as_json(session)?;
for client in clients.values() {
if client.realearn_session_id != session.id() {
continue;
}
for client in clients.values().filter(|c| {
c.is_subscribed_to(&Topic::ControllerRouting {
session_id: session.id().to_string(),
})
}) {
let _ = client.send(&json);
}
Ok(())
}

fn send_initial_controller_projection(client: &ProjectionClient) -> Result<(), &'static str> {
let session = session_manager::find_session_by_id(&client.realearn_session_id)
.ok_or("couldn't find that session")?;
fn send_initial_data(client: &WebSocketClient) {
for topic in &client.topics {
let _ = send_topic(client, topic).unwrap();
}
}

fn send_topic(client: &WebSocketClient, topic: &Topic) -> Result<(), &'static str> {
use Topic::*;
match topic {
ControllerRouting { session_id } => send_initial_controller_routing(client, session_id),
ActiveController { .. } => todo!(),
}
}

#[derive(Clone, Eq, PartialEq, Hash, Debug)]
enum Topic {
ActiveController { session_id: String },
ControllerRouting { session_id: String },
}

impl TryFrom<&str> for Topic {
type Error = &'static str;

fn try_from(topic_expression: &str) -> Result<Self, Self::Error> {
let topic_segments: Vec<_> = topic_expression.split('/').skip(1).collect();
let topic = match topic_segments.as_slice() {
["realearn", "session", id, "controller-routing"] => Topic::ControllerRouting {
session_id: id.to_string(),
},
["realearn", "session", id, "controller"] => Topic::ActiveController {
session_id: id.to_string(),
},
_ => return Err("invalid topic expression"),
};
Ok(topic)
}
}

fn send_initial_controller_routing(
client: &WebSocketClient,
session_id: &str,
) -> Result<(), &'static str> {
let session =
session_manager::find_session_by_id(session_id).ok_or("couldn't find that session")?;
let json = get_controller_routing_as_json(&session.borrow())?;
client.send(&json)
}
Expand Down

0 comments on commit bef4ec2

Please sign in to comment.