Skip to content

Commit

Permalink
Add per-endpoint rate limiter to proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
kelvich committed Dec 13, 2023
1 parent 7c2c87a commit 8460654
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 0 deletions.
8 changes: 8 additions & 0 deletions proxy/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ pub enum AuthErrorImpl {
Please add it to the allowed list in the Neon console."
)]
IpAddressNotAllowed,

#[error("Too many connections to this endpoint. Please try again later.")]
TooManyConnections,
}

#[derive(Debug, Error)]
Expand All @@ -80,6 +83,10 @@ impl AuthError {
pub fn ip_address_not_allowed() -> Self {
AuthErrorImpl::IpAddressNotAllowed.into()
}

pub fn too_many_connections() -> Self {
AuthErrorImpl::TooManyConnections.into()
}
}

impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
Expand All @@ -102,6 +109,7 @@ impl UserFacingError for AuthError {
MissingEndpointName => self.to_string(),
Io(_) => "Internal error".to_string(),
IpAddressNotAllowed => self.to_string(),
TooManyConnections => self.to_string(),
}
}
}
4 changes: 4 additions & 0 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ struct ProxyCliArgs {
/// Timeout for rate limiter. If it didn't manage to aquire a permit in this time, it will return an error.
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
rate_limiter_timeout: tokio::time::Duration,
/// Endpoint rate limiter max number of requests per second.
#[clap(long, default_value_t = 300)]
endpoint_rps_limit: u32,
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
#[clap(long, default_value_t = 100)]
initial_limit: usize,
Expand Down Expand Up @@ -317,6 +320,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
authentication_config,
require_client_ip: args.require_client_ip,
disable_ip_check_for_http: args.disable_ip_check_for_http,
endpoint_rps_limit: args.endpoint_rps_limit,
}));

Ok(config)
Expand Down
1 change: 1 addition & 0 deletions proxy/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub struct ProxyConfig {
pub authentication_config: AuthenticationConfig,
pub require_client_ip: bool,
pub disable_ip_check_for_http: bool,
pub endpoint_rps_limit: u32,
}

#[derive(Debug)]
Expand Down
21 changes: 21 additions & 0 deletions proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api},
http::StatusCode,
protocol2::WithClientIp,
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
usage_metrics::{Ids, USAGE_METRICS},
};
Expand Down Expand Up @@ -307,6 +308,7 @@ pub async fn task_main(

let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancel_map = Arc::new(CancelMap::default());
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(config.endpoint_rps_limit));

while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
Expand All @@ -315,6 +317,8 @@ pub async fn task_main(

let session_id = uuid::Uuid::new_v4();
let cancel_map = Arc::clone(&cancel_map);
let endpoint_rate_limiter = endpoint_rate_limiter.clone();

connections.spawn(
async move {
info!("accepted postgres client connection");
Expand All @@ -340,6 +344,7 @@ pub async fn task_main(
socket,
ClientMode::Tcp,
peer_addr.ip(),
endpoint_rate_limiter,
)
.await
}
Expand Down Expand Up @@ -415,6 +420,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mode: ClientMode,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
info!(
protocol = mode.protocol_label(),
Expand Down Expand Up @@ -463,6 +469,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
&params,
session_id,
mode.allow_self_signed_compute(config),
endpoint_rate_limiter,
);
cancel_map
.with_session(|session| client.connect_to_db(session, mode, &config.authentication_config))
Expand Down Expand Up @@ -928,6 +935,8 @@ struct Client<'a, S> {
session_id: uuid::Uuid,
/// Allow self-signed certificates (for testing).
allow_self_signed_compute: bool,
/// Rate limiter for endpoints
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}

impl<'a, S> Client<'a, S> {
Expand All @@ -938,13 +947,15 @@ impl<'a, S> Client<'a, S> {
params: &'a StartupMessageParams,
session_id: uuid::Uuid,
allow_self_signed_compute: bool,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Self {
Self {
stream,
creds,
params,
session_id,
allow_self_signed_compute,
endpoint_rate_limiter,
}
}
}
Expand All @@ -966,8 +977,18 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
params,
session_id,
allow_self_signed_compute,
endpoint_rate_limiter,
} = self;

// check rate limit
if let Some(ep) = creds.get_endpoint() {
if !endpoint_rate_limiter.check(ep) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await;
}
}

