Skip to content

Commit

Permalink
Bump
Browse files Browse the repository at this point in the history
  • Loading branch information
erebe committed Oct 16, 2023
1 parent e6c8f9d commit 3496804
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 95 deletions.
195 changes: 120 additions & 75 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,20 @@ use clap::Parser;
use futures_util::{pin_mut, stream, Stream, StreamExt, TryStreamExt};

Check warning on line 12 in src/main.rs

View workflow job for this annotation

GitHub Actions / Build - Windows x86_64

unused import: `stream`
use hyper::http::HeaderValue;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap};
use std::io;
use std::io::ErrorKind;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4};
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;

use tokio_rustls::rustls::server::DnsName;
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerName};

use tracing::{debug, error, instrument, Instrument, Span};
use tracing::{debug, error, span, Instrument, Level};

use tracing_subscriber::EnvFilter;
use url::{Host, Url};
Expand All @@ -52,7 +50,7 @@ enum Commands {
struct Client {
/// Listen on local and forwards traffic from remote
/// Can be specified multiple times
#[arg(short='L', long, value_name = "{tcp,udp}://[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)]
#[arg(short='L', long, value_name = "{tcp,udp,socks5}://[BIND:]PORT:HOST:PORT", value_parser = parse_tunnel_arg)]
local_to_remote: Vec<LocalToRemote>,

/// (linux only) Mark network packet with SO_MARK sockoption with the specified value.
Expand Down Expand Up @@ -138,24 +136,17 @@ struct Server {
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
enum L4Protocol {
enum LocalProtocol {
Tcp,
Udp { timeout: Option<Duration> },
Stdio,
}

impl L4Protocol {
fn new_udp() -> L4Protocol {
L4Protocol::Udp {
timeout: Some(Duration::from_secs(30)),
}
}
Socks5,
}

#[derive(Clone, Debug)]
pub struct LocalToRemote {
socket_so_mark: Option<i32>,
protocol: L4Protocol,
local_protocol: LocalProtocol,
local: SocketAddr,
remote: (Host<String>, u16),
}
Expand All @@ -173,18 +164,9 @@ fn parse_duration_sec(arg: &str) -> Result<Duration, io::Error> {
Ok(Duration::from_secs(secs))
}

fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
fn parse_local_bind(arg: &str) -> Result<(SocketAddr, &str), io::Error> {
use std::io::Error;

let (mut protocol, arg) = match &arg[..6] {
"tcp://" => (L4Protocol::Tcp, &arg[6..]),
"udp://" => (L4Protocol::new_udp(), &arg[6..]),
_ => match &arg[..8] {
"stdio://" => (L4Protocol::Stdio, &arg[8..]),
_ => (L4Protocol::Tcp, arg),
},
};

let (bind, remaining) = if arg.starts_with('[') {
// ipv6 bind
let Some((ipv6_str, remaining)) = arg.split_once(']') else {
Expand Down Expand Up @@ -217,12 +199,8 @@ fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
}
};

let Some((port_str, remaining)) = remaining.trim_start_matches(':').split_once(':') else {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("cannot parse bind port from {}", remaining),
));
};
let remaining = remaining.trim_start_matches(':');
let (port_str, remaining) = remaining.split_once([':', '?']).unwrap_or((remaining, ""));

let Ok(bind_port): Result<u16, _> = port_str.parse() else {
return Err(Error::new(
Expand All @@ -231,6 +209,14 @@ fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
));
};

Ok((SocketAddr::new(bind, bind_port), remaining))
}

fn parse_tunnel_dest(
remaining: &str,
) -> Result<(Host<String>, u16, BTreeMap<String, String>), io::Error> {
use std::io::Error;

let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
return Err(Error::new(
ErrorKind::InvalidInput,
Expand All @@ -252,14 +238,30 @@ fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
));
};

let options: BTreeMap<Cow<'_, str>, Cow<'_, str>> = remote.query_pairs().collect();
match &mut protocol {
L4Protocol::Stdio => {}
L4Protocol::Tcp => {}
L4Protocol::Udp {
ref mut timeout, ..
} => {
if let Some(duration) = options
let options: BTreeMap<String, String> = remote.query_pairs().into_owned().collect();
Ok((remote_host.to_owned(), remote_port, options))
}

fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
use std::io::Error;

match &arg[..6] {
"tcp://" => {
let (local_bind, remaining) = parse_local_bind(&arg[6..])?;
let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?;
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Tcp,
local: local_bind,
remote: (dest_host, dest_port),
})
}
"udp://" => {
let (local_bind, remaining) = parse_local_bind(&arg[6..])?;
let (dest_host, dest_port, options) = parse_tunnel_dest(remaining)?;
let timeout = options
.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|d| {
Expand All @@ -269,20 +271,48 @@ fn parse_env_var(arg: &str) -> Result<LocalToRemote, io::Error> {
Some(Duration::from_secs(d))
}
})
{
*timeout = duration;
}
.unwrap_or(Some(Duration::from_secs(30)));

Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Udp { timeout },
local: local_bind,
remote: (dest_host, dest_port),
})
}
};

Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
protocol,
local: SocketAddr::new(bind, bind_port),
remote: (remote_host.to_owned(), remote_port),
})
_ => match &arg[..8] {
"socks5:/" => {
let (local_bind, remaining) = parse_local_bind(&arg[9..])?;
let x = format!("0.0.0.0:0?{}", remaining);
let (dest_host, dest_port, options) = parse_tunnel_dest(&x)?;
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Socks5,
local: local_bind,
remote: (dest_host, dest_port),
})
}
"stdio://" => {
let (dest_host, dest_port, options) = parse_tunnel_dest(&arg[8..])?;
Ok(LocalToRemote {
socket_so_mark: options
.get("socket_so_mark")
.and_then(|x| x.parse::<i32>().ok()),
local_protocol: LocalProtocol::Stdio,
local: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(0), 0)),
remote: (dest_host, dest_port),
})
}
_ => Err(Error::new(
ErrorKind::InvalidInput,
format!("Invalid local protocol for tunnel {}", arg),
)),
},
}
}

