Skip to content
This repository has been archived by the owner on Aug 15, 2023. It is now read-only.

Commit

Permalink
Implement port forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
gifnksm committed Apr 25, 2021
1 parent b41e2ba commit 5b0c95a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 26 deletions.
3 changes: 2 additions & 1 deletion src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,15 @@ pub(crate) enum StreamId {
Stdin,
Stdout,
Stderr,
Forward(PortId, ConnId),
}

#[derive(Debug, Deserialize, Serialize, From)]
pub(crate) enum StreamAction {
Source(StreamId, SourceAction),
Sink(StreamId, SinkAction),
Listener(PortId, ListenerAction),
Connecter((PortId, ConnId), ConnecterAction),
Connecter(StreamId, ConnecterAction),
}

#[derive(Debug, Deserialize, Serialize, From)]
Expand Down
7 changes: 6 additions & 1 deletion src/server/execute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,12 @@ async fn serve(stream: TcpStream, mut param: ServeParam) -> Result<()> {
for (port_id, connect_addr) in (0..).map(PortId::new).zip(param.connect_addrs) {
let send_msg_tx = send_msg_tx.clone();
let span = info_span!("connecter", %connect_addr);
let task = connecter::Task::new(port_id, connect_addr, send_msg_tx.clone(), &recv_router);
let task = connecter::Task::new(
port_id,
connect_addr,
send_msg_tx.clone(),
recv_router.clone(),
);
let _ = task.spawn(span);
}

Expand Down
24 changes: 17 additions & 7 deletions src/stream/connecter.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::{
net::SocketAddrs,
prelude::*,
protocol::{ConnecterAction, ListenerAction, PortId, Response, StreamAction},
stream::RecvRouter,
protocol::{ConnecterAction, ListenerAction, PortId, Response, StreamAction, StreamId},
stream::{sink, source, RecvRouter},
};
use std::fmt::Debug;
use std::{fmt::Debug, sync::Arc};
use tokio::sync::mpsc;
use tracing::Span;

Expand All @@ -29,6 +29,7 @@ pub(crate) struct Task<T> {
addr: SocketAddrs,
tx: mpsc::Sender<T>,
rx: mpsc::Receiver<ListenerAction>,
recv_router: Arc<RecvRouter>,
}

impl<T> Task<T>
Expand All @@ -39,7 +40,7 @@ where
port_id: PortId,
addr: SocketAddrs,
tx: mpsc::Sender<T>,
recv_router: &RecvRouter,
recv_router: Arc<RecvRouter>,
) -> Self {
let (listen_tx, rx) = mpsc::channel(128);
recv_router.insert_connecter_tx(port_id, Sender(listen_tx));
Expand All @@ -48,6 +49,7 @@ where
addr,
tx,
rx,
recv_router,
}
}

Expand All @@ -63,6 +65,7 @@ where
addr,
tx,
mut rx,
recv_router,
} = self;

trace!("started");
Expand All @@ -71,13 +74,14 @@ where
trace!(?msg);
match msg {
ListenerAction::Connect(conn_id) => {
let resp_id = (port_id, conn_id);
let id = StreamId::from((port_id, conn_id));
let addr = addr.clone();
let tx = tx.clone();
let recv_router = recv_router.clone();
let _ = tokio::spawn(async move {
let res = addr.connect().await;
let resp = ConnecterAction::ConnectResponse(Response::new(&res));
let msg = T::from((resp_id, resp).into());
let msg = T::from((id, resp).into());
tx.send(msg).await?;
let stream = match res {
Ok(stream) => stream,
Expand All @@ -86,7 +90,13 @@ where
bail!(err);
}
};
// TODO

let (reader, writer) = stream.into_split();
let _ = source::Task::new(id, reader, tx.clone(), &recv_router)
.spawn(info_span!("forward_source", ?id));
let _ = sink::Task::new(id, writer, tx.clone(), &recv_router)
.spawn(info_span!("forward_sink", ?id));

Ok::<(), Error>(())
});
}
Expand Down
12 changes: 9 additions & 3 deletions src/stream/listener.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
prelude::*,
protocol::{ConnId, ConnecterAction, ListenerAction, PortId, StreamAction},
stream::RecvRouter,
protocol::{ConnId, ConnecterAction, ListenerAction, PortId, StreamAction, StreamId},
stream::{sink, source, RecvRouter},
};
use std::{fmt::Debug, sync::Arc};
use tokio::{net::TcpListener, sync::mpsc};
Expand Down Expand Up @@ -48,7 +48,7 @@ where
} = self;

for conn_id in (0..).map(ConnId::new) {
let id = (port_id, conn_id);
let id = StreamId::Forward(port_id, conn_id);
let (conn_tx, mut rx) = mpsc::channel(1);
recv_router.insert_listener_tx(id, conn_tx);
let (stream, peer_addr) = match listener.accept().await {
Expand All @@ -71,6 +71,12 @@ where
})?;
recv_router.remove_listener_tx(id);
debug!("connected");

let (reader, writer) = stream.into_split();
let _ = source::Task::new(id, reader, tx.clone(), &recv_router)
.spawn(info_span!("forward_source", ?id));
let _ = sink::Task::new(id, writer, tx.clone(), &recv_router)
.spawn(info_span!("forward_sink", ?id));
}

Ok(())
Expand Down
19 changes: 5 additions & 14 deletions src/stream/recv_router.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::{
prelude::*,
protocol::{
ConnId, ConnecterAction, ListenerAction, PortId, SinkAction, SourceAction, StreamAction,
StreamId,
ConnecterAction, ListenerAction, PortId, SinkAction, SourceAction, StreamAction, StreamId,
},
stream::{connecter, sink, source},
};
Expand All @@ -15,7 +14,7 @@ pub(crate) struct RecvRouter {
sink_tx_map: Mutex<HashMap<StreamId, sink::Sender>>,
source_tx_map: Mutex<HashMap<StreamId, source::Sender>>,
connecter_tx_map: Mutex<HashMap<PortId, connecter::Sender>>,
listener_tx_map: Mutex<HashMap<(PortId, ConnId), mpsc::Sender<ConnecterAction>>>,
listener_tx_map: Mutex<HashMap<StreamId, mpsc::Sender<ConnecterAction>>>,
}

impl RecvRouter {
Expand All @@ -35,15 +34,11 @@ impl RecvRouter {
assert!(self.connecter_tx_map.lock().insert(id, tx).is_none())
}

pub(crate) fn insert_listener_tx(
&self,
id: (PortId, ConnId),
tx: mpsc::Sender<ConnecterAction>,
) {
pub(crate) fn insert_listener_tx(&self, id: StreamId, tx: mpsc::Sender<ConnecterAction>) {
assert!(self.listener_tx_map.lock().insert(id, tx).is_none())
}

pub(crate) fn remove_listener_tx(&self, id: (PortId, ConnId)) {
pub(crate) fn remove_listener_tx(&self, id: StreamId) {
assert!(self.listener_tx_map.lock().remove(&id).is_some())
}

Expand Down Expand Up @@ -121,11 +116,7 @@ impl RecvRouter {
Ok(())
}

async fn send_connecter_action(
&self,
id: (PortId, ConnId),
action: ConnecterAction,
) -> Result<()> {
async fn send_connecter_action(&self, id: StreamId, action: ConnecterAction) -> Result<()> {
let tx = self
.listener_tx_map
.lock()
Expand Down

0 comments on commit 5b0c95a

Please sign in to comment.