diff --git a/server/src/main.rs b/server/src/main.rs index 4ded95e..c0c656d 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -41,9 +41,17 @@ struct Options { #[arg(short, long, default_value_t = Duration::from_secs(60).into())] ttl: humantime::Duration, + /// Maximum number of entries to store + #[arg(short, long, default_value_t = 10000)] + capacity: usize, + /// Maximum payload size, in bytes #[arg(short, long, default_value = "4KiB")] max_bytes: ByteSize, + + /// Set this flag to test how much memory the server might use with a sessions map fully loaded + #[arg(long)] + mem_check: bool, } #[tokio::main] @@ -59,9 +67,25 @@ async fn main() { .try_into() .expect("Max bytes size too large"); + let sessions = matrix_http_rendezvous::Sessions::new(ttl, options.capacity); + + if options.mem_check { + tracing::info!( + "Filling cache with {capacity} entries of {max_bytes}", + capacity = options.capacity, + max_bytes = options.max_bytes.to_string_as(true) + ); + sessions.fill_for_mem_check(max_bytes).await; + tracing::info!("Done filling, waiting 60 seconds"); + tokio::time::sleep(Duration::from_secs(60)).await; + return; + } + + tokio::spawn(sessions.eviction_task(Duration::from_secs(60))); + let addr = SocketAddr::from((options.address, options.port)); - let service = matrix_http_rendezvous::router(&prefix, ttl, max_bytes); + let service = matrix_http_rendezvous::router(&prefix, sessions, max_bytes); tracing::info!("Listening on http://{addr}"); tracing::info!( diff --git a/src/lib.rs b/src/lib.rs index b7d4592..2d19142 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,8 +18,8 @@ #![allow(clippy::trait_duplication_in_bounds)] use std::{ - collections::HashMap, - ops::Deref, + collections::BTreeMap, + future::Future, sync::Arc, time::{Duration, SystemTime}, }; @@ -112,16 +112,54 @@ impl Session { } #[derive(Clone, Default)] -struct Sessions { - // TODO: is that global lock alright? - inner: Arc>>, - ttl: Duration, +pub struct Sessions { + inner: Arc>>, generator: Arc>, + capacity: usize, + hard_capacity: usize, + ttl: Duration, +} + +fn evict(sessions: &mut BTreeMap, capacity: usize) { + // NOTE: eviction is based on the fact that ULIDs are monotonically increasing, by evictin the + // keys at the head of the map + + // List of keys to evict + let keys: Vec = sessions + .keys() + .take(sessions.len() - capacity) + .copied() + .collect(); + + // Now evict the keys + for key in keys { + sessions.remove(&key); + } } impl Sessions { + #[must_use] + pub fn new(ttl: Duration, capacity: usize) -> Self { + Self { + inner: Arc::new(RwLock::new(BTreeMap::new())), + generator: Arc::new(Mutex::new(ulid::Generator::new())), + ttl, + capacity, + hard_capacity: capacity * 2, + } + } + async fn insert(self, id: Ulid, session: Session, ttl: Duration) { - self.inner.write().await.insert(id, session); + { + let mut sessions = self.inner.write().await; + sessions.insert(id, session); + // When inserting, we check if we hit the 'hard' capacity, so that we never go over + // that capacity + if sessions.len() >= self.hard_capacity { + evict(&mut sessions, self.capacity); + } + } + // TODO: cancel this task when an item gets deleted tokio::task::spawn(async move { tokio::time::sleep(ttl).await; @@ -138,13 +176,58 @@ impl Sessions { // millisecond, which is very unlikely .expect("Failed to generate random ID") } -} -impl Deref for Sessions { - type Target = RwLock>; + /// A loop which evicts keys if the capacity is reached + pub fn eviction_task( + &self, + interval: Duration, + ) -> impl Future + Send + Sync + 'static { + let this = self.clone(); + async move { + let mut interval = tokio::time::interval(interval); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + interval.tick().await; + this.evict().await; + } + } + } + + async fn evict(&self) { + if self.inner.read().await.len() > self.capacity { + let mut sessions = self.inner.write().await; + evict(&mut sessions, self.capacity); + } + } - fn deref(&self) -> &Self::Target { - &self.inner + /// Fill the sessions storage to check how much memory it might use on max capacity + /// + /// # Panics + /// + /// It panics if the session storage is not empty + pub async fn fill_for_mem_check(&self, entry_size: usize) { + let mut sessions = self.inner.write().await; + let mut generator = self.generator.lock().await; + assert!(sessions.is_empty()); + + let data: Vec = std::iter::repeat(42).take(entry_size).collect(); + sessions.extend((0..self.capacity).map(|_| { + let data = Bytes::from(data.clone()); + let id = generator.generate().unwrap(); + let session = Session::new(data, mime::APPLICATION_OCTET_STREAM, self.ttl); + (id, session) + })); + + // Start the deletion tasks for all the sessions + let ttl = self.ttl; + for &key in sessions.keys() { + let inner = self.inner.clone(); + tokio::task::spawn(async move { + tokio::time::sleep(ttl).await; + inner.write().await.remove(&key); + }); + } } } @@ -185,7 +268,7 @@ async fn new_session( } async fn delete_session(State(sessions): State, Path(id): Path) -> StatusCode { - if sessions.write().await.remove(&id).is_some() { + if sessions.inner.write().await.remove(&id).is_some() { StatusCode::NO_CONTENT } else { StatusCode::NOT_FOUND @@ -199,7 +282,7 @@ async fn update_session( if_match: Option>, payload: Bytes, ) -> Response { - if let Some(session) = sessions.write().await.get_mut(&id) { + if let Some(session) = sessions.inner.write().await.get_mut(&id) { if let Some(TypedHeader(if_match)) = if_match { if !if_match.precondition_passes(&session.etag()) { return (StatusCode::PRECONDITION_FAILED, session.typed_headers()).into_response(); @@ -221,7 +304,7 @@ async fn get_session( Path(id): Path, if_none_match: Option>, ) -> Response { - let sessions = sessions.read().await; + let sessions = sessions.inner.read().await; let session = if let Some(session) = sessions.get(&id) { session } else { @@ -244,18 +327,12 @@ async fn get_session( } #[must_use] -pub fn router(prefix: &str, ttl: Duration, max_bytes: usize) -> Router<(), B> +pub fn router(prefix: &str, sessions: Sessions, max_bytes: usize) -> Router<(), B> where B: HttpBody + Send + 'static, ::Data: Send, ::Error: std::error::Error + Send + Sync, { - let sessions = Sessions { - inner: Arc::default(), - ttl, - generator: Arc::default(), - }; - let state = AppState::new(sessions); let router = Router::with_state(state) .route("/", post(new_session)) @@ -360,7 +437,8 @@ mod tests { #[tokio::test] async fn test_post_and_get() { let ttl = Duration::from_secs(60); - let app = router("/", ttl, 4096); + let sessions = Sessions::new(ttl, 1024); + let app = router("/", sessions, 4096); let body = r#"{"hello": "world"}"#.to_string(); let request = Request::post("/") @@ -399,7 +477,8 @@ mod tests { #[tokio::test] async fn test_monotonically_increasing() { let ttl = Duration::from_secs(60); - let app = router("/", ttl, 4096); + let sessions = Sessions::new(ttl, 1024); + let app = router("/", sessions.clone(), 4096); // Prepare a thousand requests let mut requests = Vec::with_capacity(1000); @@ -433,6 +512,7 @@ mod tests { #[tokio::test] async fn test_post_max_bytes() { let ttl = Duration::from_secs(60); + let sessions = Sessions::new(ttl, 1024); let body = br#"{"hello": "world"}"#; @@ -442,7 +522,10 @@ mod tests { .header(CONTENT_TYPE, "application/json") .body(slow_body) .unwrap(); - let response = router("/", ttl, 8).oneshot(request).await.unwrap(); + let response = router("/", sessions.clone(), 8) + .oneshot(request) + .await + .unwrap(); assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); // It works with exactly the right size @@ -451,7 +534,10 @@ mod tests { .header(CONTENT_TYPE, "application/json") .body(slow_body) .unwrap(); - let response = router("/", ttl, body.len()).oneshot(request).await.unwrap(); + let response = router("/", sessions.clone(), body.len()) + .oneshot(request) + .await + .unwrap(); assert_eq!(response.status(), StatusCode::CREATED); // It doesn't work even if the size is one too short @@ -460,7 +546,7 @@ mod tests { .header(CONTENT_TYPE, "application/json") .body(slow_body) .unwrap(); - let response = router("/", ttl, body.len() - 1) + let response = router("/", sessions.clone(), body.len() - 1) .oneshot(request) .await .unwrap(); @@ -470,7 +556,7 @@ mod tests { let body = vec![42; 4 * 1024 * 1024].into_boxed_slice(); let slow_body = SlowBody::from_bytes(Bytes::from(body)).with_chunk_size(128); let request = Request::post("/").body(slow_body).unwrap(); - let response = router("/", ttl, 4 * 1024 * 1024) + let response = router("/", sessions.clone(), 4 * 1024 * 1024) .oneshot(request) .await .unwrap(); @@ -480,7 +566,7 @@ mod tests { let body = vec![42; 4 * 1024 * 1024 + 1].into_boxed_slice(); let slow_body = SlowBody::from_bytes(Bytes::from(body)).with_chunk_size(128); let request = Request::post("/").body(slow_body).unwrap(); - let response = router("/", ttl, 4 * 1024 * 1024) + let response = router("/", sessions.clone(), 4 * 1024 * 1024) .oneshot(request) .await .unwrap(); @@ -490,7 +576,8 @@ mod tests { #[tokio::test] async fn test_post_and_get_if_none_match() { let ttl = Duration::from_secs(60); - let app = router("/", ttl, 4096); + let sessions = Sessions::new(ttl, 1024); + let app = router("/", sessions, 4096); let body = r#"{"hello": "world"}"#.to_string(); let request = Request::post("/") @@ -517,7 +604,8 @@ mod tests { #[tokio::test] async fn test_post_and_put() { let ttl = Duration::from_secs(60); - let app = router("/", ttl, 4096); + let sessions = Sessions::new(ttl, 1024); + let app = router("/", sessions, 4096); let body = r#"{"hello": "world"}"#.to_string(); let request = Request::post("/") @@ -544,7 +632,8 @@ mod tests { #[tokio::test] async fn test_post_and_put_if_match() { let ttl = Duration::from_secs(60); - let app = router("/", ttl, 4096); + let sessions = Sessions::new(ttl, 1024); + let app = router("/", sessions, 4096); let body = r#"{"hello": "world"}"#.to_string(); let request = Request::post("/") @@ -582,7 +671,8 @@ mod tests { #[tokio::test] async fn test_post_delete_and_get() { let ttl = Duration::from_secs(60); - let app = router("/", ttl, 4096); + let sessions = Sessions::new(ttl, 1024); + let app = router("/", sessions, 4096); let body = r#"{"hello": "world"}"#.to_string(); let request = Request::post("/") @@ -608,4 +698,70 @@ mod tests { let response = app.oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::NOT_FOUND); } + + #[tokio::test] + async fn test_eviction() { + let ttl = Duration::from_secs(60); + let sessions = Sessions::new(ttl, 2); + let app = router("/", sessions.clone(), 4096); + + let request = Request::post("/").body(String::new()).unwrap(); + let response = app.clone().oneshot(request).await.unwrap(); + let first_location = response.headers().get(LOCATION).unwrap().to_str().unwrap(); + + let request = Request::post("/").body(String::new()).unwrap(); + let response = app.clone().oneshot(request).await.unwrap(); + let second_location = response.headers().get(LOCATION).unwrap().to_str().unwrap(); + + sessions.evict().await; + + // Both entries are still there + let url = format!("/{first_location}"); + let request = Request::get(&url).body(String::new()).unwrap(); + let response = app.clone().oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let url = format!("/{second_location}"); + let request = Request::get(&url).body(String::new()).unwrap(); + let response = app.clone().oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + // Sending a third request + let request = Request::post("/").body(String::new()).unwrap(); + app.clone().oneshot(request).await.unwrap(); + + // First entry should still be there, there was no eviction yet because we didn't hit hard + // capacity + let url = format!("/{first_location}"); + let request = Request::get(&url).body(String::new()).unwrap(); + let response = app.clone().oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + sessions.evict().await; + + // First entry should be gone because of the eviction + let url = format!("/{first_location}"); + let request = Request::get(&url).body(String::new()).unwrap(); + let response = app.clone().oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + // Second entry should still be there + let url = format!("/{second_location}"); + let request = Request::get(&url).body(String::new()).unwrap(); + let response = app.clone().oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + // Sending two other requests, so we hit hard capacity + let request = Request::post("/").body(String::new()).unwrap(); + app.clone().oneshot(request).await.unwrap(); + let request = Request::post("/").body(String::new()).unwrap(); + app.clone().oneshot(request).await.unwrap(); + + // Second entry should be gone, because we hit hard capacity, even though we didn't had the + // eviction triggered + let url = format!("/{second_location}"); + let request = Request::get(&url).body(String::new()).unwrap(); + let response = app.clone().oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } } diff --git a/synapse/Cargo.toml b/synapse/Cargo.toml index b0f3ab8..dbd854a 100644 --- a/synapse/Cargo.toml +++ b/synapse/Cargo.toml @@ -23,6 +23,7 @@ pyo3 = { version = "0.17.2", features = ["extension-module", "abi3-py37", "anyho pyo3-log = "0.7.0" pyo3-matrix-synapse-module = "0.1.1" serde = { version = "1.0.145", features = ["derive"] } +tokio = "1.21.2" tower = { version = "0.4.13", features = ["util"] } tracing = { version = "0.1.37", features = ["log", "log-always"] } diff --git a/synapse/README.rst b/synapse/README.rst index 886c058..3c071aa 100644 --- a/synapse/README.rst +++ b/synapse/README.rst @@ -33,10 +33,11 @@ Usage Configuration options --------------------- -Apart from the `prefix` the following config options are available: +Apart from the ``prefix`` the following config options are available: -- `ttl`: The time-to-live of the rendezvous session. Defaults to 60s. -- `max_bytes`: The maximum number of bytes that can be sent in a single request. Defaults to 4096 bytes. +- ``ttl``: The time-to-live of the rendezvous session. Defaults to 60s. +- ``max_bytes``: The maximum number of bytes that can be sent in a single request. Defaults to 4096 bytes. +- ``max_entries``: The maximum number of entries to keep. Defaults to 10 000. An example configuration setting these and a custom prefix would like:: @@ -46,6 +47,14 @@ An example configuration setting these and a custom prefix would like:: prefix: /rendezvous ttl: 15s max_bytes: 10KiB + max_entries: 50000 experimental_features: msc3886_endpoint: /rendezvous # this should match above + +^^^^^^^^^^^^ +Memory usage +^^^^^^^^^^^^ + +``max_entries`` and ``max_bytes`` allow to tune how much memory the module may take. +There is a constant overhead of approximately 1KiB per entry, so with the default config (``max_bytes = 4KiB``, ``max_entries = 10000``), the maximum theorical memory footprint of the module is ``(4KiB + ~1KiB) * 10000 ~= 50MiB``. diff --git a/synapse/src/lib.rs b/synapse/src/lib.rs index 50aa539..80679fd 100644 --- a/synapse/src/lib.rs +++ b/synapse/src/lib.rs @@ -36,6 +36,10 @@ fn default_max_bytes() -> ByteSize { ByteSize::kib(4) } +fn default_max_entries() -> usize { + 10_000 +} + #[pyclass] #[derive(Deserialize)] struct Config { @@ -46,6 +50,9 @@ struct Config { #[serde(default = "default_max_bytes")] max_bytes: ByteSize, + + #[serde(default = "default_max_entries")] + max_entries: usize, } #[pyclass] @@ -68,7 +75,10 @@ impl SynapseRendezvousModule { .try_into() .context("Could not convert max_bytes from config")?; - let service = matrix_http_rendezvous::router(&config.prefix, config.ttl, max_bytes) + let sessions = matrix_http_rendezvous::Sessions::new(config.ttl, config.max_entries); + tokio::spawn(sessions.eviction_task(Duration::from_secs(60))); + + let service = matrix_http_rendezvous::router(&config.prefix, sessions, max_bytes) .map_response(|res| res.map(|b| b.map_err(|e| anyhow!(e)))); module_api.register_web_service(&config.prefix, service)?; Ok(Self)