-
Notifications
You must be signed in to change notification settings - Fork 16
/
main.rs
201 lines (178 loc) · 6.77 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
use futures::FutureExt;
use getopts::Options;
use std::env;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::broadcast;
type BoxedError = Box<dyn std::error::Error + Sync + Send + 'static>;
static DEBUG: AtomicBool = AtomicBool::new(false);
const BUF_SIZE: usize = 1024;
fn print_usage(program: &str, opts: Options) {
let program_path = std::path::PathBuf::from(program);
let program_name = program_path.file_stem().unwrap().to_string_lossy();
let brief = format!(
"Usage: {} REMOTE_HOST:PORT [-b BIND_ADDR] [-l LOCAL_PORT]",
program_name
);
print!("{}", opts.usage(&brief));
}
#[tokio::main]
async fn main() -> Result<(), BoxedError> {
let args: Vec<String> = env::args().collect();
let program = args[0].clone();
let mut opts = Options::new();
opts.optopt(
"b",
"bind",
"The address on which to listen for incoming requests, defaulting to localhost",
"BIND_ADDR",
);
opts.optopt(
"l",
"local-port",
"The local port to which tcpproxy should bind to, randomly chosen otherwise",
"LOCAL_PORT",
);
opts.optflag("d", "debug", "Enable debug mode");
let matches = match opts.parse(&args[1..]) {
Ok(opts) => opts,
Err(e) => {
eprintln!("{}", e);
print_usage(&program, opts);
std::process::exit(-1);
}
};
let remote = match matches.free.len() {
1 => matches.free[0].clone(),
_ => {
print_usage(&program, opts);
std::process::exit(-1);
}
};
if !remote.contains(':') {
eprintln!("A remote port is required (REMOTE_ADDR:PORT)");
std::process::exit(-1);
}
DEBUG.store(matches.opt_present("d"), Ordering::Relaxed);
// let local_port: i32 = matches.opt_str("l").unwrap_or("0".to_string()).parse()?;
let local_port: i32 = matches.opt_str("l").map(|s| s.parse()).unwrap_or(Ok(0))?;
let bind_addr = match matches.opt_str("b") {
Some(addr) => addr,
None => "127.0.0.1".to_owned(),
};
forward(&bind_addr, local_port, remote).await
}
async fn forward(bind_ip: &str, local_port: i32, remote: String) -> Result<(), BoxedError> {
// Listen on the specified IP and port
let bind_addr = if !bind_ip.starts_with('[') && bind_ip.contains(':') {
// Correctly format for IPv6 usage
format!("[{}]:{}", bind_ip, local_port)
} else {
format!("{}:{}", bind_ip, local_port)
};
let bind_sock = bind_addr
.parse::<std::net::SocketAddr>()
.expect("Failed to parse bind address");
let listener = TcpListener::bind(&bind_sock).await?;
println!("Listening on {}", listener.local_addr().unwrap());
// `remote` should be either the host name or ip address, with the port appended.
// It doesn't get tested/validated until we get our first connection, though!
// We leak `remote` instead of wrapping it in an Arc to share it with future tasks since
// `remote` is going to live for the lifetime of the server in all cases.
// (This reduces MESI/MOESI cache traffic between CPU cores.)
let remote: &str = Box::leak(remote.into_boxed_str());
async fn copy_with_abort<R, W>(
read: &mut R,
write: &mut W,
mut abort: broadcast::Receiver<()>,
) -> tokio::io::Result<usize>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
let mut copied = 0;
let mut buf = [0u8; BUF_SIZE];
loop {
let bytes_read;
tokio::select! {
biased;
result = read.read(&mut buf) => {
use std::io::ErrorKind::{ConnectionReset, ConnectionAborted};
bytes_read = result.or_else(|e| match e.kind() {
// Consider these to be part of the proxy life, not errors
ConnectionReset | ConnectionAborted => Ok(0),
_ => Err(e)
})?;
},
_ = abort.recv() => {
break;
}
}
if bytes_read == 0 {
break;
}
write.write_all(&buf[0..bytes_read]).await?;
copied += bytes_read;
}
Ok(copied)
}
loop {
let (mut client, client_addr) = listener.accept().await?;
tokio::spawn(async move {
println!("New connection from {}", client_addr);
// Establish connection to upstream for each incoming client connection
let mut remote = match TcpStream::connect(remote).await {
Ok(result) => result,
Err(e) => {
eprintln!("Error establishing upstream connection: {e}");
return;
}
};
let (mut client_read, mut client_write) = client.split();
let (mut remote_read, mut remote_write) = remote.split();
let (cancel, _) = broadcast::channel::<()>(1);
let (remote_copied, client_copied) = tokio::join! {
copy_with_abort(&mut remote_read, &mut client_write, cancel.subscribe())
.then(|r| { let _ = cancel.send(()); async { r } }),
copy_with_abort(&mut client_read, &mut remote_write, cancel.subscribe())
.then(|r| { let _ = cancel.send(()); async { r } }),
};
match client_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!(
"Transferred {} bytes from remote client {} to upstream server",
count, client_addr
);
}
}
Err(err) => {
eprintln!(
"Error writing bytes from remote client {} to upstream server",
client_addr
);
eprintln!("{}", err);
}
};
match remote_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!(
"Transferred {} bytes from upstream server to remote client {}",
count, client_addr
);
}
}
Err(err) => {
eprintln!(
"Error writing from upstream server to remote client {}!",
client_addr
);
eprintln!("{}", err);
}
};
()
});
}
}