diff --git a/CHANGELOG.md b/CHANGELOG.md index 5280d68..be8a741 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Unreleased +# 0.2.0 + +**Breaking Changes** + +- Rework crate into a middleware + +This changes the structure of the crate such that it is now a middleware in addition to being an extractor. Doing so allows us to improve the ergonomics of the API such that calling `save` and awaiting a future is no longer needed. + +Now applications will need to install the `MeessagesManagerLayer` after `tower-sessions` has been installed (either directly or via a middleware that wraps it). + +Also note that the iterator impplementation has been updated to use `Message` directly. Fields of `Message` have been made public as well. + # 0.1.0 - Initial release :tada: diff --git a/Cargo.toml b/Cargo.toml index 2996ee0..f38cb90 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,12 @@ [package] name = "axum-messages" -version = "0.1.0" +version = "0.2.0" edition = "2021" authors = ["Max Countryman "] categories = ["asynchronous", "network-programming", "web-programming"] description = "🛎️ One-time notification messages for Axum." homepage = "https://github.com/maxcountryman/axum-messages" -keywords = ["axum", "flash", "message", "messages"] +keywords = ["axum", "flash", "message", "messages", "notification"] license = "MIT" readme = "README.md" repository = "https://github.com/maxcountryman/axum-messages" @@ -15,7 +15,9 @@ repository = "https://github.com/maxcountryman/axum-messages" async-trait = "0.1.77" axum-core = "0.4.3" http = "1.0.0" +parking_lot = "0.12.1" serde = { version = "1.0.195", features = ["derive"] } +tower = "0.4" tower-sessions-core = "0.9.1" [dev-dependencies] diff --git a/README.md b/README.md index da412f2..9cab507 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ To use the crate in your project, add the following to your `Cargo.toml` file: ```toml [dependencies] -axum-messages = "0.1.0" +axum-messages = "0.2.0" ``` ## 🤸 Usage @@ -49,7 +49,7 @@ use axum::{ routing::get, Router, }; -use axum_messages::Messages; +use axum_messages::{Messages, MessagesManagerLayer}; use time::Duration; use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; @@ -63,6 +63,7 @@ async fn main() { let app = Router::new() .route("/", get(set_messages_handler)) .route("/read-messages", get(read_messages_handler)) + .layer(MessagesManagerLayer) .layer(session_layer); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -75,7 +76,7 @@ async fn main() { async fn read_messages_handler(messages: Messages) -> impl IntoResponse { let messages = messages .into_iter() - .map(|(level, message)| format!("{:?}: {}", level, message)) + .map(|message| format!("{:?}: {}", message.level, message)) .collect::>() .join(", "); @@ -89,10 +90,7 @@ async fn read_messages_handler(messages: Messages) -> impl IntoResponse { async fn set_messages_handler(messages: Messages) -> impl IntoResponse { messages .info("Hello, world!") - .debug("This is a debug message.") - .save() - .await - .unwrap(); + .debug("This is a debug message."); Redirect::to("/read-messages") } diff --git a/examples/basic.rs b/examples/basic.rs index eb06934..66a089f 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -5,7 +5,7 @@ use axum::{ routing::get, Router, }; -use axum_messages::Messages; +use axum_messages::{Messages, MessagesManagerLayer}; use time::Duration; use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; @@ -19,6 +19,7 @@ async fn main() { let app = Router::new() .route("/", get(set_messages_handler)) .route("/read-messages", get(read_messages_handler)) + .layer(MessagesManagerLayer) .layer(session_layer); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -31,7 +32,7 @@ async fn main() { async fn read_messages_handler(messages: Messages) -> impl IntoResponse { let messages = messages .into_iter() - .map(|(level, message)| format!("{:?}: {}", level, message)) + .map(|message| format!("{:?}: {}", message.level, message)) .collect::>() .join(", "); @@ -45,10 +46,7 @@ async fn read_messages_handler(messages: Messages) -> impl IntoResponse { async fn set_messages_handler(messages: Messages) -> impl IntoResponse { messages .info("Hello, world!") - .debug("This is a debug message.") - .save() - .await - .unwrap(); + .debug("This is a debug message."); Redirect::to("/read-messages") } diff --git a/src/lib.rs b/src/lib.rs index 8fcedc3..81166d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ //! routing::get, //! Router, //! }; -//! use axum_messages::Messages; +//! use axum_messages::{Messages, MessagesManagerLayer}; //! use time::Duration; //! use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; //! @@ -25,6 +25,7 @@ //! let app = Router::new() //! .route("/", get(set_messages_handler)) //! .route("/read-messages", get(read_messages_handler)) +//! .layer(MessagesManagerLayer) //! .layer(session_layer); //! //! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -37,7 +38,7 @@ //! async fn read_messages_handler(messages: Messages) -> impl IntoResponse { //! let messages = messages //! .into_iter() -//! .map(|(level, message)| format!("{:?}: {}", level, message)) +//! .map(|message| format!("{:?}: {}", message.level, message)) //! .collect::>() //! .join(", "); //! @@ -51,10 +52,7 @@ //! async fn set_messages_handler(messages: Messages) -> impl IntoResponse { //! messages //! .info("Hello, world!") -//! .debug("This is a debug message.") -//! .save() -//! .await -//! .unwrap(); +//! .debug("This is a debug message."); //! //! Redirect::to("/read-messages") //! } @@ -68,12 +66,24 @@ #![deny(missing_docs)] #![forbid(unsafe_code)] -use std::collections::VecDeque; +use core::fmt; +use std::{ + collections::VecDeque, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; use async_trait::async_trait; -use axum_core::extract::FromRequestParts; +use axum_core::{ + extract::{FromRequestParts, Request}, + response::Response, +}; use http::{request::Parts, StatusCode}; +use parking_lot::Mutex; use serde::{Deserialize, Serialize}; +use tower::{Layer, Service}; use tower_sessions_core::{session, Session}; // N.B.: Code structure directly borrowed from `axum-flash`: https://github.com/davidpdrsn/axum-flash/blob/5e8b2bded97fd10bb275d5bc66f4d020dec465b9/src/lib.rs @@ -81,11 +91,19 @@ use tower_sessions_core::{session, Session}; /// Container for a message which provides a level and message content. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { + /// Message level, i.e. `Level`. #[serde(rename = "l")] - level: Level, + pub level: Level, + /// The message itself. #[serde(rename = "m")] - message: String, + pub message: String, +} + +impl fmt::Display for Message { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.message) + } } type MessageQueue = VecDeque; @@ -126,70 +144,156 @@ struct Data { #[derive(Debug, Clone)] pub struct Messages { session: Session, - data: Data, + data: Arc>, } impl Messages { - const DATA_KEY: &'static str = "messages.data"; + const DATA_KEY: &'static str = "axum-messages.data"; + + fn new(session: Session, data: Data) -> Self { + Self { + session, + data: Arc::new(Mutex::new(data)), + } + } /// Push a `Debug` message. - #[must_use = "`save` must be called to persist messages in the session"] pub fn debug(self, message: impl Into) -> Self { self.push(Level::Debug, message) } /// Push an `Info` message. - #[must_use = "`save` must be called to persist messages in the session"] pub fn info(self, message: impl Into) -> Self { self.push(Level::Info, message) } /// Push a `Success` message. - #[must_use = "`save` must be called to persist messages in the session"] pub fn success(self, message: impl Into) -> Self { self.push(Level::Success, message) } /// Push a `Warning` message. - #[must_use = "`save` must be called to persist messages in the session"] pub fn warning(self, message: impl Into) -> Self { self.push(Level::Warning, message) } /// Push an `Error` message. - #[must_use = "`save` must be called to persist messages in the session"] pub fn error(self, message: impl Into) -> Self { self.push(Level::Error, message) } /// Push a message with the given level. - #[must_use = "`save` must be called to persist messages in the session"] - pub fn push(mut self, level: Level, message: impl Into) -> Self { - self.data.pending_messages.push_back(Message { - message: message.into(), - level, - }); + pub fn push(self, level: Level, message: impl Into) -> Self { + { + let mut data = self.data.lock(); + data.pending_messages.push_back(Message { + message: message.into(), + level, + }); + } self } - /// Save messages back to the session. - /// - /// Note that this must called or messages will not be persisted between - /// requests. - pub async fn save(self) -> Result { + async fn save(self) -> Result { self.session .insert(Self::DATA_KEY, self.data.clone()) .await?; Ok(self) } + + fn load(self) -> Self { + { + // Load messages by taking them from the pending queue. + let mut data = self.data.lock(); + data.messages = std::mem::take(&mut data.pending_messages); + } + self + } } impl Iterator for Messages { - type Item = (Level, String); + type Item = Message; fn next(&mut self) -> Option { - let message = self.data.messages.pop_front()?; - Some((message.level, message.message)) + let mut data = self.data.lock(); + data.messages.pop_front() + } +} + +/// MIddleware provider `Messages` as a request extension. +#[derive(Debug, Clone)] +pub struct MessagesManager { + inner: S, +} + +impl Service> for MessagesManager +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Send, + ReqBody: Send + 'static, + ResBody: Default + Send, +{ + type Response = S::Response; + type Error = S::Error; + type Future = Pin> + Send>>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + // Because the inner service can panic until ready, we need to ensure we only + // use the ready service. + // + // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services + let clone = self.inner.clone(); + let mut inner = std::mem::replace(&mut self.inner, clone); + + Box::pin(async move { + let Some(session) = req.extensions().get::().cloned() else { + let mut res = Response::default(); + *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; + return Ok(res); + }; + + let data = match session.get::(Messages::DATA_KEY).await { + Ok(Some(data)) => data, + Ok(None) => Data::default(), + Err(_) => { + let mut res = Response::default(); + *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; + return Ok(res); + } + }; + + let messages = Messages::new(session, data); + + req.extensions_mut().insert(messages.clone()); + + let res = inner.call(req).await; + + if messages.save().await.is_err() { + let mut res = Response::default(); + *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; + return Ok(res); + }; + + res + }) + } +} + +/// Layer for `MessagesManager`. +#[derive(Debug, Clone)] +pub struct MessagesManagerLayer; + +impl Layer for MessagesManagerLayer { + type Service = MessagesManager; + + fn layer(&self, inner: S) -> Self::Service { + MessagesManager { inner } } } @@ -200,32 +304,16 @@ where { type Rejection = (StatusCode, &'static str); - async fn from_request_parts(req: &mut Parts, state: &S) -> Result { - let session = Session::from_request_parts(req, state).await?; - let mut data = match session.get::(Self::DATA_KEY).await { - Ok(Some(data)) => data, - Ok(None) => Data::default(), - Err(_) => { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - "Could not get from session", - )); - } - }; - - // Load messages by taking them from the pending queue. - data.messages = std::mem::take(&mut data.pending_messages); - - // Save back to the session to ensure future loads do not repeat loaded - // messages. - if session.insert(Self::DATA_KEY, data.clone()).await.is_err() { - return Err(( + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(( StatusCode::INTERNAL_SERVER_ERROR, - "Could not insert to session", - )); - }; - - Ok(Self { session, data }) + "Could not extract messages. Is `MessagesManagerLayer` installed?", + )) + .map(|messages| messages.load()) } } @@ -251,12 +339,13 @@ mod tests { let app = Router::new() .route("/", get(root)) .route("/set-message", get(set_message)) + .layer(MessagesManagerLayer) .layer(session_layer); async fn root(messages: Messages) -> impl IntoResponse { messages .into_iter() - .map(|(level, message)| format!("{:?}: {}", level, message)) + .map(|message| format!("{:?}: {}", message.level, message)) .collect::>() .join(", ") } @@ -265,10 +354,7 @@ mod tests { async fn set_message(messages: Messages) -> impl IntoResponse { messages .debug("Hello, world!") - .info("This is an info message.") - .save() - .await - .unwrap(); + .info("This is an info message."); Redirect::to("/") }