let proto = mode.protocol_label();
let extra = console::ConsoleReqExtra {
session_id, // aka this connection's id
Expand Down
1 change: 1 addition & 0 deletions proxy/src/rate_limiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ mod limit_algorithm;
mod limiter;
pub use aimd::Aimd;
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
pub use limiter::EndpointRateLimiter;
pub use limiter::Limiter;
71 changes: 71 additions & 0 deletions proxy/src/rate_limiter/limiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use std::{
time::Duration,
};

use dashmap::DashMap;
use parking_lot::Mutex;
use smol_str::SmolStr;
use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
use tokio::time::{timeout, Instant};
use tracing::info;
Expand All @@ -15,6 +18,74 @@ use super::{
RateLimiterConfig,
};

// Simple per-endpoint rate limiter.
//
// Check that number of connections to the endpoint is below `max_rps` rps.
// Purposefully ignore user name and database name as clients can reconnect
// with different names, so we'll end up sending some http requests to
// the control plane.
//
// We also may save quite a lot of CPU (I think) by bailing out right after we
// saw SNI, before doing TLS handshake. User-side error messages in that case
// does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
// I went with a more expensive way that yields user-friendlier error messages.
//
// TODO: add a better bucketing here, e.g. not more than 300 requests per second,
// and not more than 1000 requests per 10 seconds, etc. Short bursts of reconnects
// are noramal during redeployments, so we should not block them.
pub struct EndpointRateLimiter {
map: DashMap<SmolStr, Arc<Mutex<(chrono::NaiveTime, u32)>>>,
max_rps: u32,
access_count: AtomicUsize,
}

impl EndpointRateLimiter {
pub fn new(max_rps: u32) -> Self {
Self {
map: DashMap::new(),
max_rps,
access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
}
}

/// Check that number of connections to the endpoint is below `max_rps` rps.
pub fn check(&self, endpoint: SmolStr) -> bool {
// do GC every 100k requests (worst case memory usage is about 10MB)
if self.access_count.fetch_add(1, Ordering::AcqRel) % 100_000 == 0 {
self.do_gc();
}

let now = chrono::Utc::now().naive_utc().time();
let entry = self
.map
.entry(endpoint)
.or_insert_with(|| Arc::new(Mutex::new((now, 0))));
let mut entry = entry.lock();
let (last_time, count) = *entry;

if now - last_time < chrono::Duration::seconds(1) {
if count >= self.max_rps {
return false;
}
*entry = (last_time, count + 1);
} else {
*entry = (now, 1);
}
true
}

/// Clean the map. Simple strategy: remove all entries. At worst, we'll
/// double the effective max_rps during the cleanup. But that way deletion
/// does not aquire mutex on each entry access.
pub fn do_gc(&self) {
info!(
"cleaning up endpoint rate limiter, current size = {}",
self.map.len()
);
self.map.clear();
}
}

/// Limits the number of concurrent jobs.
///
/// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
Expand Down
7 changes: 7 additions & 0 deletions proxy/src/serverless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use tokio_util::task::TaskTracker;

use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER};
use crate::rate_limiter::EndpointRateLimiter;
use crate::{cancellation::CancelMap, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
Expand Down Expand Up @@ -43,6 +44,7 @@ pub async fn task_main(
}

let conn_pool = conn_pool::GlobalConnPool::new(config);
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(config.endpoint_rps_limit));

// shutdown the connection pool
tokio::spawn({
Expand Down Expand Up @@ -91,6 +93,7 @@ pub async fn task_main(
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();

async move {
let peer_addr = match client_addr {
Expand All @@ -103,6 +106,7 @@ pub async fn task_main(
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();

async move {
let cancel_map = Arc::new(CancelMap::default());
Expand All @@ -117,6 +121,7 @@ pub async fn task_main(
session_id,
sni_name,
peer_addr.ip(),
endpoint_rate_limiter,
)
.instrument(info_span!(
"serverless",
Expand Down Expand Up @@ -190,6 +195,7 @@ async fn request_handler(
session_id: uuid::Uuid,
sni_hostname: Option<String>,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Body>, ApiError> {
let host = request
.headers()
Expand All @@ -214,6 +220,7 @@ async fn request_handler(
session_id,
host,
peer_addr,
endpoint_rate_limiter,
)
.await
{
Expand Down
4 changes: 4 additions & 0 deletions proxy/src/serverless/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
config::ProxyConfig,
error::io_error,
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
use bytes::{Buf, Bytes};
use futures::{Sink, Stream};
Expand All @@ -13,6 +14,7 @@ use pin_project_lite::pin_project;
use std::{
net::IpAddr,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
Expand Down Expand Up @@ -134,6 +136,7 @@ pub async fn serve_websocket(
session_id: uuid::Uuid,
hostname: Option<String>,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
handle_client(
Expand All @@ -143,6 +146,7 @@ pub async fn serve_websocket(
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
peer_addr,
endpoint_rate_limiter,
)
.await?;
Ok(())
Expand Down

1 comment on commit 8460654

@github-actions
Copy link

@github-actions github-actions bot commented on 8460654 Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2252 tests run: 2161 passed, 0 failed, 91 skipped (full report)


Flaky tests (5)

Postgres 16

Postgres 14

  • test_cannot_branch_from_non_uploaded_branch: debug
  • test_ondemand_download_large_rel: debug
  • test_ondemand_download_timetravel: debug
  • test_multi_attach: release

Code coverage (full report)

  • functions: 55.2% (9432 of 17093 functions)
  • lines: 82.3% (54630 of 66404 lines)

The comment gets automatically updated with the latest test results
8460654 at 2023-12-13T06:31:32.223Z :recycle:

Please sign in to comment.