diff --git a/Cargo.toml b/Cargo.toml index e554fa6..5a3452a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,12 @@ repository = "https://github.com/maxcountryman/axum-sessions" documentation = "https://docs.rs/axum-sessions" [dependencies] -async-session = "3.0.0" +async-session = { git = "https://github.com/http-rs/async-session", rev = "35cb0998f91b81b133c3314414adae4019a62741", default-features = false } +base64 = "0.21.0" futures = "0.3.21" +hmac = { version = "0.12.1", features = ["std"] } http-body = "0.4.5" +sha2 = "0.10.6" tower = "0.4.12" tracing = "0.1" @@ -34,6 +37,7 @@ features = ["sync"] http = "0.2.8" hyper = "0.14.19" serde = "1.0.147" +serde_json = "1.0.106" [dev-dependencies.rand] version = "0.8.5" @@ -43,3 +47,7 @@ features = ["min_const_gen"] version = "1.20.1" default-features = false features = ["macros", "rt-multi-thread"] + +[dev-dependencies.async-session-memory-store] +git = "https://github.com/http-rs/async-session" +rev = "35cb0998f91b81b133c3314414adae4019a62741" diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 6696429..3e38e3f 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -2,3 +2,7 @@ members = ["*"] exclude = ["target"] resolver = "2" + +[workspace.dependencies.async-session-memory-store] +git = "https://github.com/http-rs/async-session" +rev = "35cb0998f91b81b133c3314414adae4019a62741" diff --git a/examples/async-sqlx-session/Cargo.toml b/examples/async-sqlx-session/Cargo.toml index 9f052be..3a2efc3 100644 --- a/examples/async-sqlx-session/Cargo.toml +++ b/examples/async-sqlx-session/Cargo.toml @@ -9,7 +9,9 @@ axum = "0.6.0" axum-sessions = { path = "../../" } [dependencies.async-sqlx-session] -version = "0.4.0" +# version = "0.4.0" +git = "https://github.com/sbihel/async-sqlx-session" +rev = "9ed73b641cb5e5c0c08b9d29dd723b3c8e57993d" default-features = false features = ["sqlite"] @@ -18,8 +20,7 @@ version = "0.8.5" features = ["min_const_gen"] [dependencies.sqlx] -version = "0.5.13" -default-features = false +version = "0.7.1" features = ["runtime-tokio-rustls", "sqlite"] [dependencies.tokio] diff --git a/examples/async-sqlx-session/src/main.rs b/examples/async-sqlx-session/src/main.rs index 3e10433..3adb122 100644 --- a/examples/async-sqlx-session/src/main.rs +++ b/examples/async-sqlx-session/src/main.rs @@ -6,10 +6,7 @@ use async_sqlx_session::SqliteSessionStore; use axum::{routing::get, Router}; -use axum_sessions::{ - extractors::{ReadableSession, WritableSession}, - SessionLayer, -}; +use axum_sessions::{extractors::Session, SessionLayer}; use rand::Rng; #[tokio::main] @@ -24,14 +21,14 @@ async fn main() { let secret = rand::thread_rng().gen::<[u8; 128]>(); let session_layer = SessionLayer::new(store, &secret); - async fn increment_count_handler(mut session: WritableSession) { + async fn increment_count_handler(mut session: Session) { let previous: usize = session.get("counter").unwrap_or_default(); session .insert("counter", previous + 1) .expect("Could not store counter."); } - async fn handler(session: ReadableSession) -> String { + async fn handler(session: Session) -> String { format!( "Counter: {}", session.get::("counter").unwrap_or_default() diff --git a/examples/counter/Cargo.toml b/examples/counter/Cargo.toml index 93624e3..4d34bed 100644 --- a/examples/counter/Cargo.toml +++ b/examples/counter/Cargo.toml @@ -7,6 +7,7 @@ publish = false [dependencies] axum = "0.6.0" axum-sessions = { path = "../../" } +async-session-memory-store = { workspace = true } [dependencies.rand] version = "0.8.5" diff --git a/examples/counter/src/main.rs b/examples/counter/src/main.rs index 41d7749..e26412e 100644 --- a/examples/counter/src/main.rs +++ b/examples/counter/src/main.rs @@ -4,12 +4,9 @@ //! cd examples && cargo run -p example-counter //! ``` +use async_session_memory_store::MemoryStore; use axum::{response::IntoResponse, routing::get, Router}; -use axum_sessions::{ - async_session::MemoryStore, - extractors::{ReadableSession, WritableSession}, - SessionLayer, -}; +use axum_sessions::{extractors::Session, SessionLayer}; use rand::Rng; #[tokio::main] @@ -18,7 +15,7 @@ async fn main() { let secret = rand::thread_rng().gen::<[u8; 128]>(); let session_layer = SessionLayer::new(store, &secret).with_secure(false); - async fn display_handler(session: ReadableSession) -> impl IntoResponse { + async fn display_handler(session: Session) -> impl IntoResponse { let mut count = 0; count = session.get("count").unwrap_or(count); format!( @@ -27,14 +24,14 @@ async fn main() { ) } - async fn increment_handler(mut session: WritableSession) -> impl IntoResponse { + async fn increment_handler(mut session: Session) -> impl IntoResponse { let mut count = 1; count = session.get("count").map(|n: i32| n + 1).unwrap_or(count); session.insert("count", count).unwrap(); format!("Count is: {}", count) } - async fn reset_handler(mut session: WritableSession) -> impl IntoResponse { + async fn reset_handler(mut session: Session) -> impl IntoResponse { session.destroy(); "Count reset" } diff --git a/examples/regenerate/Cargo.toml b/examples/regenerate/Cargo.toml index 91f3ca1..1736f24 100644 --- a/examples/regenerate/Cargo.toml +++ b/examples/regenerate/Cargo.toml @@ -7,6 +7,7 @@ publish = false [dependencies] axum = "0.6.0" axum-sessions = { path = "../../" } +async-session-memory-store = { workspace = true } [dependencies.rand] version = "0.8.5" diff --git a/examples/regenerate/src/main.rs b/examples/regenerate/src/main.rs index 3306624..c0e08c7 100644 --- a/examples/regenerate/src/main.rs +++ b/examples/regenerate/src/main.rs @@ -4,12 +4,9 @@ //! cd examples && cargo run -p example-regenerate //! ``` +use async_session_memory_store::MemoryStore; use axum::{routing::get, Router}; -use axum_sessions::{ - async_session::MemoryStore, - extractors::{ReadableSession, WritableSession}, - SessionLayer, -}; +use axum_sessions::{extractors::Session, SessionLayer}; use rand::Rng; #[tokio::main] @@ -18,19 +15,19 @@ async fn main() { let secret = rand::thread_rng().gen::<[u8; 128]>(); let session_layer = SessionLayer::new(store, &secret); - async fn regenerate_handler(mut session: WritableSession) { + async fn regenerate_handler(mut session: Session) { // NB: This DOES NOT update the store, meaning that both sessions will still be // found. session.regenerate(); } - async fn insert_handler(mut session: WritableSession) { + async fn insert_handler(mut session: Session) { session .insert("foo", 42) .expect("Could not store the answer."); } - async fn handler(session: ReadableSession) -> String { + async fn handler(session: Session) -> String { session .get::("foo") .map(|answer| format!("{}", answer)) diff --git a/examples/signin/Cargo.toml b/examples/signin/Cargo.toml index 8387aee..a1b6e5b 100644 --- a/examples/signin/Cargo.toml +++ b/examples/signin/Cargo.toml @@ -7,6 +7,7 @@ publish = false [dependencies] axum = "0.6.0" axum-sessions = { path = "../../" } +async-session-memory-store = { workspace = true } [dependencies.rand] version = "0.8.5" diff --git a/examples/signin/src/main.rs b/examples/signin/src/main.rs index 5d0d00a..15fbd49 100644 --- a/examples/signin/src/main.rs +++ b/examples/signin/src/main.rs @@ -4,12 +4,9 @@ //! cd examples && cargo run -p example-signin //! ``` +use async_session_memory_store::MemoryStore; use axum::{routing::get, Router}; -use axum_sessions::{ - async_session::MemoryStore, - extractors::{ReadableSession, WritableSession}, - SessionLayer, -}; +use axum_sessions::{extractors::Session, SessionLayer}; use rand::Rng; #[tokio::main] @@ -18,17 +15,17 @@ async fn main() { let secret = rand::thread_rng().gen::<[u8; 128]>(); let session_layer = SessionLayer::new(store, &secret); - async fn signin_handler(mut session: WritableSession) { + async fn signin_handler(mut session: Session) { session .insert("signed_in", true) .expect("Could not sign in."); } - async fn signout_handler(mut session: WritableSession) { + async fn signout_handler(mut session: Session) { session.destroy(); } - async fn protected_handler(session: ReadableSession) -> &'static str { + async fn protected_handler(session: Session) -> &'static str { if session.get::("signed_in").unwrap_or(false) { "Shh, it's secret!" } else { diff --git a/src/extractors.rs b/src/extractors.rs index 9b9a547..85fe0ca 100644 --- a/src/extractors.rs +++ b/src/extractors.rs @@ -3,78 +3,46 @@ use std::ops::{Deref, DerefMut}; use axum::{async_trait, extract::FromRequestParts, http::request::Parts, Extension}; -use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard}; +use tokio::sync::OwnedMutexGuard; use crate::SessionHandle; /// An extractor which provides a readable session. Sessions may have many /// readers. #[derive(Debug)] -pub struct ReadableSession { - session: OwnedRwLockReadGuard, +pub struct Session { + session_guard: OwnedMutexGuard, } -impl Deref for ReadableSession { - type Target = OwnedRwLockReadGuard; +impl Deref for Session { + type Target = OwnedMutexGuard; fn deref(&self) -> &Self::Target { - &self.session + &self.session_guard } } -#[async_trait] -impl FromRequestParts for ReadableSession -where - S: Send + Sync, -{ - type Rejection = std::convert::Infallible; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let Extension(session_handle): Extension = - Extension::from_request_parts(parts, state) - .await - .expect("Session extension missing. Is the session layer installed?"); - let session = session_handle.read_owned().await; - - Ok(Self { session }) - } -} - -/// An extractor which provides a writable session. Sessions may have only one -/// writer. -#[derive(Debug)] -pub struct WritableSession { - session: OwnedRwLockWriteGuard, -} - -impl Deref for WritableSession { - type Target = OwnedRwLockWriteGuard; - - fn deref(&self) -> &Self::Target { - &self.session - } -} - -impl DerefMut for WritableSession { +impl DerefMut for Session { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.session + &mut self.session_guard } } #[async_trait] -impl FromRequestParts for WritableSession +impl FromRequestParts for Session where - S: Send + Sync, + S: Send + Sync + Clone, { type Rejection = std::convert::Infallible; - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let Extension(session_handle): Extension = - Extension::from_request_parts(parts, state) - .await - .expect("Session extension missing. Is the session layer installed?"); - let session = session_handle.write_owned().await; + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + use axum::RequestPartsExt; + let Extension(session) = parts + .extract::>() + .await + .expect("Session extension missing. Is the session layer installed?"); - Ok(Self { session }) + let session_guard = session.lock_owned().await; + Ok(Self { session_guard }) } } diff --git a/src/lib.rs b/src/lib.rs index ac3285a..b79a533 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,19 +6,17 @@ //! when they're not found or are otherwise invalid. When a valid, known cookie //! is received in a request, the session is hydrated from this cookie. The //! middleware provides sessions via [`SessionHandle`]. Handlers use the -//! [`ReadableSession`](crate::extractors::ReadableSession) and -//! [`WritableSession`](crate::extractors::WritableSession) extractors to read -//! from and write to sessions respectively. +//! [`Session`](crate::extractors::Session) extractor to read from and write to +//! sessions respectively. //! //! # Example //! //! Using the middleware with axum is straightforward: //! //! ```rust,no_run +//! use async_session_memory_store::MemoryStore; //! use axum::{routing::get, Router}; -//! use axum_sessions::{ -//! async_session::MemoryStore, extractors::WritableSession, PersistencePolicy, SessionLayer, -//! }; +//! use axum_sessions::{extractors::Session, PersistencePolicy, SessionLayer}; //! //! #[tokio::main] //! async fn main() { @@ -26,7 +24,7 @@ //! let secret = b"..."; // MUST be at least 64 bytes! //! let session_layer = SessionLayer::new(store, secret); //! -//! async fn handler(mut session: WritableSession) { +//! async fn handler(mut session: Session) { //! session //! .insert("foo", 42) //! .expect("Could not store the answer."); @@ -47,8 +45,9 @@ //! ```rust //! use std::convert::Infallible; //! +//! use async_session_memory_store::MemoryStore; //! use axum::http::header::SET_COOKIE; -//! use axum_sessions::{async_session::MemoryStore, SessionHandle, SessionLayer}; +//! use axum_sessions::{SessionHandle, SessionLayer}; //! use http::{Request, Response}; //! use hyper::Body; //! use rand::Rng; @@ -56,7 +55,7 @@ //! //! async fn handle(request: Request) -> Result, Infallible> { //! let session_handle = request.extensions().get::().unwrap(); -//! let session = session_handle.read().await; +//! let session = session_handle.lock().await; //! // Use the session as you'd like. //! //! Ok(Response::new(Body::empty())) diff --git a/src/session.rs b/src/session.rs index ac48999..941e2cb 100644 --- a/src/session.rs +++ b/src/session.rs @@ -7,12 +7,7 @@ use std::{ time::Duration, }; -use async_session::{ - base64, - hmac::{Hmac, Mac, NewMac}, - sha2::Sha256, - SessionStore, -}; +use async_session::SessionStore; use axum::{ http::{ header::{HeaderValue, COOKIE, SET_COOKIE}, @@ -21,8 +16,11 @@ use axum::{ response::Response, }; use axum_extra::extract::cookie::{Cookie, Key, SameSite}; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; use futures::future::BoxFuture; -use tokio::sync::RwLock; +use hmac::{Hmac, Mac}; +use sha2::{digest::generic_array::GenericArray, Sha256}; +use tokio::sync::Mutex; use tower::{Layer, Service}; const BASE64_DIGEST_LEN: usize = 44; @@ -30,13 +28,11 @@ const BASE64_DIGEST_LEN: usize = 44; /// A type alias which provides a handle to the underlying session. /// /// This is provided via [`http::Extensions`](axum::http::Extensions). Most -/// applications will use the -/// [`ReadableSession`](crate::extractors::ReadableSession) and -/// [`WritableSession`](crate::extractors::WritableSession) extractors rather -/// than using the handle directly. A notable exception is when using this -/// library as a generic Tower middleware: such use cases will consume the -/// handle directly. -pub type SessionHandle = Arc>; +/// applications will use the [`Session`](crate::extractors::Session) +/// extractor rather than using the handle directly. A notable exception is +/// when using this library as a generic Tower middleware: such use cases will +/// consume the handle directly. +pub type SessionHandle = Arc>; /// Controls how the session data is persisted and created. #[derive(Clone)] @@ -66,7 +62,10 @@ pub struct SessionLayer { key: Key, } -impl SessionLayer { +impl SessionLayer +where + Store: SessionStore + Clone + Send + Sync + 'static, +{ /// Creates a layer which will attach a [`SessionHandle`] to requests via an /// extension. This session is derived from a cryptographically signed /// cookie. When the client sends a valid, known cookie then the session is @@ -86,7 +85,8 @@ impl SessionLayer { /// of your application: /// /// ```rust - /// # use axum_sessions::{PersistencePolicy, SessionLayer, async_session::MemoryStore, SameSite}; + /// # use axum_sessions::{PersistencePolicy, SessionLayer, SameSite}; + /// # use async_session_memory_store::MemoryStore; /// # use std::time::Duration; /// SessionLayer::new( /// MemoryStore::new(), @@ -183,11 +183,11 @@ impl SessionLayer { async fn load_or_create(&self, cookie_value: Option) -> SessionHandle { let session = match cookie_value { - Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(), + Some(cookie_value) => self.store.load_session(&cookie_value).await.ok().flatten(), None => None, }; - Arc::new(RwLock::new( + Arc::new(Mutex::new( session .and_then(async_session::Session::validate) .unwrap_or_default(), @@ -243,7 +243,7 @@ impl SessionLayer { mac.update(cookie.value().as_bytes()); // Cookie's new value is [MAC | original-value]. - let mut new_value = base64::encode(mac.finalize().into_bytes()); + let mut new_value = BASE64.encode(mac.finalize().into_bytes()); new_value.push_str(cookie.value()); cookie.set_value(new_value); } @@ -260,18 +260,21 @@ impl SessionLayer { // Split [MAC | original-value] into its two parts. let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN); - let digest = base64::decode(digest_str).map_err(|_| "bad base64 digest")?; + let digest = BASE64.decode(digest_str).map_err(|_| "bad base64 digest")?; // Perform the verification. let mut mac = Hmac::::new_from_slice(self.key.signing()).expect("good key"); mac.update(value.as_bytes()); - mac.verify(&digest) + mac.verify(GenericArray::from_slice(&digest)) .map(|_| value.to_string()) .map_err(|_| "value did not verify") } } -impl Layer for SessionLayer { +impl Layer for SessionLayer +where + Store: SessionStore + Clone + Send + Sync + 'static, +{ type Service = Session; fn layer(&self, inner: Inner) -> Self::Service { @@ -289,13 +292,13 @@ pub struct Session { layer: SessionLayer, } -impl Service> - for Session +impl Service> for Session where Inner: Service, Response = Response> + Clone + Send + 'static, ResBody: Send + 'static, ReqBody: Send + 'static, Inner::Future: Send + 'static, + Store: SessionStore + Clone + Send + Sync + 'static, { type Response = Inner::Response; type Error = Inner::Error; @@ -326,28 +329,23 @@ where let mut inner = std::mem::replace(&mut self.inner, inner); Box::pin(async move { let session_handle = session_layer.load_or_create(cookie_value.clone()).await; + let mut session = session_handle.lock().await; - let mut session = session_handle.write().await; if let Some(ttl) = session_layer.session_ttl { (*session).expire_in(ttl); } drop(session); request.extensions_mut().insert(session_handle.clone()); + let mut response = inner.call(request).await?; - let session = session_handle.read().await; + let mut session = session_handle.lock().await; let (session_is_destroyed, session_data_changed) = (session.is_destroyed(), session.data_changed()); - drop(session); - // Pull out the session so we can pass it to the store without `Clone` blowing - // away the `cookie_value`. - let session = RwLock::into_inner( - Arc::try_unwrap(session_handle).expect("Session handle still has owners."), - ); if session_is_destroyed { - if let Err(e) = session_layer.store.destroy_session(session).await { + if let Err(e) = session_layer.store.destroy_session(&mut session).await { tracing::error!("Failed to destroy session: {:?}", e); *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; } @@ -366,7 +364,7 @@ where // - If we use the `ChangedOnly` policy, only // `session.data_changed()` should trigger this branch. } else if session_layer.should_store(&cookie_value, session_data_changed) { - match session_layer.store.store_session(session).await { + match session_layer.store.store_session(&mut session).await { Ok(Some(cookie_value)) => { let cookie = session_layer.build_cookie(cookie_value); response.headers_mut().append( @@ -391,10 +389,7 @@ where #[cfg(test)] mod tests { - use async_session::{ - serde::{Deserialize, Serialize}, - serde_json, - }; + use async_session_memory_store::MemoryStore; use axum::http::{Request, Response}; use http::{ header::{COOKIE, SET_COOKIE}, @@ -402,10 +397,11 @@ mod tests { }; use hyper::Body; use rand::Rng; + use serde::{Deserialize, Serialize}; use tower::{BoxError, Service, ServiceBuilder, ServiceExt}; use super::PersistencePolicy; - use crate::{async_session::MemoryStore, SessionHandle, SessionLayer}; + use crate::{SessionHandle, SessionLayer}; #[derive(Deserialize, Serialize, PartialEq, Debug)] struct Counter { @@ -702,7 +698,7 @@ mod tests { async fn echo_read_session(req: Request) -> Result, BoxError> { { let session_handle = req.extensions().get::().unwrap(); - let session = session_handle.write().await; + let session = session_handle.lock().await; let _ = session.get::("signed_in").unwrap_or_default(); } Ok(Response::new(req.into_body())) @@ -711,7 +707,7 @@ mod tests { async fn echo_with_session_change(req: Request) -> Result, BoxError> { { let session_handle = req.extensions().get::().unwrap(); - let mut session = session_handle.write().await; + let mut session = session_handle.lock().await; session.insert("signed_in", true).unwrap(); } Ok(Response::new(req.into_body())) @@ -721,7 +717,7 @@ mod tests { // Destroy the session if we received a session cookie. if req.headers().get(COOKIE).is_some() { let session_handle = req.extensions().get::().unwrap(); - let mut session = session_handle.write().await; + let mut session = session_handle.lock().await; session.destroy(); } @@ -733,7 +729,7 @@ mod tests { { let session_handle = req.extensions().get::().unwrap(); - let mut session = session_handle.write().await; + let mut session = session_handle.lock().await; counter = session .get("counter") .map(|count: i32| count + 1)