Skip to content

Commit

Permalink
fix: work around clap-rs/clap#5127
Browse files Browse the repository at this point in the history
  • Loading branch information
ctron committed Sep 15, 2023
1 parent 3a4fd97 commit 540a838
Showing 1 changed file with 103 additions and 17 deletions.
120 changes: 103 additions & 17 deletions infrastructure/src/app/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use actix_web::{
use actix_web_prom::{PrometheusMetrics, PrometheusMetricsBuilder};
use anyhow::{anyhow, Context};
use bytesize::ByteSize;
use clap::{value_parser, Arg, ArgMatches, Args, Command, Error, FromArgMatches};
use openssl::ssl::SslFiletype;
use prometheus::Registry;
use std::fmt::{Debug, Display, Formatter};
Expand Down Expand Up @@ -53,6 +54,67 @@ impl FromStr for BinaryByteSize {
}
}

#[derive(Clone, Debug)]
pub struct BindPort<E: Endpoint> {
/// The port to listen on
pub bind_port: u16,

_marker: Marker<E>,
}

impl<E: Endpoint> Deref for BindPort<E> {
type Target = u16;

fn deref(&self) -> &Self::Target {
&self.bind_port
}
}

impl<E: Endpoint> Default for BindPort<E> {
fn default() -> Self {
Self {
bind_port: E::PORT,
_marker: Default::default(),
}
}
}

impl<E: Endpoint> Args for BindPort<E> {
fn augment_args(cmd: Command) -> Command {
Self::augment_args_for_update(cmd)
}

fn augment_args_for_update(cmd: Command) -> Command {
cmd.arg(
Arg::new("http-server-bind-port")
.short('p')
.long("http-server-bind-port")
.help("The port to listen on")
.value_parser(value_parser!(u16))
.default_value(E::PORT.to_string()),
)
}
}

impl<E: Endpoint> FromArgMatches for BindPort<E> {
fn from_arg_matches(matches: &ArgMatches) -> Result<Self, Error> {
Ok(Self {
bind_port: matches
.get_one::<u16>("http-server-bind-port")
.cloned()
.unwrap_or(E::port()),
_marker: Default::default(),
})
}

fn update_from_arg_matches(&mut self, matches: &ArgMatches) -> Result<(), Error> {
if let Some(port) = matches.get_one::<u16>("port") {
self.bind_port = *port;
}
Ok(())
}
}

#[derive(Clone, Debug, clap::Args)]
#[command(rename_all_env = "SCREAMING_SNAKE_CASE", next_help_heading = "HTTP endpoint")]
pub struct HttpServerConfig<E>
Expand All @@ -72,14 +134,9 @@ where
)]
pub bind_addr: String,

/// The port to listen on
#[arg(
id = "http-server-bind-port",
long,
env = "HTTP_SERVER_BIND_PORT",
default_value_t = E::PORT
)]
pub bind_port: u16,
// This is required due to: https://github.com/clap-rs/clap/issues/5127
#[command(flatten)]
pub bind_port: BindPort<E>,

/// The overall request limit
#[arg(
Expand Down Expand Up @@ -129,7 +186,7 @@ mod default {
use super::*;

pub fn bind_addr() -> String {
"[::1]".to_string()
"::1".to_string()
}

pub const fn request_limit() -> BinaryByteSize {
Expand All @@ -149,7 +206,7 @@ where
Self {
workers: 0,
bind_addr: default::bind_addr().to_string(),
bind_port: E::PORT,
bind_port: BindPort::<E>::default(),
request_limit: default::request_limit(),
json_limit: default::json_limit(),
tls_enabled: false,
Expand Down Expand Up @@ -182,9 +239,14 @@ where
type Error = anyhow::Error;

fn try_from(value: HttpServerConfig<E>) -> Result<Self, Self::Error> {
let addr = SocketAddr::new(
IpAddr::from_str(&value.bind_addr).context("parse bind address")?,
value.bind_port.bind_port,
);

let mut result = HttpServerBuilder::new()
.workers(value.workers)
.bind(SocketAddr::from_str(&value.bind_addr).context("parse bind address")?)
.bind(addr)
.request_limit(value.request_limit.0 .0 as _)
.json_limit(value.json_limit.0 .0 as _);

Expand Down Expand Up @@ -374,14 +436,20 @@ impl HttpServerBuilder {
});

if self.workers > 0 {
log::info!("Using {} worker(s)", self.workers);
http = http.workers(self.workers);
}

let tls = match self.tls {
Some(tls) => {
log::info!("Enabling TLS support");
let mut acceptor = SslAcceptor::mozilla_modern_v5(SslMethod::tls_server())?;
acceptor.set_certificate_chain_file(tls.certificate)?;
acceptor.set_private_key_file(tls.key, SslFiletype::PEM)?;
acceptor
.set_certificate_chain_file(tls.certificate)
.context("setting certificate chain")?;
acceptor
.set_private_key_file(tls.key, SslFiletype::PEM)
.context("setting private key")?;
Some(acceptor)
}
None => None,
Expand All @@ -391,19 +459,37 @@ impl HttpServerBuilder {
Bind::Listener(listener) => {
log::info!("Binding to provided listener: {listener:?}");
http = match tls {
Some(tls) => http.listen_openssl(listener, tls)?,
None => http.listen(listener)?,
Some(tls) => http.listen_openssl(listener, tls).context("listen with TLS")?,
None => http.listen(listener).context("listen")?,
};
}
Bind::Address(addr) => {
log::info!("Binding to: {addr}");
http = match tls {
Some(tls) => http.bind_openssl(addr, tls)?,
None => http.bind(addr)?,
Some(tls) => http.bind_openssl(addr, tls).context("bind with TLS")?,
None => http.bind(addr).context("bind")?,
};
}
}

Ok(http.run().await?)
}
}

#[cfg(test)]
mod test {
use super::*;

#[derive(Debug)]
pub struct MockEndpoint;

impl Endpoint for MockEndpoint {
const PORT: u16 = 1234;
const PATH: &'static str = "";
}

#[test]
fn default_config_converts() {
HttpServerBuilder::try_from(HttpServerConfig::<MockEndpoint>::default()).unwrap();
}
}

0 comments on commit 540a838

Please sign in to comment.