From a10c0a6bbc2f437449634e7c7dd85909b7aded54 Mon Sep 17 00:00:00 2001 From: Joe Parks <26990067+jowparks@users.noreply.github.com> Date: Tue, 18 Mar 2025 15:59:14 -0700 Subject: [PATCH 1/2] feat(PROTO-945): implement Redis-based rate limiting and update dependencies - Added RedisRateLimit for distributed rate limiting, allowing connection tracking across multiple instances. - Updated Cargo.toml to include new dependencies for Redis and related packages. - Enhanced README with Redis integration instructions and usage examples. - Modified main.rs to support Redis configuration via command-line arguments. - Updated Cargo.lock with new package versions and dependencies. --- Cargo.lock | 140 +++++++++++++++++---- Cargo.toml | 3 + README.md | 29 ++++- src/main.rs | 56 ++++++++- src/rate_limit.rs | 306 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 502 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8c815b8..5a2dbce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -88,6 +88,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "atomic-waker" version = "1.1.2" @@ -354,6 +360,16 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -495,6 +511,8 @@ dependencies = [ "metrics", "metrics-derive", "metrics-exporter-prometheus", + "redis 0.24.0", + "redis-test", "reqwest", "ring", "serde_json", @@ -850,7 +868,7 @@ dependencies = [ "http-body", "hyper", "pin-project-lite", - "socket2", + "socket2 0.5.8", "tokio", "tower-service", "tracing", @@ -1094,12 +1112,6 @@ version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" -[[package]] -name = "linux-raw-sys" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" - [[package]] name = "litemap" version = "0.7.5" @@ -1271,6 +1283,34 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "object" version = "0.36.7" @@ -1519,6 +1559,50 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redis" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd" +dependencies = [ + "combine", + "itoa", + "percent-encoding", + "ryu", + "sha1_smol", + "socket2 0.4.10", + "url", +] + +[[package]] +name = "redis" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8034fb926579ff49d3fe58d288d5dcb580bf11e9bccd33224b45adebf0fd0c23" +dependencies = [ + "arc-swap", + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "ryu", + "sha1_smol", + "socket2 0.5.8", + "url", +] + +[[package]] +name = "redis-test" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27f3dafa5ba24e7ad63516ee55f80418b9fb60b82333c90d0a1df5305cb82066" +dependencies = [ + "rand 0.9.0", + "redis 0.29.1", + "socket2 0.5.8", + "tempfile", +] + [[package]] name = "redox_syscall" version = "0.5.10" @@ -1647,20 +1731,7 @@ dependencies = [ "bitflags", "errno", "libc", - "linux-raw-sys 0.4.15", - "windows-sys 0.59.0", -] - -[[package]] -name = "rustix" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7178faa4b75a30e269c71e61c353ce2748cf3d76f0c44c393f4e60abf49b825" -dependencies = [ - "bitflags", - "errno", - "libc", - "linux-raw-sys 0.9.3", + "linux-raw-sys", "windows-sys 0.59.0", ] @@ -1845,6 +1916,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1890,6 +1967,16 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +[[package]] +name = "socket2" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "socket2" version = "0.5.8" @@ -1951,14 +2038,15 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.19.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488960f40a3fd53d72c2a29a58722561dee8afdd175bd88e3db4677d7b2ba600" +checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ + "cfg-if", "fastrand", "getrandom 0.3.1", "once_cell", - "rustix 1.0.2", + "rustix", "windows-sys 0.59.0", ] @@ -2035,7 +2123,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.5.8", "tokio-macros", "windows-sys 0.52.0", ] @@ -2410,7 +2498,7 @@ dependencies = [ "either", "home", "once_cell", - "rustix 0.38.44", + "rustix", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 6db846c..8537535 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,9 @@ metrics-derive = "0.1" thiserror = "2.0.11" serde_json = "1.0.138" hostname = "0.4.0" +redis = "0.24.0" +redis-test = "0.9.0" + [dependencies.ring] version = "0.17.12" diff --git a/README.md b/README.md index bd8decc..796ad18 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ You can build and test the project using [Cargo](https://doc.rust-lang.org/cargo # Build the project cargo build -# Run all the tests +# Run all the tests (requires local version of redis to be installed) cargo test --all-features ``` @@ -35,3 +35,30 @@ You can see a full list of parameters by running: `docker run ghcr.io/base/flashblocks-websocket-proxy:master --help` +### Redis Integration + +The proxy supports distributed rate limiting with Redis. This is useful when running multiple instances of the proxy behind a load balancer, as it allows rate limits to be enforced across all instances. + +To enable Redis integration, use the following parameters: + +- `--redis-url` - Redis connection URL (e.g., `redis://localhost:6379`) +- `--redis-key-prefix` - Prefix for Redis keys (default: `flashblocks`) + +Example: + +```bash +docker run ghcr.io/base/flashblocks-websocket-proxy:master \ + --upstream-ws wss://your-sequencer-endpoint \ + --redis-url redis://redis:6379 \ + --global-connections-limit 1000 \ + --per-ip-connections-limit 10 +``` + +When Redis is enabled, the following features are available: + +- Distributed rate limiting across multiple proxy instances +- Connection tracking persists even if the proxy instance restarts +- More accurate global connection limiting in multi-instance deployments + +If the Redis connection fails, the proxy will automatically fall back to in-memory rate limiting. + diff --git a/src/main.rs b/src/main.rs index 222dbdc..0691739 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ mod server; mod subscriber; use crate::metrics::Metrics; -use crate::rate_limit::InMemoryRateLimit; +use crate::rate_limit::{InMemoryRateLimit, RateLimit}; use crate::registry::Registry; use crate::server::Server; use crate::subscriber::WebsocketSubscriber; @@ -16,6 +16,7 @@ use axum::http::Uri; use clap::Parser; use dotenvy::dotenv; use metrics_exporter_prometheus::PrometheusBuilder; +use rate_limit::RedisRateLimit; use std::net::SocketAddr; use std::sync::Arc; use tokio::signal::unix::{signal, SignalKind}; @@ -96,6 +97,21 @@ struct Args { /// Maximum backoff allowed for upstream connections #[arg(long, env, default_value = "20")] subscriber_max_interval: u64, + + #[arg( + long, + env, + help = "Redis URL for distributed rate limiting (e.g., redis://localhost:6379). If not provided, in-memory rate limiting will be used." + )] + redis_url: Option, + + #[arg( + long, + env, + default_value = "flashblocks", + help = "Prefix for Redis keys" + )] + redis_key_prefix: String, } #[tokio::main] @@ -203,10 +219,40 @@ async fn main() { let registry = Registry::new(sender, metrics.clone()); - let rate_limiter = Arc::new(InMemoryRateLimit::new( - args.global_connections_limit, - args.per_ip_connections_limit, - )); + let rate_limiter = match &args.redis_url { + Some(redis_url) => { + info!(message = "Using Redis rate limiter", redis_url = redis_url); + match RedisRateLimit::new( + redis_url, + args.global_connections_limit, + args.per_ip_connections_limit, + &args.redis_key_prefix, + ) { + Ok(limiter) => { + info!(message = "Connected to Redis successfully"); + Arc::new(limiter) as Arc + } + Err(e) => { + error!( + message = + "Failed to connect to Redis, falling back to in-memory rate limiting", + error = e.to_string() + ); + Arc::new(InMemoryRateLimit::new( + args.global_connections_limit, + args.per_ip_connections_limit, + )) as Arc + } + } + } + None => { + info!(message = "Using in-memory rate limiter"); + Arc::new(InMemoryRateLimit::new( + args.global_connections_limit, + args.per_ip_connections_limit, + )) as Arc + } + }; let server = Server::new( args.listen_addr, diff --git a/src/rate_limit.rs b/src/rate_limit.rs index bce754f..6cc6020 100644 --- a/src/rate_limit.rs +++ b/src/rate_limit.rs @@ -6,6 +6,9 @@ use tracing::{debug, warn}; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use redis::{Client, Commands, RedisError}; +use tracing::error; + #[derive(Error, Debug)] pub enum RateLimitError { #[error("Rate Limit Reached: {reason}")] @@ -118,6 +121,189 @@ impl RateLimit for InMemoryRateLimit { } } +pub struct RedisRateLimit { + redis_client: Client, + global_limit: usize, + per_ip_limit: usize, + semaphore: Arc, + key_prefix: String, +} + +impl RedisRateLimit { + pub fn new( + redis_url: &str, + global_limit: usize, + per_ip_limit: usize, + key_prefix: &str, + ) -> Result { + let client = Client::open(redis_url)?; + + Ok(Self { + redis_client: client, + global_limit, + per_ip_limit, + semaphore: Arc::new(Semaphore::new(global_limit)), + key_prefix: key_prefix.to_string(), + }) + } + + /// Get Redis key for tracking global connections + fn global_key(&self) -> String { + format!("{}:global:connections", self.key_prefix) + } + + /// Get Redis key for tracking connections per IP + fn ip_key(&self, addr: &IpAddr) -> String { + format!("{}:ip:{}:connections", self.key_prefix, addr) + } +} + +impl RateLimit for RedisRateLimit { + fn try_acquire(self: Arc, addr: IpAddr) -> Result { + // Get Redis connection first to check current counts + let mut conn = match self.redis_client.get_connection() { + Ok(conn) => conn, + Err(e) => { + error!( + message = "Failed to connect to Redis", + error = e.to_string() + ); + return Err(RateLimitError::Limit { + reason: "Redis connection failed".to_string(), + }); + } + }; + + // Check global count BEFORE incrementing + let global_connections: usize = conn.get(self.global_key()).unwrap_or(0); + + if global_connections >= self.global_limit { + debug!( + message = "Global limit reached", + global_connections = global_connections, + global_limit = self.global_limit + ); + return Err(RateLimitError::Limit { + reason: "Global connection limit reached".to_string(), + }); + } + + // Check IP count BEFORE incrementing + let ip_connections: usize = conn.get(self.ip_key(&addr)).unwrap_or(0); + if ip_connections >= self.per_ip_limit { + return Err(RateLimitError::Limit { + reason: format!("Per-IP connection limit reached for {}", addr), + }); + } + + // Now try to get the local semaphore permit + let permit = match self.semaphore.clone().try_acquire_owned() { + Ok(permit) => permit, + Err(_) => { + return Err(RateLimitError::Limit { + reason: "Global connection limit reached on this instance".to_string(), + }); + } + }; + + // Increment IP counter + let ip_connections: usize = match conn.incr(self.ip_key(&addr), 1) { + Ok(count) => count, + Err(e) => { + error!( + message = "Failed to increment IP counter in Redis", + error = e.to_string() + ); + return Err(RateLimitError::Limit { + reason: "Redis operation failed".to_string(), + }); + } + }; + + // Increment global counter + let global_connections: usize = match conn.incr(self.global_key(), 1) { + Ok(count) => { + debug!( + message = "Incremented global counter", + global_connections = count, + global_limit = self.global_limit + ); + count + } + Err(e) => { + // Roll back IP counter increment + let _: Result<(), _> = conn.decr(self.ip_key(&addr), 1); + + error!( + message = "Failed to increment global counter in Redis", + error = e.to_string() + ); + return Err(RateLimitError::Limit { + reason: "Redis operation failed".to_string(), + }); + } + }; + + debug!( + message = "Connection established", + ip = addr.to_string(), + ip_connections = ip_connections, + global_connections = global_connections + ); + + Ok(Ticket { + addr, + _permit: permit, + rate_limiter: self, + }) + } + + fn release(&self, addr: IpAddr) { + match self.redis_client.get_connection() { + Ok(mut conn) => { + // Decrement IP counter + let ip_connections: Result = conn.decr(self.ip_key(&addr), 1); + + if let Err(ref e) = ip_connections { + error!( + message = "Failed to decrement IP counter in Redis", + error = e.to_string() + ); + } + + // Decrement global counter + let global_connections: Result = conn.decr(self.global_key(), 1); + + if let Ok(count) = global_connections { + debug!( + message = "Decremented global counter on release", + global_connections = count, + global_limit = self.global_limit + ); + } else if let Err(ref e) = global_connections { + error!( + message = "Failed to decrement global counter in Redis", + error = e.to_string() + ); + } + + debug!( + message = "Connection released", + ip = addr.to_string(), + ip_connections = ip_connections.unwrap_or(0), + global_connections = global_connections.unwrap_or(0) + ); + } + Err(e) => { + error!( + message = "Failed to connect to Redis for release", + error = e.to_string() + ); + } + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -245,4 +431,124 @@ mod tests { let c4 = rate_limiter.clone().try_acquire(user_2); assert!(c4.is_ok()); } + + #[tokio::test] + async fn test_redis_rate_limits_with_mock() { + use redis_test::server::RedisServer; + + // Start a mock Redis server + let server = RedisServer::new(); + let client_addr = format!("redis://{}", server.client_addr()); + let client = redis::Client::open(client_addr.as_str()).unwrap(); + + // Wait for the server to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let rate_limiter = Arc::new(RedisRateLimit::new(&client_addr, 5, 1, "test").unwrap()); + + // Test IP addresses + let user_1 = IpAddr::from_str("127.0.0.1").unwrap(); + let user_2 = IpAddr::from_str("127.0.0.2").unwrap(); + + // Test 1: Acquire a connection successfully + let ticket1 = rate_limiter.clone().try_acquire(user_1).unwrap(); + + // Verify Redis state manually + { + let mut conn = client.get_connection().unwrap(); + let global_count: usize = redis::cmd("GET") + .arg("test:global:connections") + .query(&mut conn) + .unwrap(); + let ip_count: usize = redis::cmd("GET") + .arg("test:ip:127.0.0.1:connections") + .query(&mut conn) + .unwrap(); + + assert_eq!(global_count, 1); + assert_eq!(ip_count, 1); + } + + // Test 3: Third connection for same IP should fail (exceeds per-IP limit) + let result = rate_limiter.clone().try_acquire(user_1); + assert!(result.is_err()); + if let Err(RateLimitError::Limit { reason }) = result { + assert!(reason.contains("Per-IP connection limit")); + } else { + panic!("Expected a RateLimitError::Limit"); + } + + // Test 4: Different IP should work + let ticket2 = rate_limiter.clone().try_acquire(user_2).unwrap(); + + // Verify counts after multiple operations + { + let mut conn = client.get_connection().unwrap(); + let global_count: usize = redis::cmd("GET") + .arg("test:global:connections") + .query(&mut conn) + .unwrap(); + let ip1_count: usize = redis::cmd("GET") + .arg("test:ip:127.0.0.1:connections") + .query(&mut conn) + .unwrap(); + let ip2_count: usize = redis::cmd("GET") + .arg("test:ip:127.0.0.2:connections") + .query(&mut conn) + .unwrap(); + + assert_eq!(global_count, 2); + assert_eq!(ip1_count, 1); + assert_eq!(ip2_count, 1); + } + + // Test 5: Test release by dropping tickets + drop(ticket1); + + // Verify that counters were decremented + { + let mut conn = client.get_connection().unwrap(); + let global_count: usize = redis::cmd("GET") + .arg("test:global:connections") + .query(&mut conn) + .unwrap(); + let ip1_count: usize = redis::cmd("GET") + .arg("test:ip:127.0.0.1:connections") + .query(&mut conn) + .unwrap(); + + assert_eq!(global_count, 1); + assert_eq!(ip1_count, 0); + } + + // Test 6: Now we should be able to acquire another connection for user_1 + let _ = rate_limiter.clone().try_acquire(user_1).unwrap(); + + // Test 7: Test global limit + let rate_limiter_small = Arc::new( + RedisRateLimit::new( + &client_addr, + 3, // smaller global limit + 5, // large per-IP limit + "test2", + ) + .unwrap(), + ); + + // Acquire connections up to the global limit + let _t1 = rate_limiter_small.clone().try_acquire(user_1).unwrap(); + let _t2 = rate_limiter_small.clone().try_acquire(user_1).unwrap(); + let _t3 = rate_limiter_small.clone().try_acquire(user_1).unwrap(); + + // This should fail due to global limit + let result = rate_limiter_small.clone().try_acquire(user_1); + assert!(result.is_err()); + if let Err(RateLimitError::Limit { reason }) = result { + assert!(reason.contains("Global connection limit")); + } else { + panic!("Expected a RateLimitError::Limit"); + } + + drop(ticket2); + } } From e57c5408e89ea3622f55909606c330b8a5c687c3 Mon Sep 17 00:00:00 2001 From: Joe Parks <26990067+jowparks@users.noreply.github.com> Date: Wed, 19 Mar 2025 09:41:54 -0700 Subject: [PATCH 2/2] Add redis-tools to github action for test runs --- .github/workflows/ci.yaml | 5 +++++ Cargo.toml | 4 ++-- src/rate_limit.rs | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8739f83..4a87a2a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -16,6 +16,11 @@ jobs: steps: - uses: actions/checkout@v4 + - name: Install Redis for tests + run: | + sudo apt-get update + sudo apt-get install -y redis + - name: Install Rust toolchain uses: dtolnay/rust-toolchain@master with: diff --git a/Cargo.toml b/Cargo.toml index 8537535..720c407 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,11 +26,11 @@ thiserror = "2.0.11" serde_json = "1.0.138" hostname = "0.4.0" redis = "0.24.0" -redis-test = "0.9.0" +redis-test = { version = "0.9.0", optional = true } [dependencies.ring] version = "0.17.12" [features] -integration = [] \ No newline at end of file +integration = ["redis-test"] diff --git a/src/rate_limit.rs b/src/rate_limit.rs index 6cc6020..06bbf66 100644 --- a/src/rate_limit.rs +++ b/src/rate_limit.rs @@ -433,6 +433,7 @@ mod tests { } #[tokio::test] + #[cfg(all(feature = "integration", test))] async fn test_redis_rate_limits_with_mock() { use redis_test::server::RedisServer;