fn parse_sni_override(arg: &str) -> Result<DnsName, io::Error> {
Expand Down Expand Up @@ -432,7 +462,7 @@ async fn main() {
if args
.local_to_remote
.iter()
.filter(|x| x.protocol == L4Protocol::Stdio)
.filter(|x| x.local_protocol == LocalProtocol::Stdio)
.count()
> 0 => {}
_ => {
Expand Down Expand Up @@ -474,36 +504,54 @@ async fn main() {
for tunnel in args.local_to_remote.into_iter() {
let server_config = server_config.clone();

match &tunnel.protocol {
L4Protocol::Tcp => {
match &tunnel.local_protocol {
LocalProtocol::Tcp => {
let remote = tunnel.remote.clone();
let server = tcp::run_server(tunnel.local)
.await
.unwrap_or_else(|err| {
panic!("Cannot start TCP server on {}: {}", tunnel.local, err)
})
.map_ok(TcpStream::into_split);
.map_err(anyhow::Error::new)
.map_ok(move |stream| (stream.into_split(), remote.clone()));

tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
error!("{:?}", err);
}
});
}
L4Protocol::Udp { timeout } => {
LocalProtocol::Udp { timeout } => {
let remote = tunnel.remote.clone();
let server = udp::run_server(tunnel.local, *timeout)
.await
.unwrap_or_else(|err| {
panic!("Cannot start UDP server on {}: {}", tunnel.local, err)
})
.map_ok(tokio::io::split);
.map_err(anyhow::Error::new)
.map_ok(move |stream| (tokio::io::split(stream), remote.clone()));

tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
error!("{:?}", err);
}
});
}
LocalProtocol::Socks5 => {
let server = socks5::run_server(tunnel.local)
.await
.unwrap_or_else(|err| {
panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)
})
.map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest));

tokio::spawn(async move {
if let Err(err) = run_tunnel(server_config, tunnel, server).await {
error!("{:?}", err);
}
});
}
L4Protocol::Stdio => {
LocalProtocol::Stdio => {
#[cfg(target_family = "unix")]
{
let server = stdio::run_server().await.unwrap_or_else(|err| {
Expand All @@ -512,8 +560,8 @@ async fn main() {
tokio::spawn(async move {
if let Err(err) = run_tunnel(
server_config,
tunnel,
stream::once(async move { Ok(server) }),
tunnel.clone(),
stream::once(async move { Ok((server, tunnel.remote)) }),
)
.await
{
Expand Down Expand Up @@ -573,31 +621,28 @@ async fn main() {
tokio::signal::ctrl_c().await.unwrap();
}

#[instrument(name="tunnel", level="info", skip_all, fields(id=tracing::field::Empty, remote=tracing::field::Empty))]
async fn run_tunnel<T, R, W>(
server_config: Arc<WsClientConfig>,
tunnel: LocalToRemote,
incoming_cnx: T,
) -> anyhow::Result<()>
where
T: Stream<Item = io::Result<(R, W)>>,
T: Stream<Item = anyhow::Result<((R, W), (Host, u16))>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
let span = Span::current();
let request_id = Uuid::now_v7();
span.record("id", request_id.to_string());
span.record(
"remote",
&format!("{}:{}", tunnel.remote.0, tunnel.remote.1),
);

let tunnel = Arc::new(tunnel);
pin_mut!(incoming_cnx);

while let Some(Ok(cnx_stream)) = incoming_cnx.next().await {
while let Some(Ok((cnx_stream, remote_dest))) = incoming_cnx.next().await {
let request_id = Uuid::now_v7();
let span = span!(
Level::INFO,
"tunnel",
id = request_id.to_string(),
remote = format!("{}:{}", remote_dest.0, remote_dest.1)
);
let server_config = server_config.clone();
let tunnel = tunnel.clone();
let mut tunnel = tunnel.clone();
tunnel.remote = remote_dest;

tokio::spawn(
async move {
Expand Down
Loading

0 comments on commit 3496804

Please sign in to comment.