Skip to content

Commit

Permalink
Add ws support for DI server (#2604)
Browse files Browse the repository at this point in the history
* init

* add impl

* fix compile

* fix socketaddr

* fix tests

* fix tests

* use ws

* update lockfile

* use ws for cli by default
  • Loading branch information
Kailai-Wang committed Mar 21, 2024
1 parent 7539aff commit 6e45b49
Show file tree
Hide file tree
Showing 23 changed files with 333 additions and 205 deletions.
2 changes: 1 addition & 1 deletion local-setup/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_flags(index, worker):

return list(filter(None, [
"--clean-reset",
"-T", "wss://localhost",
"-T", "ws://localhost",
"-P", ports['trusted_worker_port'],
"-w", ports['untrusted_worker_port'],
"-r", ports['mura_port'],
Expand Down
19 changes: 0 additions & 19 deletions tee-worker/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tee-worker/cli/lit_ts_integration_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ NPORT=${NPORT:-9912}
NODEURL=${NODEURL:-"ws://litentry-node"}
NODEHTTPURL=${NODEHTTPURL:-"http://litentry-node"}
WORKER1PORT=${WORKER1PORT:-2011}
WORKER1URL=${WORKER1URL:-"wss://litentry-worker-1"}
WORKER1URL=${WORKER1URL:-"ws://litentry-worker-1"}

CLIENT_BIN=${CLIENT_BIN:-"/usr/local/bin/litentry-cli"}

Expand Down
2 changes: 1 addition & 1 deletion tee-worker/cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub struct Cli {
node_port: String,

/// worker url
#[clap(short = 'U', long, default_value_t = String::from("wss://127.0.0.1"))]
#[clap(short = 'U', long, default_value_t = String::from("ws://127.0.0.1"))]
worker_url: String,

/// worker direct invocation port
Expand Down
138 changes: 71 additions & 67 deletions tee-worker/core/tls-websocket-server/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
use crate::sgx_reexport_prelude::*;

use crate::{
error::WebSocketError, stream_state::StreamState, WebSocketConnection, WebSocketMessageHandler,
WebSocketResult,
error::WebSocketError,
stream_state::{MaybeServerTlsStream, StreamState},
WebSocketConnection, WebSocketMessageHandler, WebSocketResult,
};
use log::*;
use mio::{event::Event, net::TcpStream, Poll, Ready, Token};
Expand All @@ -35,7 +36,7 @@ use tungstenite::Message;

/// A web-socket connection object.
pub struct TungsteniteWsConnection<Handler> {
stream_state: StreamState,
stream_state: StreamState<TcpStream>,
connection_token: Token,
connection_handler: Arc<Handler>,
is_closed: bool,
Expand All @@ -47,78 +48,81 @@ where
{
pub fn new(
tcp_stream: TcpStream,
server_session: ServerSession,
maybe_server_session: Option<ServerSession>,
connection_token: Token,
handler: Arc<Handler>,
) -> WebSocketResult<Self> {
let stream_state = match maybe_server_session {
Some(sess) => StreamState::new_rustls_stream(sess, tcp_stream),
None => StreamState::new_plain_stream(tcp_stream),
};
Ok(TungsteniteWsConnection {
stream_state: StreamState::from_stream(rustls::StreamOwned::new(
server_session,
tcp_stream,
)),
stream_state,
connection_token,
connection_handler: handler,
is_closed: false,
})
}

fn do_tls_read(&mut self) -> ConnectionState {
let tls_stream = match self.stream_state.internal_stream_mut() {
None => return ConnectionState::Closing,
Some(s) => s,
};

let tls_session = &mut tls_stream.sess;

match tls_session.read_tls(&mut tls_stream.sock) {
Ok(r) =>
if r == 0 {
return ConnectionState::Closing
},
Err(err) => {
if let std::io::ErrorKind::WouldBlock = err.kind() {
debug!("TLS session is blocked (connection {})", self.connection_token.0);
return ConnectionState::Blocked
fn maybe_do_tls_read(&mut self) -> ConnectionState {
match self.stream_state.internal_stream_mut() {
None => ConnectionState::Closing,
Some(MaybeServerTlsStream::Plain(_)) => ConnectionState::Alive, // noop for non-TLS ws server
Some(MaybeServerTlsStream::Rustls(s)) => {
let tls_session = &mut s.sess;
match tls_session.read_tls(&mut s.sock) {
Ok(r) =>
if r == 0 {
return ConnectionState::Closing
},
Err(err) => {
if let std::io::ErrorKind::WouldBlock = err.kind() {
debug!(
"TLS session is blocked (connection {})",
self.connection_token.0
);
return ConnectionState::Blocked
}
warn!(
"I/O error after reading TLS data (connection {}): {:?}",
self.connection_token.0, err
);
return ConnectionState::Closing
},
}
warn!(
"I/O error after reading TLS data (connection {}): {:?}",
self.connection_token.0, err
);
return ConnectionState::Closing
},
}

match tls_session.process_new_packets() {
Ok(_) => {
if tls_session.is_handshaking() {
return ConnectionState::TlsHandshake
match tls_session.process_new_packets() {
Ok(_) => {
if tls_session.is_handshaking() {
return ConnectionState::TlsHandshake
}
ConnectionState::Alive
},
Err(e) => {
error!("cannot process TLS packet(s), closing connection: {:?}", e);
ConnectionState::Closing
},
}
ConnectionState::Alive
},
Err(e) => {
error!("cannot process TLS packet(s), closing connection: {:?}", e);
ConnectionState::Closing
},
}
}

fn do_tls_write(&mut self) -> ConnectionState {
let tls_stream = match self.stream_state.internal_stream_mut() {
None => return ConnectionState::Closing,
Some(s) => s,
};

match tls_stream.sess.write_tls(&mut tls_stream.sock) {
Ok(_) => {
trace!("TLS write successful, connection {} is alive", self.connection_token.0);
if tls_stream.sess.is_handshaking() {
return ConnectionState::TlsHandshake
}
ConnectionState::Alive
},
Err(e) => {
error!("TLS write error (connection {}): {:?}", self.connection_token.0, e);
ConnectionState::Closing
fn maybe_do_tls_write(&mut self) -> ConnectionState {
match self.stream_state.internal_stream_mut() {
None => ConnectionState::Closing,
Some(MaybeServerTlsStream::Plain(_)) => ConnectionState::Alive, // noop for non-TLS ws server
Some(MaybeServerTlsStream::Rustls(s)) => match s.sess.write_tls(&mut s.sock) {
Ok(_) => {
trace!("TLS write successful, connection {} is alive", self.connection_token.0);
if s.sess.is_handshaking() {
return ConnectionState::TlsHandshake
}
ConnectionState::Alive
},
Err(e) => {
error!("TLS write error (connection {}): {:?}", self.connection_token.0, e);
ConnectionState::Closing
},
},
}
}
Expand All @@ -127,7 +131,7 @@ where
///
/// Returns a boolean 'connection should be closed'.
fn read_or_initialize_websocket(&mut self) -> WebSocketResult<bool> {
if let StreamState::EstablishedWebsocket(web_socket) = &mut self.stream_state {
if let StreamState::Established(web_socket) = &mut self.stream_state {
trace!(
"Read is possible for connection {}: {}",
self.connection_token.0,
Expand Down Expand Up @@ -196,7 +200,7 @@ where
"Received close frame, driving web-socket connection {} to close",
self.connection_token.0
);
if let StreamState::EstablishedWebsocket(web_socket) = &mut self.stream_state {
if let StreamState::Established(web_socket) = &mut self.stream_state {
// Send a close frame back and then flush the send queue.
if let Err(e) = web_socket.close(None) {
match e {
Expand Down Expand Up @@ -226,7 +230,7 @@ where

pub(crate) fn write_message(&mut self, message: String) -> WebSocketResult<()> {
match &mut self.stream_state {
StreamState::EstablishedWebsocket(web_socket) => {
StreamState::Established(web_socket) => {
if !web_socket.can_write() {
return Err(WebSocketError::ConnectionClosed)
}
Expand All @@ -248,15 +252,15 @@ where
type Socket = TcpStream;

fn socket(&self) -> Option<&Self::Socket> {
self.stream_state.internal_stream().map(|s| &s.sock)
self.stream_state.internal_stream().map(|s| s.inner())
}

fn get_session_readiness(&self) -> Ready {
match self.stream_state.internal_stream() {
None => mio::Ready::empty(),
Some(s) => {
let wants_read = s.sess.wants_read();
let wants_write = s.sess.wants_write();
let wants_read = s.wants_read();
let wants_write = s.wants_write();

if wants_read && wants_write {
mio::Ready::readable() | mio::Ready::writable()
Expand All @@ -275,7 +279,7 @@ where
if event.readiness().is_readable() {
trace!("Connection ({:?}) is readable", self.token());

let connection_state = self.do_tls_read();
let connection_state = self.maybe_do_tls_read();

if connection_state.is_alive() {
is_closing = self.read_or_initialize_websocket()?;
Expand All @@ -287,10 +291,10 @@ where
if event.readiness().is_writable() {
trace!("Connection ({:?}) is writable", self.token());

let connection_state = self.do_tls_write();
let connection_state = self.maybe_do_tls_write();

if connection_state.is_alive() {
if let StreamState::EstablishedWebsocket(web_socket) = &mut self.stream_state {
if let StreamState::Established(web_socket) = &mut self.stream_state {
trace!("Web-socket, write pending messages");
if let Err(e) = web_socket.write_pending() {
match e {
Expand Down
23 changes: 1 addition & 22 deletions tee-worker/core/tls-websocket-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,11 @@ pub mod sgx_reexport_prelude {
use crate::sgx_reexport_prelude::*;

use crate::{
config_provider::FromFileConfigProvider,
connection_id_generator::{ConnectionId, ConnectionIdGenerator},
error::{WebSocketError, WebSocketResult},
ws_server::TungsteniteWsServer,
};
use mio::{event::Evented, Token};
use std::{
fmt::Debug,
string::{String, ToString},
sync::Arc,
};
use std::{fmt::Debug, string::String};

pub mod certificate_generation;
pub mod config_provider;
Expand Down Expand Up @@ -160,18 +154,3 @@ pub(crate) trait WebSocketConnection: Send + Sync {
}
}
}

pub fn create_ws_server<Handler>(
addr_plain: &str,
private_key: &str,
certificate: &str,
handler: Arc<Handler>,
) -> Arc<TungsteniteWsServer<Handler, FromFileConfigProvider>>
where
Handler: WebSocketMessageHandler,
{
let config_provider =
Arc::new(FromFileConfigProvider::new(private_key.to_string(), certificate.to_string()));

Arc::new(TungsteniteWsServer::new(addr_plain.to_string(), config_provider, handler))
}

0 comments on commit 6e45b49

Please sign in to comment.