/
main.rs
91 lines (72 loc) · 2.51 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use futures::{FutureExt, StreamExt};
use tokio::sync::{mpsc, RwLock};
use tokio_stream::wrappers::UnboundedReceiverStream;
use warp::ws::{Message, WebSocket};
use warp::Filter;
static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1);
static INDEX_HTML: &str = std::include_str!("../../../static/index.html");
type Users = Arc<RwLock<HashMap<usize, mpsc::UnboundedSender<Result<Message, warp::Error>>>>>;
#[tokio::main]
async fn main() {
let users = Users::default();
let users = warp::any().map(move || users.clone());
let chat = warp::path("chat")
.and(warp::ws())
.and(users)
.map(|ws: warp::ws::Ws, users| {
ws.on_upgrade(move |socket| user_connected(socket, users))
});
let index = warp::path::end().map(|| warp::reply::html(INDEX_HTML));
let routes = index.or(chat);
warp::serve(routes).run(([127, 0, 0, 1], 8000)).await;
}
async fn user_connected(ws: WebSocket, users: Users) {
let my_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed);
eprintln!("new chat user: {}", my_id);
let (user_ws_tx, mut user_ws_rx) = ws.split();
let (tx, rx) = mpsc::unbounded_channel();
let rx = UnboundedReceiverStream::new(rx);
tokio::task::spawn(rx.forward(user_ws_tx).map(|result| {
if let Err(e) = result {
eprintln!("websocket send error: {}", e);
}
}));
users.write().await.insert(my_id, tx);
let users2 = users.clone();
while let Some(result) = user_ws_rx.next().await {
let msg = match result {
Ok(msg) => msg,
Err(e) => {
eprintln!("websocket error(uid={}): {}", my_id, e);
break;
}
};
user_message(my_id, msg, &users).await;
}
user_disconnected(my_id, &users2).await;
}
async fn user_message(my_id: usize, msg: Message, users: &Users) {
// Skip any non-Text messages...
let msg = if let Ok(s) = msg.to_str() {
s
} else {
return;
};
let new_msg = format!("<User#{}>: {}", my_id, msg);
for (&uid, tx) in users.read().await.iter() {
if my_id != uid {
if let Err(_disconnected) = tx.send(Ok(Message::text(new_msg.clone()))) {
// nothing to do here, `user_disconnected` handles it
}
}
}
}
async fn user_disconnected(my_id: usize, users: &Users) {
eprintln!("good bye user: {}", my_id);
users.write().await.remove(&my_id);
}