Skip to content

Commit

Permalink
Feat: Add config file support (#90)
Browse files Browse the repository at this point in the history
* Feat: Add config file support

* Fix formatting

* Fix clippy warnings

* Add default value for `disable_nodelay` and `fastopen`

* Change default json values to match clap

* Fix default function for listening on [::]:443
  • Loading branch information
erfan-khadem committed May 16, 2023
1 parent 5454cb4 commit 5484ecc
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 8 deletions.
45 changes: 45 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Expand Up @@ -28,6 +28,8 @@ sha2 = "0.10"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
webpki-roots = "0.22"
serde = { version = "1.0.163", features = ["derive"] }
serde_json = "1.0.96"

[profile.release]
lto = true
Expand Down
13 changes: 13 additions & 0 deletions examples/client_config.json
@@ -0,0 +1,13 @@
{
"disable_nodelay": false,
"fastopen": true,
"v3": true,
"strict": true,
"client": {
"listen": "localhost:8080",
"server_addr": "256.256.256.256:443",
"tls_names": "captive.apple.com;cloudflare.com",
"password": "12345678"
}
}

21 changes: 21 additions & 0 deletions examples/server_config.json
@@ -0,0 +1,21 @@
{
"disable_nodelay": false,
"fastopen": true,
"v3": true,
"strict": true,
"server": {
"listen": "localhost:8443",
"server_addr": "localhost:24000",
"tls_addr": {
"wildcard_sni": "off",
"dispatch": {
"cloudflare.com": "1.1.1.1:443",
"captive.apple.com": "captive.apple.com:443"
},
"fallback": "cloud.tencent.com:443"
},
"password": "12345678",
"wildcard_sni": "authed"
}
}

36 changes: 36 additions & 0 deletions src/client.rs
Expand Up @@ -14,6 +14,7 @@ use monoio::{
use monoio_rustls_fork_shadow_tls::TlsConnector;
use rand::{prelude::Distribution, seq::SliceRandom, Rng};
use rustls_fork_shadow_tls::{OwnedTrustAnchor, RootCertStore, ServerName};
use serde::{de::Visitor, Deserialize};

use crate::{
helper_v2::{copy_with_application_data, copy_without_application_data, HashedReadStream},
Expand Down Expand Up @@ -64,6 +65,41 @@ impl TryFrom<&str> for TlsNames {
}
}

struct TlsNamesVisitor;

impl<'de> Visitor<'de> for TlsNamesVisitor {
type Value = TlsNames;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a semicolon seperated list of domains and ip addresses")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match Self::Value::try_from(v) {
Err(e) => Err(E::custom(e.to_string())),
Ok(u) => Ok(u),
}
}

fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Self.visit_str(&v)
}
}

impl<'de> Deserialize<'de> for TlsNames {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_string(TlsNamesVisitor)
}
}

impl std::fmt::Display for TlsNames {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
Expand Down
78 changes: 72 additions & 6 deletions src/main.rs
@@ -1,6 +1,6 @@
#![feature(type_alias_impl_trait)]

use std::{collections::HashMap, process::exit};
use std::{collections::HashMap, path::PathBuf, process::exit};

use clap::{Parser, Subcommand, ValueEnum};
use tracing_subscriber::{filter::LevelFilter, fmt, prelude::*, EnvFilter};
Expand All @@ -10,7 +10,9 @@ use shadow_tls::{
WildcardSNI,
};

#[derive(Parser, Debug)]
use serde::Deserialize;

#[derive(Parser, Debug, Deserialize)]
#[clap(
author,
version,
Expand All @@ -19,34 +21,56 @@ use shadow_tls::{
)]
struct Args {
#[clap(subcommand)]
#[serde(flatten)]
cmd: Commands,
#[clap(flatten)]
#[serde(flatten)]
opts: Opts,
}

#[derive(Parser, Debug, Default, Clone)]
macro_rules! default_function {
($name: ident, $type: ident, $val: expr) => {
fn $name() -> $type {
$val
}
};
}

// default_function!(default_true, bool, true);
default_function!(default_false, bool, false);
default_function!(default_8080, String, "[::1]:8080".to_string());
default_function!(default_443, String, "[::]:443".to_string());
default_function!(default_wildcard_sni, WildcardSNI, WildcardSNI::Off);

#[derive(Parser, Debug, Default, Clone, Deserialize)]
struct Opts {
#[clap(short, long, help = "Set parallelism manually")]
threads: Option<u8>,
#[serde(default = "default_false")]
#[clap(long, help = "Disable TCP_NODELAY")]
disable_nodelay: bool,
#[serde(default = "default_false")]
#[clap(long, help = "Enable TCP_FASTOPEN")]
fastopen: bool,
#[serde(default = "default_false")]
#[clap(long, help = "Use v3 protocol")]
v3: bool,
#[serde(default = "default_false")]
#[clap(long, help = "Strict mode(only for v3 protocol)")]
strict: bool,
}

