Skip to content

Commit

Permalink
feat: add api_key for request authorization (#211)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Mar 21, 2024
1 parent 1d6f288 commit 5e60d06
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 16 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,21 @@ Options:
[env: PAYLOAD_LIMIT=]
[default: 2000000]
--api-key <API_KEY>
Set an api key for request authorization.
By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.
[env: API_KEY=]
--json-output
Outputs the logs in JSON format (useful for telemetry)
[env: JSON_OUTPUT=]
--otlp-endpoint <OTLP_ENDPOINT>
The grpc endpoint for opentelemetry. Telemetry is sent to this endpoint as OTLP over gRPC.
e.g. `http://localhost:4317`
The grpc endpoint for opentelemetry. Telemetry is sent to this endpoint as OTLP over gRPC. e.g. `http://localhost:4317`
[env: OTLP_ENDPOINT=]
--cors-allow-origin <CORS_ALLOW_ORIGIN>
Expand Down
17 changes: 17 additions & 0 deletions docs/source/en/cli_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,29 @@ Options:
[env: HUGGINGFACE_HUB_CACHE=/data]
--payload-limit <PAYLOAD_LIMIT>
Payload size limit in bytes
Default is 2MB
[env: PAYLOAD_LIMIT=]
[default: 2000000]
--api-key <API_KEY>
Set an api key for request authorization.
By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.
[env: API_KEY=]
--json-output
Outputs the logs in JSON format (useful for telemetry)
[env: JSON_OUTPUT=]
--otlp-endpoint <OTLP_ENDPOINT>
The grpc endpoint for opentelemetry. Telemetry is sent to this endpoint as OTLP over gRPC. e.g. `http://localhost:4317`
[env: OTLP_ENDPOINT=]
--cors-allow-origin <CORS_ALLOW_ORIGIN>
Expand Down
50 changes: 40 additions & 10 deletions router/src/grpc/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,7 @@ pub async fn run(
info: Info,
addr: SocketAddr,
prom_builder: PrometheusBuilder,
api_key: Option<String>,
) -> Result<(), anyhow::Error> {
prom_builder.install()?;
tracing::info!("Serving Prometheus metrics: 0.0.0.0:9000");
Expand Down Expand Up @@ -1431,17 +1432,46 @@ pub async fn run(
let service = TextEmbeddingsService::new(infer, info);

// Create gRPC server
let server = if let Some(api_key) = api_key {
let mut prefix = "Bearer ".to_string();
prefix.push_str(&api_key);

// Leak to allow FnMut
let api_key: &'static str = prefix.leak();

let auth = move |req: Request<()>| -> Result<Request<()>, Status> {
match req.metadata().get("authorization") {
Some(t) if t == api_key => Ok(req),
_ => Err(Status::unauthenticated("No valid auth token")),
}
};

Server::builder()
.add_service(health_service)
.add_service(reflection_service)
.add_service(grpc::InfoServer::with_interceptor(service.clone(), auth))
.add_service(grpc::TokenizeServer::with_interceptor(
service.clone(),
auth,
))
.add_service(grpc::EmbedServer::with_interceptor(service.clone(), auth))
.add_service(grpc::PredictServer::with_interceptor(service.clone(), auth))
.add_service(grpc::RerankServer::with_interceptor(service, auth))
.serve_with_shutdown(addr, shutdown::shutdown_signal())
} else {
Server::builder()
.add_service(health_service)
.add_service(reflection_service)
.add_service(grpc::InfoServer::new(service.clone()))
.add_service(grpc::TokenizeServer::new(service.clone()))
.add_service(grpc::EmbedServer::new(service.clone()))
.add_service(grpc::PredictServer::new(service.clone()))
.add_service(grpc::RerankServer::new(service))
.serve_with_shutdown(addr, shutdown::shutdown_signal())
};

tracing::info!("Starting gRPC server: {}", &addr);
Server::builder()
.add_service(health_service)
.add_service(reflection_service)
.add_service(grpc::InfoServer::new(service.clone()))
.add_service(grpc::TokenizeServer::new(service.clone()))
.add_service(grpc::EmbedServer::new(service.clone()))
.add_service(grpc::PredictServer::new(service.clone()))
.add_service(grpc::RerankServer::new(service))
.serve_with_shutdown(addr, shutdown::shutdown_signal())
.await?;
server.await?;

Ok(())
}
Expand Down
26 changes: 25 additions & 1 deletion router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use axum::routing::{get, post};
use axum::{http, Json, Router};
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
use futures::future::join_all;
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use std::net::SocketAddr;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -1263,6 +1264,7 @@ pub async fn run(
addr: SocketAddr,
prom_builder: PrometheusBuilder,
payload_limit: usize,
api_key: Option<String>,
cors_allow_origin: Option<Vec<String>>,
) -> Result<(), anyhow::Error> {
// OpenAPI documentation
Expand Down Expand Up @@ -1434,13 +1436,35 @@ pub async fn run(
}
}

let app = app
app = app
.layer(Extension(infer))
.layer(Extension(info))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(cors_layer);

if let Some(api_key) = api_key {
let mut prefix = "Bearer ".to_string();
prefix.push_str(&api_key);

// Leak to allow FnMut
let api_key: &'static str = prefix.leak();

let auth = move |headers: HeaderMap,
request: axum::extract::Request,
next: axum::middleware::Next| async move {
match headers.get(AUTHORIZATION) {
Some(token) if token == api_key => {
let response = next.run(request).await;
Ok(response)
}
_ => Err(StatusCode::UNAUTHORIZED),
}
};

app = app.layer(axum::middleware::from_fn(auth));
}

// Run server
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();

Expand Down
10 changes: 7 additions & 3 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub async fn run(
uds_path: Option<String>,
huggingface_hub_cache: Option<String>,
payload_limit: usize,
api_key: Option<String>,
otlp_endpoint: Option<String>,
cors_allow_origin: Option<Vec<String>>,
) -> Result<()> {
Expand Down Expand Up @@ -275,6 +276,7 @@ pub async fn run(
addr,
prom_builder,
payload_limit,
api_key,
cors_allow_origin,
)
.await
Expand All @@ -285,10 +287,12 @@ pub async fn run(

#[cfg(feature = "grpc")]
{
// cors_allow_origin is not used for gRPC servers
// cors_allow_origin and payload_limit are not used for gRPC servers
let _ = cors_allow_origin;
let server =
tokio::spawn(async move { grpc::server::run(infer, info, addr, prom_builder).await });
let _ = payload_limit;
let server = tokio::spawn(async move {
grpc::server::run(infer, info, addr, prom_builder, api_key).await
});
tracing::info!("Ready");
server.await??;
}
Expand Down
7 changes: 7 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ struct Args {
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,

/// Set an api key for request authorization.
///
/// By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.
#[clap(long, env)]
api_key: Option<String>,

/// Outputs the logs in JSON format (useful for telemetry)
#[clap(long, env)]
json_output: bool,
Expand Down Expand Up @@ -143,6 +149,7 @@ async fn main() -> Result<()> {
Some(args.uds_path),
args.huggingface_hub_cache,
args.payload_limit,
args.api_key,
args.otlp_endpoint,
args.cors_allow_origin,
)
Expand Down

0 comments on commit 5e60d06

Please sign in to comment.