From d8622f51d54fc4c84dcced465a33b7fba3924e21 Mon Sep 17 00:00:00 2001 From: Joe Parks <26990067+jowparks@users.noreply.github.com> Date: Tue, 29 Apr 2025 13:15:25 -0700 Subject: [PATCH 1/3] reset redis counters on boot --- src/rate_limit.rs | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/src/rate_limit.rs b/src/rate_limit.rs index 06bbf66..774f6aa 100644 --- a/src/rate_limit.rs +++ b/src/rate_limit.rs @@ -1,13 +1,12 @@ use std::collections::HashMap; use std::net::IpAddr; use std::sync::{Arc, Mutex}; -use tracing::{debug, warn}; +use tracing::{debug, error, info, warn}; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use redis::{Client, Commands, RedisError}; -use tracing::error; #[derive(Error, Debug)] pub enum RateLimitError { @@ -138,13 +137,22 @@ impl RedisRateLimit { ) -> Result { let client = Client::open(redis_url)?; - Ok(Self { + let limiter = Self { redis_client: client, global_limit, per_ip_limit, semaphore: Arc::new(Semaphore::new(global_limit)), key_prefix: key_prefix.to_string(), - }) + }; + + if let Err(e) = limiter.reset_counters() { + error!( + message = "Failed to reset Redis counters on startup", + error = e.to_string() + ); + } + + Ok(limiter) } /// Get Redis key for tracking global connections @@ -156,6 +164,29 @@ impl RedisRateLimit { fn ip_key(&self, addr: &IpAddr) -> String { format!("{}:ip:{}:connections", self.key_prefix, addr) } + + /// Reset all Redis counters associated with this rate limiter + pub fn reset_counters(&self) -> Result<(), RedisError> { + let mut conn = self.redis_client.get_connection()?; + + // Delete the global counter + let _: () = conn.del(self.global_key())?; + + // Find and delete all IP-specific counters with this prefix + let pattern = format!("{}:ip:*:connections", self.key_prefix); + let keys: Vec = conn.keys(pattern)?; + + if !keys.is_empty() { + let _: () = conn.del(keys)?; + } + + info!( + message = "Reset all Redis rate limit counters", + prefix = self.key_prefix + ); + + Ok(()) + } } impl RateLimit for RedisRateLimit { From 915de1aa7abe569d7e6f9aca3698584161ac3894 Mon Sep 17 00:00:00 2001 From: Joe Parks <26990067+jowparks@users.noreply.github.com> Date: Wed, 30 Apr 2025 12:15:34 -0700 Subject: [PATCH 2/3] updates to use heartbeat/cleanup mechanism for tracking --- .gitignore | 1 + Cargo.lock | 10 + Cargo.toml | 1 + src/rate_limit.rs | 491 ++++++++++++++++++++++++++++++---------------- 4 files changed, 337 insertions(+), 166 deletions(-) diff --git a/.gitignore b/.gitignore index a009243..884df16 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target/ /.idea/ /.env +/.vscode/ \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 7244c07..78e1ba8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -522,6 +522,7 @@ dependencies = [ "tokio-util", "tracing", "tracing-subscriber", + "uuid", ] [[package]] @@ -2346,6 +2347,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +dependencies = [ + "getrandom 0.3.2", +] + [[package]] name = "valuable" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 195ee21..180dbdd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ serde_json = "1.0.138" hostname = "0.4.0" redis = "0.30.0" redis-test = { version = "0.10.0", optional = true } +uuid = { version = "1.16.0", features = ["v4"] } [dependencies.ring] diff --git a/src/rate_limit.rs b/src/rate_limit.rs index 774f6aa..da1f763 100644 --- a/src/rate_limit.rs +++ b/src/rate_limit.rs @@ -1,12 +1,15 @@ use std::collections::HashMap; use std::net::IpAddr; use std::sync::{Arc, Mutex}; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, warn}; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use redis::{Client, Commands, RedisError}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::{Duration, SystemTime}; +use uuid::Uuid; #[derive(Error, Debug)] pub enum RateLimitError { @@ -126,6 +129,10 @@ pub struct RedisRateLimit { per_ip_limit: usize, semaphore: Arc, key_prefix: String, + instance_id: String, + heartbeat_interval: Duration, + heartbeat_ttl: Duration, + background_tasks_started: AtomicBool, } impl RedisRateLimit { @@ -136,62 +143,197 @@ impl RedisRateLimit { key_prefix: &str, ) -> Result { let client = Client::open(redis_url)?; + let instance_id = Uuid::new_v4().to_string(); - let limiter = Self { + let heartbeat_interval = Duration::from_secs(10); + let heartbeat_ttl = Duration::from_secs(30); + + let rate_limiter = Self { redis_client: client, global_limit, per_ip_limit, semaphore: Arc::new(Semaphore::new(global_limit)), key_prefix: key_prefix.to_string(), + instance_id, + heartbeat_interval, + heartbeat_ttl, + background_tasks_started: AtomicBool::new(false), }; - if let Err(e) = limiter.reset_counters() { + if let Err(e) = rate_limiter.register_instance() { error!( - message = "Failed to reset Redis counters on startup", + message = "Failed to register instance in Redis", error = e.to_string() ); } - Ok(limiter) + Ok(rate_limiter) + } + + pub fn start_background_tasks(self: Arc) { + if self.background_tasks_started.swap(true, Ordering::SeqCst) { + return; + } + + debug!( + message = "Starting background heartbeat and cleanup tasks", + instance_id = self.instance_id + ); + + let self_clone = self.clone(); + tokio::spawn(async move { + loop { + if let Err(e) = self_clone.update_heartbeat() { + error!( + message = "Failed to update heartbeat in background task", + error = e.to_string() + ); + } + + if let Err(e) = self_clone.cleanup_stale_instances() { + error!( + message = "Failed to cleanup stale instances in background task", + error = e.to_string() + ); + } + + tokio::time::sleep(self_clone.heartbeat_interval / 2).await; + } + }); } - /// Get Redis key for tracking global connections - fn global_key(&self) -> String { - format!("{}:global:connections", self.key_prefix) + fn register_instance(&self) -> Result<(), RedisError> { + self.update_heartbeat()?; + debug!( + message = "Registered instance in Redis", + instance_id = self.instance_id + ); + + Ok(()) } - /// Get Redis key for tracking connections per IP - fn ip_key(&self, addr: &IpAddr) -> String { - format!("{}:ip:{}:connections", self.key_prefix, addr) + fn update_heartbeat(&self) -> Result<(), RedisError> { + let now = SystemTime::now(); + let mut conn = self.redis_client.get_connection()?; + + let ttl = self.heartbeat_ttl.as_secs(); + conn.set_ex::<_, _, ()>( + self.instance_heartbeat_key(), + now.duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(), + ttl, + )?; + + debug!( + message = "Updated instance heartbeat", + instance_id = self.instance_id + ); + + Ok(()) } - /// Reset all Redis counters associated with this rate limiter - pub fn reset_counters(&self) -> Result<(), RedisError> { + fn cleanup_stale_instances(&self) -> Result<(), RedisError> { let mut conn = self.redis_client.get_connection()?; - // Delete the global counter - let _: () = conn.del(self.global_key())?; + let instance_heartbeat_pattern = format!("{}:instance:*:heartbeat", self.key_prefix); + let instance_heartbeats: Vec = conn.keys(instance_heartbeat_pattern)?; + + let active_instance_ids: Vec = instance_heartbeats + .iter() + .filter_map(|key| key.split(':').nth(2).map(String::from)) + .collect(); - // Find and delete all IP-specific counters with this prefix - let pattern = format!("{}:ip:*:connections", self.key_prefix); - let keys: Vec = conn.keys(pattern)?; + debug!( + message = "Active instances with heartbeats", + instance_count = active_instance_ids.len(), + current_instance = self.instance_id + ); - if !keys.is_empty() { - let _: () = conn.del(keys)?; + let ip_instance_pattern = format!("{}:ip:*:instance:*:connections", self.key_prefix); + let ip_instance_keys: Vec = conn.keys(ip_instance_pattern)?; + + let mut instance_ids_with_connections = std::collections::HashSet::new(); + for key in &ip_instance_keys { + if let Some(instance_id) = key.split(':').nth(4) { + instance_ids_with_connections.insert(instance_id.to_string()); + } } - info!( - message = "Reset all Redis rate limit counters", - prefix = self.key_prefix + debug!( + message = "Checking for stale instances", + instances_with_connections = instance_ids_with_connections.len(), + current_instance = self.instance_id ); + for instance_id in instance_ids_with_connections { + if instance_id == self.instance_id { + debug!( + message = "Skipping current instance", + instance_id = instance_id + ); + continue; + } + + if !active_instance_ids.contains(&instance_id) { + debug!( + message = "Found stale instance", + instance_id = instance_id, + reason = "Heartbeat key not found" + ); + self.cleanup_instance(&mut conn, &instance_id)?; + } + } + + debug!(message = "Completed stale instance cleanup"); + Ok(()) } + + fn cleanup_instance( + &self, + conn: &mut redis::Connection, + instance_id: &str, + ) -> Result<(), RedisError> { + let ip_instance_pattern = format!( + "{}:ip:*:instance:{}:connections", + self.key_prefix, instance_id + ); + let ip_instance_keys: Vec = conn.keys(ip_instance_pattern)?; + + debug!( + message = "Cleaning up instance", + instance_id = instance_id, + ip_key_count = ip_instance_keys.len() + ); + + for key in ip_instance_keys { + conn.del::<_, ()>(&key)?; + debug!(message = "Deleted IP instance key", key = key); + } + + Ok(()) + } + + fn ip_instance_key(&self, addr: &IpAddr) -> String { + format!( + "{}:ip:{}:instance:{}:connections", + self.key_prefix, addr, self.instance_id + ) + } + + fn instance_heartbeat_key(&self) -> String { + format!( + "{}:instance:{}:heartbeat", + self.key_prefix, self.instance_id + ) + } } impl RateLimit for RedisRateLimit { fn try_acquire(self: Arc, addr: IpAddr) -> Result { - // Get Redis connection first to check current counts + self.clone().start_background_tasks(); + let mut conn = match self.redis_client.get_connection() { Ok(conn) => conn, Err(e) => { @@ -205,13 +347,30 @@ impl RateLimit for RedisRateLimit { } }; - // Check global count BEFORE incrementing - let global_connections: usize = conn.get(self.global_key()).unwrap_or(0); + let ip_instance_pattern = format!("{}:ip:*:instance:*:connections", self.key_prefix); + let ip_instance_keys: Vec = match conn.keys(ip_instance_pattern) { + Ok(keys) => keys, + Err(e) => { + error!( + message = "Failed to get IP instance keys from Redis", + error = e.to_string() + ); + return Err(RateLimitError::Limit { + reason: "Redis operation failed".to_string(), + }); + } + }; + + let mut total_global_connections: usize = 0; + for key in &ip_instance_keys { + let count: usize = conn.get(key).unwrap_or(0); + total_global_connections += count; + } - if global_connections >= self.global_limit { + if total_global_connections >= self.global_limit { debug!( message = "Global limit reached", - global_connections = global_connections, + global_connections = total_global_connections, global_limit = self.global_limit ); return Err(RateLimitError::Limit { @@ -219,54 +378,46 @@ impl RateLimit for RedisRateLimit { }); } - // Check IP count BEFORE incrementing - let ip_connections: usize = conn.get(self.ip_key(&addr)).unwrap_or(0); - if ip_connections >= self.per_ip_limit { + let ip_keys_pattern = format!("{}:ip:{}:instance:*:connections", self.key_prefix, addr); + let ip_keys: Vec = match conn.keys(ip_keys_pattern) { + Ok(keys) => keys, + Err(e) => { + error!( + message = "Failed to get IP instance keys from Redis", + error = e.to_string() + ); + return Err(RateLimitError::Limit { + reason: "Redis operation failed".to_string(), + }); + } + }; + + let mut total_ip_connections: usize = 0; + for key in &ip_keys { + let count: usize = conn.get(key).unwrap_or(0); + total_ip_connections += count; + } + + if total_ip_connections >= self.per_ip_limit { return Err(RateLimitError::Limit { reason: format!("Per-IP connection limit reached for {}", addr), }); } - // Now try to get the local semaphore permit let permit = match self.semaphore.clone().try_acquire_owned() { Ok(permit) => permit, Err(_) => { return Err(RateLimitError::Limit { - reason: "Global connection limit reached on this instance".to_string(), + reason: "Maximum connection limit reached for this server instance".to_string(), }); } }; - // Increment IP counter - let ip_connections: usize = match conn.incr(self.ip_key(&addr), 1) { + let ip_instance_connections: usize = match conn.incr(self.ip_instance_key(&addr), 1) { Ok(count) => count, Err(e) => { error!( - message = "Failed to increment IP counter in Redis", - error = e.to_string() - ); - return Err(RateLimitError::Limit { - reason: "Redis operation failed".to_string(), - }); - } - }; - - // Increment global counter - let global_connections: usize = match conn.incr(self.global_key(), 1) { - Ok(count) => { - debug!( - message = "Incremented global counter", - global_connections = count, - global_limit = self.global_limit - ); - count - } - Err(e) => { - // Roll back IP counter increment - let _: Result<(), _> = conn.decr(self.ip_key(&addr), 1); - - error!( - message = "Failed to increment global counter in Redis", + message = "Failed to increment per-instance IP counter in Redis", error = e.to_string() ); return Err(RateLimitError::Limit { @@ -278,8 +429,10 @@ impl RateLimit for RedisRateLimit { debug!( message = "Connection established", ip = addr.to_string(), - ip_connections = ip_connections, - global_connections = global_connections + ip_instance_connections = ip_instance_connections, + total_ip_connections = total_ip_connections + 1, + total_global_connections = total_global_connections + 1, + instance_id = self.instance_id ); Ok(Ticket { @@ -292,28 +445,12 @@ impl RateLimit for RedisRateLimit { fn release(&self, addr: IpAddr) { match self.redis_client.get_connection() { Ok(mut conn) => { - // Decrement IP counter - let ip_connections: Result = conn.decr(self.ip_key(&addr), 1); + let ip_instance_connections: Result = + conn.decr(self.ip_instance_key(&addr), 1); - if let Err(ref e) = ip_connections { + if let Err(ref e) = ip_instance_connections { error!( - message = "Failed to decrement IP counter in Redis", - error = e.to_string() - ); - } - - // Decrement global counter - let global_connections: Result = conn.decr(self.global_key(), 1); - - if let Ok(count) = global_connections { - debug!( - message = "Decremented global counter on release", - global_connections = count, - global_limit = self.global_limit - ); - } else if let Err(ref e) = global_connections { - error!( - message = "Failed to decrement global counter in Redis", + message = "Failed to decrement per-instance IP counter in Redis", error = e.to_string() ); } @@ -321,8 +458,8 @@ impl RateLimit for RedisRateLimit { debug!( message = "Connection released", ip = addr.to_string(), - ip_connections = ip_connections.unwrap_or(0), - global_connections = global_connections.unwrap_or(0) + ip_instance_connections = ip_instance_connections.unwrap_or(0), + instance_id = self.instance_id ); } Err(e) => { @@ -458,129 +595,151 @@ mod tests { "Rate Limit Reached: IP limit exceeded" ); - // While the first IP is limited, the second isn't let c4 = rate_limiter.clone().try_acquire(user_2); assert!(c4.is_ok()); } #[tokio::test] #[cfg(all(feature = "integration", test))] - async fn test_redis_rate_limits_with_mock() { + async fn test_instance_tracking_and_cleanup() { use redis_test::server::RedisServer; + use std::time::Duration; - // Start a mock Redis server let server = RedisServer::new(); let client_addr = format!("redis://{}", server.client_addr()); - let client = redis::Client::open(client_addr.as_str()).unwrap(); - // Wait for the server to start - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + tokio::time::sleep(Duration::from_millis(100)).await; - let rate_limiter = Arc::new(RedisRateLimit::new(&client_addr, 5, 1, "test").unwrap()); - - // Test IP addresses let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); - // Test 1: Acquire a connection successfully - let ticket1 = rate_limiter.clone().try_acquire(user_1).unwrap(); + let redis_client = Client::open(client_addr.as_str()).unwrap(); - // Verify Redis state manually { - let mut conn = client.get_connection().unwrap(); - let global_count: usize = redis::cmd("GET") - .arg("test:global:connections") - .query(&mut conn) - .unwrap(); - let ip_count: usize = redis::cmd("GET") - .arg("test:ip:127.0.0.1:connections") - .query(&mut conn) - .unwrap(); - - assert_eq!(global_count, 1); - assert_eq!(ip_count, 1); - } + let rate_limiter1 = Arc::new(RedisRateLimit { + redis_client: Client::open(client_addr.as_str()).unwrap(), + global_limit: 10, + per_ip_limit: 5, + semaphore: Arc::new(Semaphore::new(10)), + key_prefix: "test".to_string(), + instance_id: "instance1".to_string(), + heartbeat_interval: Duration::from_millis(200), + heartbeat_ttl: Duration::from_secs(1), + background_tasks_started: AtomicBool::new(true), + }); - // Test 3: Third connection for same IP should fail (exceeds per-IP limit) - let result = rate_limiter.clone().try_acquire(user_1); - assert!(result.is_err()); - if let Err(RateLimitError::Limit { reason }) = result { - assert!(reason.contains("Per-IP connection limit")); - } else { - panic!("Expected a RateLimitError::Limit"); - } + rate_limiter1.register_instance().unwrap(); + let _ticket1 = rate_limiter1.clone().try_acquire(user_1).unwrap(); + let _ticket2 = rate_limiter1.clone().try_acquire(user_2).unwrap(); + // no drop on release (exit of block) + std::mem::forget(_ticket1); + std::mem::forget(_ticket2); + + { + let mut conn = redis_client.get_connection().unwrap(); + + let exists: bool = redis::cmd("EXISTS") + .arg(format!("test:instance:instance1:heartbeat")) + .query(&mut conn) + .unwrap(); + assert!(exists, "Instance1 heartbeat should exist initially"); + + let ip1_instance1_count: usize = redis::cmd("GET") + .arg("test:ip:127.0.0.1:instance:instance1:connections") + .query(&mut conn) + .unwrap(); + let ip2_instance1_count: usize = redis::cmd("GET") + .arg("test:ip:127.0.0.2:instance:instance1:connections") + .query(&mut conn) + .unwrap(); + + assert_eq!(ip1_instance1_count, 1, "IP1 count should be 1 initially"); + assert_eq!(ip2_instance1_count, 1, "IP2 count should be 1 initially"); + } + }; - // Test 4: Different IP should work - let ticket2 = rate_limiter.clone().try_acquire(user_2).unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; - // Verify counts after multiple operations { - let mut conn = client.get_connection().unwrap(); - let global_count: usize = redis::cmd("GET") - .arg("test:global:connections") + let mut conn = redis_client.get_connection().unwrap(); + + let exists: bool = redis::cmd("EXISTS") + .arg(format!("test:instance:instance1:heartbeat")) .query(&mut conn) .unwrap(); - let ip1_count: usize = redis::cmd("GET") - .arg("test:ip:127.0.0.1:connections") + assert!( + !exists, + "Instance1 heartbeat should be gone after TTL expiration" + ); + + let ip1_instance1_count: usize = redis::cmd("GET") + .arg("test:ip:127.0.0.1:instance:instance1:connections") .query(&mut conn) .unwrap(); - let ip2_count: usize = redis::cmd("GET") - .arg("test:ip:127.0.0.2:connections") + let ip2_instance1_count: usize = redis::cmd("GET") + .arg("test:ip:127.0.0.2:instance:instance1:connections") .query(&mut conn) .unwrap(); - assert_eq!(global_count, 2); - assert_eq!(ip1_count, 1); - assert_eq!(ip2_count, 1); + assert_eq!( + ip1_instance1_count, 1, + "IP1 instance1 count should still be 1 after instance1 crash" + ); + assert_eq!( + ip2_instance1_count, 1, + "IP2 instance1 count should still be 1 after crash" + ); } - // Test 5: Test release by dropping tickets - drop(ticket1); + let rate_limiter2 = Arc::new(RedisRateLimit { + redis_client: Client::open(client_addr.as_str()).unwrap(), + global_limit: 10, + per_ip_limit: 5, + semaphore: Arc::new(Semaphore::new(10)), + key_prefix: "test".to_string(), + instance_id: "instance2".to_string(), + heartbeat_interval: Duration::from_millis(200), + heartbeat_ttl: Duration::from_secs(2), + background_tasks_started: AtomicBool::new(false), + }); + + rate_limiter2.register_instance().unwrap(); + rate_limiter2.cleanup_stale_instances().unwrap(); + + tokio::time::sleep(Duration::from_secs(1)).await; - // Verify that counters were decremented { - let mut conn = client.get_connection().unwrap(); - let global_count: usize = redis::cmd("GET") - .arg("test:global:connections") + let mut conn = redis_client.get_connection().unwrap(); + + let ip1_instance1_exists: bool = redis::cmd("EXISTS") + .arg("test:ip:127.0.0.1:instance:instance1:connections") .query(&mut conn) .unwrap(); - let ip1_count: usize = redis::cmd("GET") - .arg("test:ip:127.0.0.1:connections") + let ip2_instance1_exists: bool = redis::cmd("EXISTS") + .arg("test:ip:127.0.0.2:instance:instance1:connections") .query(&mut conn) .unwrap(); - assert_eq!(global_count, 1); - assert_eq!(ip1_count, 0); + assert!( + !ip1_instance1_exists, + "IP1 instance1 counter should be gone after cleanup" + ); + assert!( + !ip2_instance1_exists, + "IP2 instance1 counter should be gone after cleanup" + ); } - // Test 6: Now we should be able to acquire another connection for user_1 - let _ = rate_limiter.clone().try_acquire(user_1).unwrap(); - - // Test 7: Test global limit - let rate_limiter_small = Arc::new( - RedisRateLimit::new( - &client_addr, - 3, // smaller global limit - 5, // large per-IP limit - "test2", - ) - .unwrap(), - ); + let _ticket3 = rate_limiter2.clone().try_acquire(user_1).unwrap(); - // Acquire connections up to the global limit - let _t1 = rate_limiter_small.clone().try_acquire(user_1).unwrap(); - let _t2 = rate_limiter_small.clone().try_acquire(user_1).unwrap(); - let _t3 = rate_limiter_small.clone().try_acquire(user_1).unwrap(); + { + let mut conn = redis_client.get_connection().unwrap(); + let ip1_instance2_count: usize = redis::cmd("GET") + .arg("test:ip:127.0.0.1:instance:instance2:connections") + .query(&mut conn) + .unwrap(); - // This should fail due to global limit - let result = rate_limiter_small.clone().try_acquire(user_1); - assert!(result.is_err()); - if let Err(RateLimitError::Limit { reason }) = result { - assert!(reason.contains("Global connection limit")); - } else { - panic!("Expected a RateLimitError::Limit"); + assert_eq!(ip1_instance2_count, 1, "IP1 instance2 count should be 1"); } - - drop(ticket2); } } From 9fbb627efec4b46a24756f41feebb4ba9e6453a7 Mon Sep 17 00:00:00 2001 From: Joe Parks <26990067+jowparks@users.noreply.github.com> Date: Mon, 5 May 2025 11:19:17 -0700 Subject: [PATCH 3/3] pr comment updates, add tests, move sempahore --- src/rate_limit.rs | 122 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 113 insertions(+), 9 deletions(-) diff --git a/src/rate_limit.rs b/src/rate_limit.rs index da1f763..dc7a14a 100644 --- a/src/rate_limit.rs +++ b/src/rate_limit.rs @@ -334,6 +334,15 @@ impl RateLimit for RedisRateLimit { fn try_acquire(self: Arc, addr: IpAddr) -> Result { self.clone().start_background_tasks(); + let permit = match self.semaphore.clone().try_acquire_owned() { + Ok(permit) => permit, + Err(_) => { + return Err(RateLimitError::Limit { + reason: "Maximum connection limit reached for this server instance".to_string(), + }); + } + }; + let mut conn = match self.redis_client.get_connection() { Ok(conn) => conn, Err(e) => { @@ -404,15 +413,6 @@ impl RateLimit for RedisRateLimit { }); } - let permit = match self.semaphore.clone().try_acquire_owned() { - Ok(permit) => permit, - Err(_) => { - return Err(RateLimitError::Limit { - reason: "Maximum connection limit reached for this server instance".to_string(), - }); - } - }; - let ip_instance_connections: usize = match conn.incr(self.ip_instance_key(&addr), 1) { Ok(count) => count, Err(e) => { @@ -599,6 +599,110 @@ mod tests { assert!(c4.is_ok()); } + #[tokio::test] + async fn test_global_limits_with_multiple_ips() { + let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); + let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); + let user_3 = IpAddr::from_str("127.0.0.3").unwrap(); + + let rate_limiter = Arc::new(InMemoryRateLimit::new(4, 3)); + + let ticket_1_1 = rate_limiter.clone().try_acquire(user_1).unwrap(); + let ticket_1_2 = rate_limiter.clone().try_acquire(user_1).unwrap(); + + let ticket_2_1 = rate_limiter.clone().try_acquire(user_2).unwrap(); + let ticket_2_2 = rate_limiter.clone().try_acquire(user_2).unwrap(); + + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .semaphore + .available_permits(), + 0 + ); + + // Try user_3 - should fail due to global limit + let result = rate_limiter.clone().try_acquire(user_3); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "Rate Limit Reached: Global limit" + ); + + drop(ticket_1_1); + + let ticket_3_1 = rate_limiter.clone().try_acquire(user_3).unwrap(); + + drop(ticket_1_2); + drop(ticket_2_1); + drop(ticket_2_2); + drop(ticket_3_1); + + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .semaphore + .available_permits(), + 4 + ); + assert_eq!( + rate_limiter.inner.lock().unwrap().active_connections.len(), + 0 + ); + } + + #[tokio::test] + async fn test_per_ip_limits_remain_enforced() { + let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); + let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); + + let rate_limiter = Arc::new(InMemoryRateLimit::new(5, 2)); + + let ticket_1_1 = rate_limiter.clone().try_acquire(user_1).unwrap(); + let ticket_1_2 = rate_limiter.clone().try_acquire(user_1).unwrap(); + + let result = rate_limiter.clone().try_acquire(user_1); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "Rate Limit Reached: IP limit exceeded" + ); + + let ticket_2_1 = rate_limiter.clone().try_acquire(user_2).unwrap(); + drop(ticket_1_1); + + let ticket_1_3 = rate_limiter.clone().try_acquire(user_1).unwrap(); + + let result = rate_limiter.clone().try_acquire(user_1); + assert!(result.is_err()); + assert_eq!( + result.err().unwrap().to_string(), + "Rate Limit Reached: IP limit exceeded" + ); + + drop(ticket_1_2); + drop(ticket_1_3); + drop(ticket_2_1); + + assert_eq!( + rate_limiter + .inner + .lock() + .unwrap() + .semaphore + .available_permits(), + 5 + ); + assert_eq!( + rate_limiter.inner.lock().unwrap().active_connections.len(), + 0 + ); + } + #[tokio::test] #[cfg(all(feature = "integration", test))] async fn test_instance_tracking_and_cleanup() {