#[derive(Subcommand, Debug)]
#[derive(Subcommand, Debug, Deserialize)]
enum Commands {
#[clap(about = "Run client side")]
#[serde(rename = "client")]
Client {
#[clap(
long = "listen",
default_value = "[::1]:8080",
help = "Shadow-tls client listen address(like \"[::1]:8080\")"
)]
#[serde(default = "default_8080")]
listen: String,
#[clap(
long = "server",
Expand All @@ -69,12 +93,14 @@ enum Commands {
alpn: Option<Vec<String>>,
},
#[clap(about = "Run server side")]
#[serde(rename = "server")]
Server {
#[clap(
long = "listen",
default_value = "[::]:443",
help = "Shadow-tls server listen address(like \"[::]:443\")"
)]
#[serde(default = "default_443")]
listen: String,
#[clap(
long = "server",
Expand All @@ -94,8 +120,15 @@ enum Commands {
default_value = "off",
help = "Use sni:443 as handshake server without predefining mapping(useful for bypass billing system like airplane wifi without modifying server config)"
)]
#[serde(default = "default_wildcard_sni")]
wildcard_sni: WildcardSNI,
},
#[serde(skip)]
Config {
#[serde(skip)]
#[clap(short, long, value_name = "FILE", help = "Path to config file")]
config: PathBuf,
},
}

fn parse_client_names(addrs: &str) -> anyhow::Result<TlsNames> {
Expand All @@ -106,6 +139,23 @@ fn parse_server_addrs(arg: &str) -> anyhow::Result<TlsAddrs> {
TlsAddrs::try_from(arg)
}

fn read_config_file(filename: String) -> Args {
let file = std::fs::File::open(filename);
match file {
Err(e) => {
tracing::error!("cannot open config file: {}", e);
exit(-1);
}
Ok(f) => match serde_json::from_reader(f) {
Err(e) => {
tracing::error!("cannot read config file: {}", e);
exit(-1);
}
Ok(res) => res,
},
}
}

impl From<Args> for RunningArgs {
fn from(args: Args) -> Self {
let v3 = match (args.opts.v3, args.opts.strict) {
Expand Down Expand Up @@ -149,6 +199,9 @@ impl From<Args> for RunningArgs {
v3,
}
}
Commands::Config { config: _ } => {
unreachable!()
}
}
}
}
Expand All @@ -173,8 +226,18 @@ pub(crate) fn get_sip003_arg() -> Option<Args> {
Some(val) => val,
}
};
(optional $key: expr) => {
match std::env::var($key).ok() {
None => "".to_string(),
Some(val) if val.is_empty() => "".to_string(),
Some(val) => val,
}
};
}
let config_file = env!(optional "CONFIG_FILE");
if !config_file.is_empty() {
return Some(read_config_file(config_file));
}

let ss_remote_host = env!("SS_REMOTE_HOST");
let ss_remote_port = env!("SS_REMOTE_PORT");
let ss_local_host = env!("SS_LOCAL_HOST");
Expand Down Expand Up @@ -246,7 +309,10 @@ fn main() {
.add_directive("rustls=off".parse().unwrap()),
)
.init();
let args = get_sip003_arg().unwrap_or_else(Args::parse);
let mut args = get_sip003_arg().unwrap_or_else(Args::parse);
if let Commands::Config { config } = args.cmd {
args = read_config_file(config.to_str().unwrap().to_string());
}
let parallelism = get_parallelism(&args);
let running_args = RunningArgs::from(args);
tracing::info!("Start {parallelism}-thread {running_args}");
Expand Down
3 changes: 2 additions & 1 deletion src/server.rs
Expand Up @@ -17,6 +17,7 @@ use monoio::{
},
net::TcpStream,
};
use serde::Deserialize;

use crate::{
helper_v2::{
Expand All @@ -42,7 +43,7 @@ pub struct ShadowTlsServer<LA, TA> {
v3: V3Mode,
}

#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Deserialize)]
pub struct TlsAddrs {
dispatch: rustc_hash::FxHashMap<String, String>,
fallback: String,
Expand Down
6 changes: 5 additions & 1 deletion src/util.rs
Expand Up @@ -15,6 +15,7 @@ use monoio::{

use hmac::Mac;
use rand::Rng;
use serde::Deserialize;
use sha2::{Digest, Sha256};

use prelude::*;
Expand Down Expand Up @@ -73,13 +74,16 @@ impl V3Mode {
}
}

#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, clap::ValueEnum)]
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, clap::ValueEnum, Deserialize)]
pub enum WildcardSNI {
/// Disabled
#[serde(rename = "off")]
Off,
/// For authenticated client only(may be differentiable); in v2 protocol it is eq to all.
#[serde(rename = "authed")]
Authed,
/// For all request(may cause service abused but not differentiable)
#[serde(rename = "all")]
All,
}

Expand Down

0 comments on commit 5484ecc

Please sign in to comment.