From ddea2764643d6c027cadd8e652d589a1acc3ff0a Mon Sep 17 00:00:00 2001 From: Haardik H Date: Tue, 10 Jun 2025 13:52:18 -0400 Subject: [PATCH 1/2] configurable rate limits per authed app --- crates/websocket-proxy/src/auth.rs | 93 +- crates/websocket-proxy/src/main.rs | 15 +- crates/websocket-proxy/src/rate_limit.rs | 948 +++++++++++++++++--- crates/websocket-proxy/src/server.rs | 10 +- crates/websocket-proxy/tests/integration.rs | 26 +- 5 files changed, 928 insertions(+), 164 deletions(-) diff --git a/crates/websocket-proxy/src/auth.rs b/crates/websocket-proxy/src/auth.rs index 74d3172c..4750193a 100644 --- a/crates/websocket-proxy/src/auth.rs +++ b/crates/websocket-proxy/src/auth.rs @@ -1,12 +1,13 @@ use crate::auth::AuthenticationParseError::{ DuplicateAPIKeyArgument, DuplicateApplicationArgument, MissingAPIKeyArgument, - MissingApplicationArgument, NoData, TooManyComponents, + MissingApplicationArgument, MissingRateLimitArgument, NoData, TooManyComponents, }; use std::collections::{HashMap, HashSet}; #[derive(Clone, Debug)] pub struct Authentication { key_to_application: HashMap, + app_to_rate_limit: HashMap, } #[derive(Debug, PartialEq)] @@ -14,6 +15,7 @@ pub enum AuthenticationParseError { NoData(), MissingApplicationArgument(String), MissingAPIKeyArgument(String), + MissingRateLimitArgument(String), TooManyComponents(String), DuplicateApplicationArgument(String), DuplicateAPIKeyArgument(String), @@ -23,9 +25,10 @@ impl std::fmt::Display for AuthenticationParseError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { NoData() => write!(f, "No API Keys Provided"), - MissingApplicationArgument(arg) => write!(f, "Missing application argument: [{arg}]"), - MissingAPIKeyArgument(app) => write!(f, "Missing API Key argument: [{app}]"), - TooManyComponents(app) => write!(f, "Too many components: [{app}]"), + MissingApplicationArgument(arg) => write!(f, "Missing application argument: [{}]", arg), + MissingAPIKeyArgument(app) => write!(f, "Missing API Key argument: [{}]", app), + MissingRateLimitArgument(app) => write!(f, "Missing rate limit argument: [{}]", app), + TooManyComponents(app) => write!(f, "Too many components: [{}]", app), DuplicateApplicationArgument(app) => { write!(f, "Duplicate application argument: [{app}]") } @@ -42,6 +45,7 @@ impl TryFrom> for Authentication { fn try_from(args: Vec) -> Result { let mut applications = HashSet::new(); let mut key_to_application: HashMap = HashMap::new(); + let mut app_to_rate_limit: HashMap = HashMap::new(); if args.is_empty() { return Err(NoData()); @@ -61,6 +65,13 @@ impl TryFrom> for Authentication { return Err(MissingAPIKeyArgument(app.to_string())); } + let rate_limit = parts + .next() + .ok_or(MissingRateLimitArgument(app.to_string()))?; + if rate_limit.is_empty() { + return Err(MissingRateLimitArgument(app.to_string())); + } + if parts.count() > 0 { return Err(TooManyComponents(app.to_string())); } @@ -75,9 +86,13 @@ impl TryFrom> for Authentication { applications.insert(app.to_string()); key_to_application.insert(key.to_string(), app.to_string()); + app_to_rate_limit.insert(app.to_string(), rate_limit.parse().unwrap()); } - Ok(Self { key_to_application }) + Ok(Self { + key_to_application, + app_to_rate_limit, + }) } } @@ -85,19 +100,28 @@ impl Authentication { pub fn none() -> Self { Self { key_to_application: HashMap::new(), + app_to_rate_limit: HashMap::new(), } } #[allow(dead_code)] - pub fn new(api_keys: HashMap) -> Self { + pub fn new( + api_keys: HashMap, + app_to_rate_limit: HashMap, + ) -> Self { Self { key_to_application: api_keys, + app_to_rate_limit, } } pub fn get_application_for_key(&self, api_key: &String) -> Option<&String> { self.key_to_application.get(api_key) } + + pub fn get_rate_limits(&self) -> HashMap { + self.app_to_rate_limit.clone() + } } #[cfg(test)] @@ -107,9 +131,9 @@ mod tests { #[test] fn test_parsing() { let auth = Authentication::try_from(vec![ - "app1:key1".to_string(), - "app2:key2".to_string(), - "app3:key3".to_string(), + "app1:key1:10".to_string(), + "app2:key2:10".to_string(), + "app3:key3:10".to_string(), ]) .unwrap(); @@ -117,51 +141,71 @@ mod tests { assert_eq!(auth.key_to_application["key1"], "app1"); assert_eq!(auth.key_to_application["key2"], "app2"); assert_eq!(auth.key_to_application["key3"], "app3"); + assert_eq!(auth.app_to_rate_limit.len(), 3); + assert_eq!(auth.app_to_rate_limit["app1"], 10); + assert_eq!(auth.app_to_rate_limit["app2"], 10); + assert_eq!(auth.app_to_rate_limit["app3"], 10); let auth = Authentication::try_from(vec![ - "app1:key1".to_string(), + "app1:key1:10".to_string(), "".to_string(), - "app3:key3".to_string(), + "app3:key3:10".to_string(), ]); assert!(auth.is_err()); assert_eq!(auth.unwrap_err(), MissingApplicationArgument("".into())); let auth = Authentication::try_from(vec![ - "app1:key1".to_string(), + "app1:key1:10".to_string(), "app2".to_string(), - "app3:key3".to_string(), + "app3:key3:10".to_string(), ]); assert!(auth.is_err()); assert_eq!(auth.unwrap_err(), MissingAPIKeyArgument("app2".into())); let auth = Authentication::try_from(vec![ - "app1:key1".to_string(), - ":".to_string(), + "app1:key1:10".to_string(), + "app2:key2:10".to_string(), "app3:key3".to_string(), ]); assert!(auth.is_err()); + assert_eq!(auth.unwrap_err(), MissingRateLimitArgument("app3".into())); + + let auth = Authentication::try_from(vec![ + "app1:key1:10".to_string(), + ":".to_string(), + "app3:key3:10".to_string(), + ]); + assert!(auth.is_err()); assert_eq!(auth.unwrap_err(), MissingApplicationArgument(":".into())); let auth = Authentication::try_from(vec![ - "app1:key1".to_string(), + "app1:key1:10".to_string(), "app2:".to_string(), - "app3:key3".to_string(), + "app3:key3:10".to_string(), ]); assert!(auth.is_err()); assert_eq!(auth.unwrap_err(), MissingAPIKeyArgument("app2".into())); let auth = Authentication::try_from(vec![ - "app1:key1".to_string(), - "app2:key2:unexpected2".to_string(), - "app3:key3".to_string(), + "app1:key1:10".to_string(), + "app2:key2:10".to_string(), + "app3:key3:".to_string(), + ]); + assert!(auth.is_err()); + assert_eq!(auth.unwrap_err(), MissingRateLimitArgument("app3".into())); + + let auth = Authentication::try_from(vec![ + "app1:key1:10".to_string(), + "app2:key2:10:unexpected2".to_string(), + "app3:key3:10".to_string(), ]); assert!(auth.is_err()); assert_eq!(auth.unwrap_err(), TooManyComponents("app2".into())); let auth = Authentication::try_from(vec![ - "app1:key1".to_string(), - "app1:key3".to_string(), - "app2:key2".to_string(), + "app1:key1:10".to_string(), + "app1:key3:10".to_string(), + "app2:key2:10".to_string(), ]); assert!(auth.is_err()); assert_eq!( @@ -169,7 +213,8 @@ mod tests { DuplicateApplicationArgument("app1".into()) ); - let auth = Authentication::try_from(vec!["app1:key1".to_string(), "app2:key1".to_string()]); + let auth = + Authentication::try_from(vec!["app1:key1:10".to_string(), "app2:key1:10".to_string()]); assert!(auth.is_err()); assert_eq!(auth.unwrap_err(), DuplicateAPIKeyArgument("app2".into())); } diff --git a/crates/websocket-proxy/src/main.rs b/crates/websocket-proxy/src/main.rs index 6b117e22..789e7f45 100644 --- a/crates/websocket-proxy/src/main.rs +++ b/crates/websocket-proxy/src/main.rs @@ -15,6 +15,7 @@ use metrics_exporter_prometheus::PrometheusBuilder; use rate_limit::{InMemoryRateLimit, RateLimit, RedisRateLimit}; use registry::Registry; use server::Server; +use std::collections::HashMap; use std::io::Write; use std::net::SocketAddr; use std::sync::Arc; @@ -66,9 +67,10 @@ struct Args { long, env, default_value = "10", - help = "Maximum number of concurrently connected clients per IP" + help = "Maximum number of concurrently connected clients per IP. 0 here means no limit." )] per_ip_connection_limit: usize, + #[arg( long, env, @@ -96,7 +98,7 @@ struct Args { #[arg(long, env, default_value = "true")] metrics: bool, - /// API Keys, if not provided will be an unauthenticated endpoint, should be in the format :,:,.. + /// API Keys, if not provided will be an unauthenticated endpoint, should be in the format ::,::,.. #[arg(long, env, value_delimiter = ',', help = "API keys to allow")] api_keys: Vec, @@ -337,6 +339,12 @@ async fn main() { args.client_pong_timeout_ms, ); + let app_rate_limits = if let Some(auth) = &authentication { + auth.get_rate_limits() + } else { + HashMap::new() + }; + let rate_limiter = match &args.redis_url { Some(redis_url) => { info!(message = "Using Redis rate limiter", redis_url = redis_url); @@ -344,6 +352,7 @@ async fn main() { redis_url, args.instance_connection_limit, args.per_ip_connection_limit, + app_rate_limits.clone(), &args.redis_key_prefix, ) { Ok(limiter) => { @@ -359,6 +368,7 @@ async fn main() { Arc::new(InMemoryRateLimit::new( args.instance_connection_limit, args.per_ip_connection_limit, + app_rate_limits.clone(), )) as Arc } } @@ -368,6 +378,7 @@ async fn main() { Arc::new(InMemoryRateLimit::new( args.instance_connection_limit, args.per_ip_connection_limit, + app_rate_limits, )) as Arc } }; diff --git a/crates/websocket-proxy/src/rate_limit.rs b/crates/websocket-proxy/src/rate_limit.rs index 34148def..1821ae4d 100644 --- a/crates/websocket-proxy/src/rate_limit.rs +++ b/crates/websocket-proxy/src/rate_limit.rs @@ -20,38 +20,51 @@ pub enum RateLimitError { #[clippy::has_significant_drop] pub struct Ticket { addr: IpAddr, + app: Option, _permit: OwnedSemaphorePermit, rate_limiter: Arc, } impl Drop for Ticket { fn drop(&mut self) { - self.rate_limiter.release(self.addr) + self.rate_limiter.release(self.addr, self.app.clone()) } } pub trait RateLimit: Send + Sync { - fn try_acquire(self: Arc, addr: IpAddr) -> Result; + fn try_acquire( + self: Arc, + addr: IpAddr, + app: Option, + ) -> Result; - fn release(&self, ticket: IpAddr); + fn release(&self, addr: IpAddr, app: Option); } struct Inner { - active_connections: HashMap, + active_connections_per_ip: HashMap, + active_connections_per_app: HashMap, semaphore: Arc, } pub struct InMemoryRateLimit { per_ip_limit: usize, + per_app_limit: HashMap, inner: Mutex, } impl InMemoryRateLimit { - pub fn new(instance_limit: usize, per_ip_limit: usize) -> Self { + pub fn new( + instance_limit: usize, + per_ip_limit: usize, + per_app_limit: HashMap, + ) -> Self { Self { per_ip_limit, + per_app_limit, inner: Mutex::new(Inner { - active_connections: HashMap::new(), + active_connections_per_ip: HashMap::new(), + active_connections_per_app: HashMap::new(), semaphore: Arc::new(Semaphore::new(instance_limit)), }), } @@ -59,7 +72,11 @@ impl InMemoryRateLimit { } impl RateLimit for InMemoryRateLimit { - fn try_acquire(self: Arc, addr: IpAddr) -> Result { + fn try_acquire( + self: Arc, + addr: IpAddr, + app: Option, + ) -> Result { let mut inner = self.inner.lock().unwrap(); let permit = @@ -71,54 +88,102 @@ impl RateLimit for InMemoryRateLimit { reason: "Global limit".to_owned(), })?; - let current_count = match inner.active_connections.get(&addr) { - Some(count) => *count, - None => 0, - }; + if self.per_ip_limit > 0 { + let current_count = match inner.active_connections_per_ip.get(&addr) { + Some(count) => *count, + None => 0, + }; - if current_count + 1 > self.per_ip_limit { - debug!( - message = "Rate limit exceeded, trying to acquire", - client = addr.to_string() - ); - return Err(RateLimitError::Limit { - reason: String::from("IP limit exceeded"), - }); + if current_count + 1 > self.per_ip_limit { + debug!( + message = "Rate limit exceeded, trying to acquire", + client = addr.to_string() + ); + return Err(RateLimitError::Limit { + reason: String::from("IP limit exceeded"), + }); + } + + let new_count = current_count + 1; + + inner.active_connections_per_ip.insert(addr, new_count); } - let new_count = current_count + 1; + if let Some(app) = app.clone() { + let current_count = match inner.active_connections_per_app.get(&app) { + Some(count) => *count, + None => 0, + }; - inner.active_connections.insert(addr, new_count); + if current_count + 1 > *self.per_app_limit.get(&app).unwrap_or(&0) { + debug!( + message = "Rate limit exceeded, trying to acquire", + client = addr.to_string() + ); + return Err(RateLimitError::Limit { + reason: String::from("App limit exceeded"), + }); + } + + let new_count = current_count + 1; + + inner.active_connections_per_app.insert(app, new_count); + } Ok(Ticket { addr, + app, _permit: permit, rate_limiter: self.clone(), }) } - fn release(&self, addr: IpAddr) { + fn release(&self, addr: IpAddr, app: Option) { let mut inner = self.inner.lock().unwrap(); - let current_count = match inner.active_connections.get(&addr) { - Some(count) => *count, - None => 0, - }; + if self.per_ip_limit > 0 { + let current_count = match inner.active_connections_per_ip.get(&addr) { + Some(count) => *count, + None => 0, + }; + let new_count = if current_count == 0 { + warn!( + message = "ip counting is not accurate -- unexpected underflow", + client = addr.to_string() + ); + 0 + } else { + current_count - 1 + }; + + if new_count == 0 { + inner.active_connections_per_ip.remove(&addr); + } else { + inner.active_connections_per_ip.insert(addr, new_count); + } + } - let new_count = if current_count == 0 { - warn!( - message = "ip counting is not accurate -- unexpected underflow", - client = addr.to_string() - ); - 0 - } else { - current_count - 1 - }; + if let Some(app) = app { + let current_count = match inner.active_connections_per_app.get(&app) { + Some(count) => *count, + None => 0, + }; - if new_count == 0 { - inner.active_connections.remove(&addr); - } else { - inner.active_connections.insert(addr, new_count); + let new_count = if current_count == 0 { + warn!( + message = "app counting is not accurate -- unexpected underflow", + client = addr.to_string() + ); + 0 + } else { + current_count - 1 + }; + + if new_count == 0 { + inner.active_connections_per_app.remove(&app); + } else { + inner.active_connections_per_app.insert(app, new_count); + } } } } @@ -127,6 +192,7 @@ pub struct RedisRateLimit { redis_client: Client, instance_limit: usize, per_ip_limit: usize, + per_app_limit: HashMap, semaphore: Arc, key_prefix: String, instance_id: String, @@ -140,6 +206,7 @@ impl RedisRateLimit { redis_url: &str, instance_limit: usize, per_ip_limit: usize, + per_app_limit: HashMap, key_prefix: &str, ) -> Result { let client = Client::open(redis_url)?; @@ -152,6 +219,7 @@ impl RedisRateLimit { redis_client: client, instance_limit, per_ip_limit, + per_app_limit, semaphore: Arc::new(Semaphore::new(instance_limit)), key_prefix: key_prefix.to_string(), instance_id, @@ -253,20 +321,50 @@ impl RedisRateLimit { let ip_instance_pattern = format!("{}:ip:*:instance:*:connections", self.key_prefix); let ip_instance_keys: Vec = conn.keys(ip_instance_pattern)?; - let mut instance_ids_with_connections = std::collections::HashSet::new(); + let mut instance_ids_with_ip_connections = std::collections::HashSet::new(); for key in &ip_instance_keys { if let Some(instance_id) = key.split(':').nth(4) { - instance_ids_with_connections.insert(instance_id.to_string()); + instance_ids_with_ip_connections.insert(instance_id.to_string()); + } + } + + let app_instance_pattern = format!("{}:app:*:instance:*:connections", self.key_prefix); + let app_instance_keys: Vec = conn.keys(app_instance_pattern)?; + + let mut instance_ids_with_app_connections = std::collections::HashSet::new(); + for key in &app_instance_keys { + if let Some(instance_id) = key.split(':').nth(4) { + instance_ids_with_app_connections.insert(instance_id.to_string()); } } debug!( message = "Checking for stale instances", - instances_with_connections = instance_ids_with_connections.len(), + instances_with_ip_connections = instance_ids_with_ip_connections.len(), + instances_with_app_connections = instance_ids_with_app_connections.len(), current_instance = self.instance_id ); - for instance_id in instance_ids_with_connections { + for instance_id in instance_ids_with_ip_connections { + if instance_id == self.instance_id { + debug!( + message = "Skipping current instance", + instance_id = instance_id + ); + continue; + } + + if !active_instance_ids.contains(&instance_id) { + debug!( + message = "Found stale instance", + instance_id = instance_id, + reason = "Heartbeat key not found" + ); + self.cleanup_instance(&mut conn, &instance_id)?; + } + } + + for instance_id in instance_ids_with_app_connections { if instance_id == self.instance_id { debug!( message = "Skipping current instance", @@ -301,10 +399,17 @@ impl RedisRateLimit { ); let ip_instance_keys: Vec = conn.keys(ip_instance_pattern)?; + let app_instance_pattern = format!( + "{}:app:*:instance:{}:connections", + self.key_prefix, instance_id + ); + let app_instance_keys: Vec = conn.keys(app_instance_pattern)?; + debug!( message = "Cleaning up instance", instance_id = instance_id, - ip_key_count = ip_instance_keys.len() + ip_key_count = ip_instance_keys.len(), + app_key_count = app_instance_keys.len() ); for key in ip_instance_keys { @@ -312,6 +417,11 @@ impl RedisRateLimit { debug!(message = "Deleted IP instance key", key = key); } + for key in app_instance_keys { + conn.del::<_, ()>(&key)?; + debug!(message = "Deleted app instance key", key = key); + } + Ok(()) } @@ -322,6 +432,13 @@ impl RedisRateLimit { ) } + fn app_instance_key(&self, app: &str) -> String { + format!( + "{}:app:{}:instance:{}:connections", + self.key_prefix, app, self.instance_id + ) + } + fn instance_heartbeat_key(&self) -> String { format!( "{}:instance:{}:heartbeat", @@ -331,7 +448,11 @@ impl RedisRateLimit { } impl RateLimit for RedisRateLimit { - fn try_acquire(self: Arc, addr: IpAddr) -> Result { + fn try_acquire( + self: Arc, + addr: IpAddr, + app: Option, + ) -> Result { self.clone().start_background_tasks(); let permit = match self.semaphore.clone().try_acquire_owned() { @@ -356,44 +477,91 @@ impl RateLimit for RedisRateLimit { } }; - let ip_keys_pattern = format!("{}:ip:{}:instance:*:connections", self.key_prefix, addr); - let ip_keys: Vec = match conn.keys(ip_keys_pattern) { - Ok(keys) => keys, - Err(e) => { - error!( - message = "Failed to get IP instance keys from Redis", - error = e.to_string() - ); + let mut ip_instance_connections: usize = 0; + let mut app_instance_connections: usize = 0; + let mut total_ip_connections: usize = 0; + let mut total_app_connections: usize = 0; + + if self.per_ip_limit > 0 { + let ip_keys_pattern = format!("{}:ip:{}:instance:*:connections", self.key_prefix, addr); + let ip_keys: Vec = match conn.keys(ip_keys_pattern) { + Ok(keys) => keys, + Err(e) => { + error!( + message = "Failed to get IP instance keys from Redis", + error = e.to_string() + ); + return Err(RateLimitError::Limit { + reason: "Redis operation failed".to_string(), + }); + } + }; + + for key in &ip_keys { + let count: usize = conn.get(key).unwrap_or(0); + total_ip_connections += count; + } + + if total_ip_connections >= self.per_ip_limit { return Err(RateLimitError::Limit { - reason: "Redis operation failed".to_string(), + reason: format!("Per-IP connection limit reached for {}", addr), }); } - }; - let mut total_ip_connections: usize = 0; - for key in &ip_keys { - let count: usize = conn.get(key).unwrap_or(0); - total_ip_connections += count; + ip_instance_connections = match conn.incr(self.ip_instance_key(&addr), 1) { + Ok(count) => count, + Err(e) => { + error!( + message = "Failed to increment per-instance IP counter in Redis", + error = e.to_string() + ); + return Err(RateLimitError::Limit { + reason: "Redis operation failed".to_string(), + }); + } + }; } - if total_ip_connections >= self.per_ip_limit { - return Err(RateLimitError::Limit { - reason: format!("Per-IP connection limit reached for {addr}"), - }); - } + if let Some(app) = app.clone() { + let app_keys_pattern = + format!("{}:app:{}:instance:*:connections", self.key_prefix, app); + let app_keys: Vec = match conn.keys(app_keys_pattern) { + Ok(keys) => keys, + Err(e) => { + error!( + message = "Failed to get app instance keys from Redis", + error = e.to_string() + ); + return Err(RateLimitError::Limit { + reason: "Redis operation failed".to_string(), + }); + } + }; - let ip_instance_connections: usize = match conn.incr(self.ip_instance_key(&addr), 1) { - Ok(count) => count, - Err(e) => { - error!( - message = "Failed to increment per-instance IP counter in Redis", - error = e.to_string() - ); + for key in &app_keys { + let count: usize = conn.get(key).unwrap_or(0); + total_app_connections += count; + } + + if total_app_connections >= *self.per_app_limit.get(&app).unwrap_or(&0) { return Err(RateLimitError::Limit { - reason: "Redis operation failed".to_string(), + reason: format!("Per-app connection limit reached for {}", app), }); } - }; + + app_instance_connections = match conn.incr(self.app_instance_key(&app), 1) { + Ok(count) => count, + Err(e) => { + error!( + message = "Failed to increment per-instance app counter in Redis", + error = e.to_string() + ); + return Err(RateLimitError::Limit { + reason: "Redis operation failed".to_string(), + }); + } + }; + } let total_instance_connections = self.instance_limit - self.semaphore.available_permits(); @@ -401,37 +569,61 @@ impl RateLimit for RedisRateLimit { message = "Connection established", ip = addr.to_string(), ip_instance_connections = ip_instance_connections, - total_ip_connections = total_ip_connections + 1, + total_ip_connections = total_ip_connections, + app_instance_connections = app_instance_connections, + total_app_connections = total_app_connections, total_instance_connections = total_instance_connections, instance_id = self.instance_id ); Ok(Ticket { addr, + app, _permit: permit, rate_limiter: self, }) } - fn release(&self, addr: IpAddr) { + fn release(&self, addr: IpAddr, app: Option) { match self.redis_client.get_connection() { Ok(mut conn) => { - let ip_instance_connections: Result = - conn.decr(self.ip_instance_key(&addr), 1); - - if let Err(ref e) = ip_instance_connections { - error!( - message = "Failed to decrement per-instance IP counter in Redis", - error = e.to_string() + if self.per_ip_limit > 0 { + let ip_instance_connections: Result = + conn.decr(self.ip_instance_key(&addr), 1); + + if let Err(ref e) = ip_instance_connections { + error!( + message = "Failed to decrement per-instance IP counter in Redis", + error = e.to_string() + ); + } + + debug!( + message = "Connection released", + ip = addr.to_string(), + ip_instance_connections = ip_instance_connections.unwrap_or(0), + instance_id = self.instance_id ); } - debug!( - message = "Connection released", - ip = addr.to_string(), - ip_instance_connections = ip_instance_connections.unwrap_or(0), - instance_id = self.instance_id - ); + if let Some(app) = app.clone() { + let app_instance_connections: Result = + conn.decr(self.app_instance_key(&app), 1); + + if let Err(ref e) = app_instance_connections { + error!( + message = "Failed to decrement per-instance app counter in Redis", + error = e.to_string() + ); + } + + debug!( + message = "Connection released", + app = app, + app_instance_connections = app_instance_connections.unwrap_or(0), + instance_id = self.instance_id + ); + } } Err(e) => { error!( @@ -455,10 +647,14 @@ mod tests { const PER_IP_LIMIT: usize = 2; #[tokio::test] - async fn test_tickets_are_released() { + async fn test_ip_tickets_are_released() { let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); - let rate_limiter = Arc::new(InMemoryRateLimit::new(GLOBAL_LIMIT, PER_IP_LIMIT)); + let rate_limiter = Arc::new(InMemoryRateLimit::new( + GLOBAL_LIMIT, + PER_IP_LIMIT, + HashMap::new(), + )); assert_eq!( rate_limiter @@ -470,11 +666,16 @@ mod tests { GLOBAL_LIMIT ); assert_eq!( - rate_limiter.inner.lock().unwrap().active_connections.len(), + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_ip + .len(), 0 ); - let c1 = rate_limiter.clone().try_acquire(user_1).unwrap(); + let c1 = rate_limiter.clone().try_acquire(user_1, None).unwrap(); assert_eq!( rate_limiter @@ -486,11 +687,16 @@ mod tests { GLOBAL_LIMIT - 1 ); assert_eq!( - rate_limiter.inner.lock().unwrap().active_connections.len(), + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_ip + .len(), 1 ); assert_eq!( - rate_limiter.inner.lock().unwrap().active_connections[&user_1], + rate_limiter.inner.lock().unwrap().active_connections_per_ip[&user_1], 1 ); @@ -506,7 +712,12 @@ mod tests { GLOBAL_LIMIT ); assert_eq!( - rate_limiter.inner.lock().unwrap().active_connections.len(), + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_ip + .len(), 0 ); } @@ -516,13 +727,17 @@ mod tests { let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); let user_2 = IpAddr::from_str("128.0.0.1").unwrap(); - let rate_limiter = Arc::new(InMemoryRateLimit::new(GLOBAL_LIMIT, PER_IP_LIMIT)); + let rate_limiter = Arc::new(InMemoryRateLimit::new( + GLOBAL_LIMIT, + PER_IP_LIMIT, + HashMap::new(), + )); - let _c1 = rate_limiter.clone().try_acquire(user_1).unwrap(); + let _c1 = rate_limiter.clone().try_acquire(user_1, None).unwrap(); - let _c2 = rate_limiter.clone().try_acquire(user_2).unwrap(); + let _c2 = rate_limiter.clone().try_acquire(user_2, None).unwrap(); - let _c3 = rate_limiter.clone().try_acquire(user_1).unwrap(); + let _c3 = rate_limiter.clone().try_acquire(user_1, None).unwrap(); assert_eq!( rate_limiter @@ -534,7 +749,7 @@ mod tests { 0 ); - let c4 = rate_limiter.clone().try_acquire(user_2); + let c4 = rate_limiter.clone().try_acquire(user_2, None); assert!(c4.is_err()); assert_eq!( c4.err().unwrap().to_string(), @@ -543,7 +758,7 @@ mod tests { drop(_c3); - let c4 = rate_limiter.clone().try_acquire(user_2); + let c4 = rate_limiter.clone().try_acquire(user_2, None); assert!(c4.is_ok()); } @@ -552,24 +767,28 @@ mod tests { let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); - let rate_limiter = Arc::new(InMemoryRateLimit::new(GLOBAL_LIMIT, PER_IP_LIMIT)); + let rate_limiter = Arc::new(InMemoryRateLimit::new( + GLOBAL_LIMIT, + PER_IP_LIMIT, + HashMap::new(), + )); - let _c1 = rate_limiter.clone().try_acquire(user_1).unwrap(); - let _c2 = rate_limiter.clone().try_acquire(user_1).unwrap(); + let _c1 = rate_limiter.clone().try_acquire(user_1, None).unwrap(); + let _c2 = rate_limiter.clone().try_acquire(user_1, None).unwrap(); assert_eq!( - rate_limiter.inner.lock().unwrap().active_connections[&user_1], + rate_limiter.inner.lock().unwrap().active_connections_per_ip[&user_1], 2 ); - let c3 = rate_limiter.clone().try_acquire(user_1); + let c3 = rate_limiter.clone().try_acquire(user_1, None); assert!(c3.is_err()); assert_eq!( c3.err().unwrap().to_string(), "Rate Limit Reached: IP limit exceeded" ); - let c4 = rate_limiter.clone().try_acquire(user_2); + let c4 = rate_limiter.clone().try_acquire(user_2, None); assert!(c4.is_ok()); } @@ -579,13 +798,13 @@ mod tests { let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); let user_3 = IpAddr::from_str("127.0.0.3").unwrap(); - let rate_limiter = Arc::new(InMemoryRateLimit::new(4, 3)); + let rate_limiter = Arc::new(InMemoryRateLimit::new(4, 3, HashMap::new())); - let ticket_1_1 = rate_limiter.clone().try_acquire(user_1).unwrap(); - let ticket_1_2 = rate_limiter.clone().try_acquire(user_1).unwrap(); + let ticket_1_1 = rate_limiter.clone().try_acquire(user_1, None).unwrap(); + let ticket_1_2 = rate_limiter.clone().try_acquire(user_1, None).unwrap(); - let ticket_2_1 = rate_limiter.clone().try_acquire(user_2).unwrap(); - let ticket_2_2 = rate_limiter.clone().try_acquire(user_2).unwrap(); + let ticket_2_1 = rate_limiter.clone().try_acquire(user_2, None).unwrap(); + let ticket_2_2 = rate_limiter.clone().try_acquire(user_2, None).unwrap(); assert_eq!( rate_limiter @@ -598,7 +817,7 @@ mod tests { ); // Try user_3 - should fail due to global limit - let result = rate_limiter.clone().try_acquire(user_3); + let result = rate_limiter.clone().try_acquire(user_3, None); assert!(result.is_err()); assert_eq!( result.err().unwrap().to_string(), @@ -607,7 +826,7 @@ mod tests { drop(ticket_1_1); - let ticket_3_1 = rate_limiter.clone().try_acquire(user_3).unwrap(); + let ticket_3_1 = rate_limiter.clone().try_acquire(user_3, None).unwrap(); drop(ticket_1_2); drop(ticket_2_1); @@ -624,7 +843,12 @@ mod tests { 4 ); assert_eq!( - rate_limiter.inner.lock().unwrap().active_connections.len(), + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_ip + .len(), 0 ); } @@ -634,24 +858,24 @@ mod tests { let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); - let rate_limiter = Arc::new(InMemoryRateLimit::new(5, 2)); + let rate_limiter = Arc::new(InMemoryRateLimit::new(5, 2, HashMap::new())); - let ticket_1_1 = rate_limiter.clone().try_acquire(user_1).unwrap(); - let ticket_1_2 = rate_limiter.clone().try_acquire(user_1).unwrap(); + let ticket_1_1 = rate_limiter.clone().try_acquire(user_1, None).unwrap(); + let ticket_1_2 = rate_limiter.clone().try_acquire(user_1, None).unwrap(); - let result = rate_limiter.clone().try_acquire(user_1); + let result = rate_limiter.clone().try_acquire(user_1, None); assert!(result.is_err()); assert_eq!( result.err().unwrap().to_string(), "Rate Limit Reached: IP limit exceeded" ); - let ticket_2_1 = rate_limiter.clone().try_acquire(user_2).unwrap(); + let ticket_2_1 = rate_limiter.clone().try_acquire(user_2, None).unwrap(); drop(ticket_1_1); - let ticket_1_3 = rate_limiter.clone().try_acquire(user_1).unwrap(); + let ticket_1_3 = rate_limiter.clone().try_acquire(user_1, None).unwrap(); - let result = rate_limiter.clone().try_acquire(user_1); + let result = rate_limiter.clone().try_acquire(user_1, None); assert!(result.is_err()); assert_eq!( result.err().unwrap().to_string(), @@ -672,13 +896,18 @@ mod tests { 5 ); assert_eq!( - rate_limiter.inner.lock().unwrap().active_connections.len(), + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_ip + .len(), 0 ); } #[tokio::test] - async fn test_instance_tracking_and_cleanup() { + async fn test_redis_instance_ip_tracking_and_cleanup() { let container = Redis::default().start().await.unwrap(); let host_port = container.get_host_port_ipv4(6379).await.unwrap(); let client_addr = format!("redis://127.0.0.1:{}", host_port); @@ -695,6 +924,7 @@ mod tests { redis_client: Client::open(client_addr.as_str()).unwrap(), instance_limit: 10, per_ip_limit: 5, + per_app_limit: HashMap::new(), semaphore: Arc::new(Semaphore::new(10)), key_prefix: "test".to_string(), instance_id: "instance1".to_string(), @@ -704,8 +934,8 @@ mod tests { }); rate_limiter1.register_instance().unwrap(); - let _ticket1 = rate_limiter1.clone().try_acquire(user_1).unwrap(); - let _ticket2 = rate_limiter1.clone().try_acquire(user_2).unwrap(); + let _ticket1 = rate_limiter1.clone().try_acquire(user_1, None).unwrap(); + let _ticket2 = rate_limiter1.clone().try_acquire(user_2, None).unwrap(); // no drop on release (exit of block) std::mem::forget(_ticket1); std::mem::forget(_ticket2); @@ -770,6 +1000,7 @@ mod tests { redis_client: Client::open(client_addr.as_str()).unwrap(), instance_limit: 10, per_ip_limit: 5, + per_app_limit: HashMap::new(), semaphore: Arc::new(Semaphore::new(10)), key_prefix: "test".to_string(), instance_id: "instance2".to_string(), @@ -805,7 +1036,7 @@ mod tests { ); } - let _ticket3 = rate_limiter2.clone().try_acquire(user_1).unwrap(); + let _ticket3 = rate_limiter2.clone().try_acquire(user_1, None).unwrap(); { let mut conn = redis_client.get_connection().unwrap(); @@ -817,4 +1048,467 @@ mod tests { assert_eq!(ip1_instance2_count, 1, "IP1 instance2 count should be 1"); } } + + // API Key (App) Rate Limiting Tests + const PER_APP_LIMIT: usize = 2; + + #[tokio::test] + async fn test_app_tickets_are_released() { + let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); + let app_1 = "app_1".to_string(); + + let mut per_app_limits = HashMap::new(); + per_app_limits.insert(app_1.clone(), PER_APP_LIMIT); + + let rate_limiter = Arc::new(InMemoryRateLimit::new( + GLOBAL_LIMIT, + 0, // Disable IP rate limiting + per_app_limits, + )); + + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .semaphore + .available_permits(), + GLOBAL_LIMIT + ); + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_app + .len(), + 0 + ); + + let c1 = rate_limiter + .clone() + .try_acquire(user_1, Some(app_1.clone())) + .unwrap(); + + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .semaphore + .available_permits(), + GLOBAL_LIMIT - 1 + ); + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_app + .len(), + 1 + ); + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_app[&app_1], + 1 + ); + + drop(c1); + + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .semaphore + .available_permits(), + GLOBAL_LIMIT + ); + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_app + .len(), + 0 + ); + } + + #[tokio::test] + async fn test_per_app_limits() { + let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); + let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); + let app_1 = "app_1".to_string(); + let app_2 = "app_2".to_string(); + + let mut per_app_limits = HashMap::new(); + per_app_limits.insert(app_1.clone(), PER_APP_LIMIT); + per_app_limits.insert(app_2.clone(), PER_APP_LIMIT); + + let rate_limiter = Arc::new(InMemoryRateLimit::new( + GLOBAL_LIMIT, + 0, // Disable IP rate limiting + per_app_limits, + )); + + let _c1 = rate_limiter + .clone() + .try_acquire(user_1, Some(app_1.clone())) + .unwrap(); + let _c2 = rate_limiter + .clone() + .try_acquire(user_2, Some(app_1.clone())) + .unwrap(); + + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_app[&app_1], + 2 + ); + + let c3 = rate_limiter + .clone() + .try_acquire(user_1, Some(app_1.clone())); + assert!(c3.is_err()); + assert_eq!( + c3.err().unwrap().to_string(), + "Rate Limit Reached: App limit exceeded" + ); + + // Different app should still work + let c4 = rate_limiter + .clone() + .try_acquire(user_2, Some(app_2.clone())); + assert!(c4.is_ok()); + } + + #[tokio::test] + async fn test_global_limits_with_multiple_apps() { + let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); + let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); + let user_3 = IpAddr::from_str("127.0.0.3").unwrap(); + let app_1 = "app_1".to_string(); + let app_2 = "app_2".to_string(); + let app_3 = "app_3".to_string(); + + let mut per_app_limits = HashMap::new(); + per_app_limits.insert(app_1.clone(), PER_APP_LIMIT); + per_app_limits.insert(app_2.clone(), PER_APP_LIMIT); + per_app_limits.insert(app_3.clone(), PER_APP_LIMIT); + + let rate_limiter = Arc::new(InMemoryRateLimit::new(4, 0, per_app_limits)); + + let ticket_1_1 = rate_limiter + .clone() + .try_acquire(user_1, Some(app_1.clone())) + .unwrap(); + let ticket_1_2 = rate_limiter + .clone() + .try_acquire(user_1, Some(app_1.clone())) + .unwrap(); + + let ticket_2_1 = rate_limiter + .clone() + .try_acquire(user_2, Some(app_2.clone())) + .unwrap(); + let ticket_2_2 = rate_limiter + .clone() + .try_acquire(user_2, Some(app_2.clone())) + .unwrap(); + + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .semaphore + .available_permits(), + 0 + ); + + // Try app_3 - should fail due to global limit + let result = rate_limiter + .clone() + .try_acquire(user_3, Some(app_3.clone())); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "Rate Limit Reached: Global limit" + ); + + drop(ticket_1_1); + + let ticket_3_1 = rate_limiter + .clone() + .try_acquire(user_3, Some(app_3.clone())) + .unwrap(); + + drop(ticket_1_2); + drop(ticket_2_1); + drop(ticket_2_2); + drop(ticket_3_1); + + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .semaphore + .available_permits(), + 4 + ); + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_app + .len(), + 0 + ); + } + + #[tokio::test] + async fn test_per_app_limits_remain_enforced() { + let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); + let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); + let app_1 = "app_1".to_string(); + let app_2 = "app_2".to_string(); + + let mut per_app_limits = HashMap::new(); + per_app_limits.insert(app_1.clone(), PER_APP_LIMIT); + per_app_limits.insert(app_2.clone(), PER_APP_LIMIT); + + let rate_limiter = Arc::new(InMemoryRateLimit::new(5, 0, per_app_limits)); + + let ticket_1_1 = rate_limiter + .clone() + .try_acquire(user_1, Some(app_1.clone())) + .unwrap(); + let ticket_1_2 = rate_limiter + .clone() + .try_acquire(user_2, Some(app_1.clone())) + .unwrap(); + + let result = rate_limiter + .clone() + .try_acquire(user_1, Some(app_1.clone())); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "Rate Limit Reached: App limit exceeded" + ); + + let ticket_2_1 = rate_limiter + .clone() + .try_acquire(user_2, Some(app_2.clone())) + .unwrap(); + drop(ticket_1_1); + + let ticket_1_3 = rate_limiter + .clone() + .try_acquire(user_1, Some(app_1.clone())) + .unwrap(); + + let result = rate_limiter + .clone() + .try_acquire(user_2, Some(app_1.clone())); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "Rate Limit Reached: App limit exceeded" + ); + + drop(ticket_1_2); + drop(ticket_1_3); + drop(ticket_2_1); + + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .semaphore + .available_permits(), + 5 + ); + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .active_connections_per_app + .len(), + 0 + ); + } + + #[tokio::test] + #[cfg(all(feature = "integration", test))] + async fn test_redis_instance_app_tracking_and_cleanup() { + use redis_test::server::RedisServer; + use std::time::Duration; + + let server = RedisServer::new(); + let client_addr = format!("redis://{}", server.client_addr()); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); + let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); + let app_1 = "app_1".to_string(); + let app_2 = "app_2".to_string(); + + let mut per_app_limits = HashMap::new(); + per_app_limits.insert(app_1.clone(), 5); + per_app_limits.insert(app_2.clone(), 5); + + let redis_client = Client::open(client_addr.as_str()).unwrap(); + + { + let rate_limiter1 = Arc::new(RedisRateLimit { + redis_client: Client::open(client_addr.as_str()).unwrap(), + instance_limit: 10, + per_ip_limit: 0, // Disable IP rate limiting + per_app_limit: per_app_limits.clone(), + semaphore: Arc::new(Semaphore::new(10)), + key_prefix: "test".to_string(), + instance_id: "instance1".to_string(), + heartbeat_interval: Duration::from_millis(200), + heartbeat_ttl: Duration::from_secs(1), + background_tasks_started: AtomicBool::new(true), + }); + + rate_limiter1.register_instance().unwrap(); + let _ticket1 = rate_limiter1 + .clone() + .try_acquire(user_1, Some(app_1.clone())) + .unwrap(); + let _ticket2 = rate_limiter1 + .clone() + .try_acquire(user_2, Some(app_2.clone())) + .unwrap(); + // no drop on release (exit of block) + std::mem::forget(_ticket1); + std::mem::forget(_ticket2); + + { + let mut conn = redis_client.get_connection().unwrap(); + + let exists: bool = redis::cmd("EXISTS") + .arg(format!("test:instance:instance1:heartbeat")) + .query(&mut conn) + .unwrap(); + assert!(exists, "Instance1 heartbeat should exist initially"); + + let app1_instance1_count: usize = redis::cmd("GET") + .arg(format!("test:app:{}:instance:instance1:connections", app_1)) + .query(&mut conn) + .unwrap(); + let app2_instance1_count: usize = redis::cmd("GET") + .arg(format!("test:app:{}:instance:instance1:connections", app_2)) + .query(&mut conn) + .unwrap(); + + assert_eq!(app1_instance1_count, 1, "App1 count should be 1 initially"); + assert_eq!(app2_instance1_count, 1, "App2 count should be 1 initially"); + } + }; + + tokio::time::sleep(Duration::from_secs(1)).await; + + { + let mut conn = redis_client.get_connection().unwrap(); + + let exists: bool = redis::cmd("EXISTS") + .arg(format!("test:instance:instance1:heartbeat")) + .query(&mut conn) + .unwrap(); + assert!( + !exists, + "Instance1 heartbeat should be gone after TTL expiration" + ); + + let app1_instance1_count: usize = redis::cmd("GET") + .arg(format!("test:app:{}:instance:instance1:connections", app_1)) + .query(&mut conn) + .unwrap(); + let app2_instance1_count: usize = redis::cmd("GET") + .arg(format!("test:app:{}:instance:instance1:connections", app_2)) + .query(&mut conn) + .unwrap(); + + assert_eq!( + app1_instance1_count, 1, + "App1 instance1 count should still be 1 after instance1 crash" + ); + assert_eq!( + app2_instance1_count, 1, + "App2 instance1 count should still be 1 after crash" + ); + } + + let rate_limiter2 = Arc::new(RedisRateLimit { + redis_client: Client::open(client_addr.as_str()).unwrap(), + instance_limit: 10, + per_ip_limit: 0, // Disable IP rate limiting + per_app_limit: per_app_limits, + semaphore: Arc::new(Semaphore::new(10)), + key_prefix: "test".to_string(), + instance_id: "instance2".to_string(), + heartbeat_interval: Duration::from_millis(200), + heartbeat_ttl: Duration::from_secs(2), + background_tasks_started: AtomicBool::new(false), + }); + + rate_limiter2.register_instance().unwrap(); + rate_limiter2.cleanup_stale_instances().unwrap(); + + tokio::time::sleep(Duration::from_secs(1)).await; + + { + let mut conn = redis_client.get_connection().unwrap(); + + let app1_instance1_exists: bool = redis::cmd("EXISTS") + .arg(format!("test:app:{}:instance:instance1:connections", app_1)) + .query(&mut conn) + .unwrap(); + let app2_instance1_exists: bool = redis::cmd("EXISTS") + .arg(format!("test:app:{}:instance:instance1:connections", app_2)) + .query(&mut conn) + .unwrap(); + + assert!( + !app1_instance1_exists, + "App1 instance1 counter should be gone after cleanup" + ); + assert!( + !app2_instance1_exists, + "App2 instance1 counter should be gone after cleanup" + ); + } + + let _ticket3 = rate_limiter2 + .clone() + .try_acquire(user_1, Some(app_1.clone())) + .unwrap(); + + { + let mut conn = redis_client.get_connection().unwrap(); + let app1_instance2_count: usize = redis::cmd("GET") + .arg(format!("test:app:{}:instance:instance2:connections", app_1)) + .query(&mut conn) + .unwrap(); + + assert_eq!(app1_instance2_count, 1, "App1 instance2 count should be 1"); + } + } } diff --git a/crates/websocket-proxy/src/server.rs b/crates/websocket-proxy/src/server.rs index fadf6237..cebaad02 100644 --- a/crates/websocket-proxy/src/server.rs +++ b/crates/websocket-proxy/src/server.rs @@ -120,8 +120,9 @@ async fn authenticated_websocket_handler( .unwrap() } Some(app) => { - state.metrics.proxy_connections_by_app(app); - websocket_handler(state, ws, addr, headers) + let app = app.clone(); + state.metrics.proxy_connections_by_app(&app); + websocket_handler(state, ws, addr, headers, Some(app)) } } } @@ -132,7 +133,7 @@ async fn unauthenticated_websocket_handler( ConnectInfo(addr): ConnectInfo, headers: HeaderMap, ) -> impl IntoResponse { - websocket_handler(state, ws, addr, headers) + websocket_handler(state, ws, addr, headers, None) } fn websocket_handler( @@ -140,6 +141,7 @@ fn websocket_handler( ws: WebSocketUpgrade, addr: SocketAddr, headers: HeaderMap, + app: Option, ) -> Response { let connect_addr = addr.ip(); @@ -148,7 +150,7 @@ fn websocket_handler( Some(value) => extract_addr(value, connect_addr), }; - let ticket = match state.rate_limiter.try_acquire(client_addr) { + let ticket = match state.rate_limiter.try_acquire(client_addr, app) { Ok(ticket) => ticket, Err(RateLimitError::Limit { reason }) => { state.metrics.rate_limited_requests.increment(1); diff --git a/crates/websocket-proxy/tests/integration.rs b/crates/websocket-proxy/tests/integration.rs index ae3d8837..37dab95b 100644 --- a/crates/websocket-proxy/tests/integration.rs +++ b/crates/websocket-proxy/tests/integration.rs @@ -44,7 +44,12 @@ impl TestHarness { let (sender, _) = broadcast::channel(5); let metrics = Arc::new(Metrics::default()); let registry = Registry::new(sender.clone(), metrics.clone(), false, 120000); - let rate_limited = Arc::new(InMemoryRateLimit::new(3, 10)); + let app_rate_limits = if let Some(auth) = &auth { + auth.get_rate_limits() + } else { + HashMap::new() + }; + let rate_limited = Arc::new(InMemoryRateLimit::new(3, 10, app_rate_limits)); Self { received_messages: Arc::new(Mutex::new(HashMap::new())), @@ -357,11 +362,18 @@ async fn test_authentication_disables_public_endpoint() { #[tokio::test] async fn test_authentication_allows_known_api_keys() { let addr = TestHarness::alloc_port().await; - let auth = Authentication::new(HashMap::from([ - ("key1".to_string(), "app1".to_string()), - ("key2".to_string(), "app2".to_string()), - ("key3".to_string(), "app3".to_string()), - ])); + let auth = Authentication::new( + HashMap::from([ + ("key1".to_string(), "app1".to_string()), + ("key2".to_string(), "app2".to_string()), + ("key3".to_string(), "app3".to_string()), + ]), + HashMap::from([ + ("app1".to_string(), 10), + ("app2".to_string(), 10), + ("app3".to_string(), 10), + ]), + ); let mut harness = TestHarness::new_with_auth(addr, Some(auth)); harness.start_server().await; @@ -379,7 +391,7 @@ async fn test_ping_timeout_disconnects_client() { let (sender, _) = broadcast::channel(5); let metrics = Arc::new(Metrics::default()); let registry = Registry::new(sender.clone(), metrics.clone(), true, 1000); - let rate_limited = Arc::new(InMemoryRateLimit::new(3, 10)); + let rate_limited = Arc::new(InMemoryRateLimit::new(3, 10, HashMap::new())); let mut harness = TestHarness { received_messages: Arc::new(Mutex::new(HashMap::new())), From b11759725e6448a4f8ed5ff85c73d4a4ba200a78 Mon Sep 17 00:00:00 2001 From: Haardik H Date: Thu, 12 Jun 2025 11:31:04 -0400 Subject: [PATCH 2/2] cleanup, resolve comments: --- crates/websocket-proxy/src/rate_limit.rs | 81 +++++++++++------------- 1 file changed, 36 insertions(+), 45 deletions(-) diff --git a/crates/websocket-proxy/src/rate_limit.rs b/crates/websocket-proxy/src/rate_limit.rs index 1821ae4d..8c19401a 100644 --- a/crates/websocket-proxy/src/rate_limit.rs +++ b/crates/websocket-proxy/src/rate_limit.rs @@ -89,10 +89,7 @@ impl RateLimit for InMemoryRateLimit { })?; if self.per_ip_limit > 0 { - let current_count = match inner.active_connections_per_ip.get(&addr) { - Some(count) => *count, - None => 0, - }; + let current_count = *inner.active_connections_per_ip.get(&addr).unwrap_or(&0); if current_count + 1 > self.per_ip_limit { debug!( @@ -105,15 +102,11 @@ impl RateLimit for InMemoryRateLimit { } let new_count = current_count + 1; - inner.active_connections_per_ip.insert(addr, new_count); } if let Some(app) = app.clone() { - let current_count = match inner.active_connections_per_app.get(&app) { - Some(count) => *count, - None => 0, - }; + let current_count = *inner.active_connections_per_app.get(&app).unwrap_or(&0); if current_count + 1 > *self.per_app_limit.get(&app).unwrap_or(&0) { debug!( @@ -126,7 +119,6 @@ impl RateLimit for InMemoryRateLimit { } let new_count = current_count + 1; - inner.active_connections_per_app.insert(app, new_count); } @@ -142,47 +134,46 @@ impl RateLimit for InMemoryRateLimit { let mut inner = self.inner.lock().unwrap(); if self.per_ip_limit > 0 { - let current_count = match inner.active_connections_per_ip.get(&addr) { - Some(count) => *count, - None => 0, - }; - let new_count = if current_count == 0 { - warn!( - message = "ip counting is not accurate -- unexpected underflow", - client = addr.to_string() - ); - 0 - } else { - current_count - 1 - }; + let current_count = *inner.active_connections_per_ip.get(&addr).unwrap_or(&0); - if new_count == 0 { - inner.active_connections_per_ip.remove(&addr); - } else { - inner.active_connections_per_ip.insert(addr, new_count); + match current_count { + 0 => { + warn!( + message = "ip counting is not accurate -- unexpected underflow", + client = addr.to_string() + ); + inner.active_connections_per_ip.remove(&addr); + } + 1 => { + inner.active_connections_per_ip.remove(&addr); + } + _ => { + inner + .active_connections_per_ip + .insert(addr, current_count - 1); + } } } if let Some(app) = app { - let current_count = match inner.active_connections_per_app.get(&app) { - Some(count) => *count, - None => 0, - }; - - let new_count = if current_count == 0 { - warn!( - message = "app counting is not accurate -- unexpected underflow", - client = addr.to_string() - ); - 0 - } else { - current_count - 1 - }; + let current_count = *inner.active_connections_per_app.get(&app).unwrap_or(&0); - if new_count == 0 { - inner.active_connections_per_app.remove(&app); - } else { - inner.active_connections_per_app.insert(app, new_count); + match current_count { + 0 => { + warn!( + message = "app counting is not accurate -- unexpected underflow", + client = app + ); + inner.active_connections_per_app.remove(&app); + } + 1 => { + inner.active_connections_per_app.remove(&app); + } + _ => { + inner + .active_connections_per_app + .insert(app, current_count - 1); + } } } }