Skip to content

An implementation of the extract-as-argument pattern in Rust.

License

Notifications You must be signed in to change notification settings

matteopolak/extract-as-argument

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Extract as Argument

This is a blog post about the extract-as-argument pattern used in Axum, and how you can write it in 100 lines of code. Read the blog post here.

Blog

Axum is a web framework built on top of Tokio and Tower. It is designed to be ergonomic and easy to use, while still being fast and scalable. One of the features that makes Axum so easy to use is its extract-as-argument pattern.

Before we get started, here's what the pattern looks like in action:

fn handler(
  State(state): State<u8>,
  Json(json): Json<Data>,
) -> Response {
  // ...
}

This pattern is extremely powerful by allowing you to extract pieces of data in a type-safe manner from anything at all, not just requests.

There are two different things going on here: the first is the State extractor, which clones the state passed into the handler, but does not consume the request. On the other hand, the Json extractor consumes the request and deserializes the body into the Data struct.

Let's start by implementing some basic reuqest & response structs and a simple trait for our non-consuming extractors:

/// A "part" of a request that can be taken out without consuming the request.
/// Note that this struct is Copy, so it can be taken out multiple times.
#[derive(Clone, Copy)]
struct RequestParts {
  count: u8,
}

/// On the other hand, this struct is not Copy, so the data stored in `expensive` can only be "taken out" once.
#[derive(Clone)]
struct Request {
  parts: RequestParts,
  expensive: Vec<u8>,
}

/// Finally, a simple response struct so we can see what's happening.
struct Response {
  content: String,
}

/// A trait for extractors that do not consume the request.
/// Note that we always pass in a mutable reference to the request parts,
/// and a state object.
trait FromRequestParts<S> {
  fn from_request_parts(parts: &mut RequestParts, state: S) -> Self;
}

Now that we have our structs and trait, let's implement the easiest extractor: State.

/// The state extractor clones the state passed into the handler.
struct State<S>(S);

impl<S> FromRequestParts<S> for State<S> {
  /// We just ignore the request entirely, extracting the state only.
  fn from_request_parts(_: &mut RequestParts, state: S) -> Self {
    Self(state)
  }
}

Okay, great! Now we need a trait that allows an extractor to consume the incoming request.

mod private {
  pub struct WithParts;
  pub struct WithRequest;
}

trait FromRequest<S, X = private::WithRequest> {
  fn from_request(req: Request, state: S) -> Self;
}

This trait is pretty similar to the previous one, but it takes an owned Request instead of a mutable reference to RequestParts. We also have a second generic parameter, X, which allows us to implement a blanket trait for all types that implement FromRequestParts without worrying about Rust complaining that someone else could do it downstream.

impl<T, S> FromRequest<S, private::WithParts> for T
where
  T: FromRequestParts<S>,
{
  fn from_request(mut req: Request, state: S) -> Self {
    T::from_request_parts(&mut req.parts, state)
  }
}

This blanket trait is pretty simple: if T implements FromRequestParts, then we can implement FromRequest for it, since FromRequestParts only needs a mutable reference to a part of the owned Request that we're given.

With that done, let's implement an extractor that returns the expensive Vec<u8> from the request.

struct Expensive(Vec<u8>);

impl<S> FromRequest<S> for Expensive {
  fn from_request(req: Request, _: S) -> Self {
    Self(req.expensive)
  }
}

Nothing too special, now let's try the Json extractor.

struct Json<T>(T);

impl<S, T> FromRequest<S> for Json<T>
where
 T: serde::de::DeserializeOwned,
{
  fn from_request(req: Request, _: S) -> Self {
    Self(serde_json::from_slice(&req.expensive).expect("expected valid json"))
  }
}

We don't really care about errors in this example, so we'll just panic if the input isn't valid JSON that conforms to the data required by T.

This extractor is where it gets interesting. We can now use the Json extractor to easily deserialize the request body into any JSON that we want!

We still need to actually complete the handler portion, which looks almost magical when you see it for the first time. In order to do this, we need another trait: Handler.

trait Handler<T, S> {
  fn call(self, req: Request, state: S) -> Response;
}

Again, the trait definition is pretty simple. With just a few blanket implementations of common functions, we can make this trait extremely powerful for any state S and arguments T. However, it's going to get a bit ugly. Ideally, we would implement these using a macro to avoid repetition, but we're only going to implement the first two.

impl<S, F, M, T1> Handler<(M, T1), S> for F
where
  F: FnOnce(T1) -> Response,
  T1: FromRequest<S, M>,
{
  fn call(self, req: Request, state: S) -> Response {
    let t1 = T1::from_request(req, state);

    self(t1)
  }
}

impl<S, F, M, T1, T2> Handler<(M, T1, T2), S> for F
where
  F: FnOnce(T1, T2) -> Response,
  S: Clone,
  T1: FromRequestParts<S>,
  T2: FromRequest<S, M>,
{
  fn call(self, mut req: Request, state: S) -> Response {
    let t1 = T1::from_request_parts(&mut req.parts, state.clone());
    let t2 = T2::from_request(req, state);

    self(t1, t2)
  }
}

Okay, okay, that's a lot of spaghetti. Let's break it down. The FnOnce trait is implemented for functions that can be called once, and the definition looks like the following (simplified):

pub trait FnOnce<Args>
where
  Args: Tuple,
{
  type Output;

  fn call_once(self, args: Args) -> Self::Output;
}

Most importantly, it takes a tuple of arguments that we can generalize over to accept any number of arguments to our handler.

For our implementation of one argument, we take a function F that takes a single argument T1 and returns a Response. The argument T1 must implement FromRequest<S> (which includes anything that implements FromRequestParts<S> from our blanket implementation earlier), and the state S must be the same as the state passed into the handler.

If we were to implement this trait for n arguments, arguments T1 through Tn-1 would need to implement FromRequestParts<S>, and argument Tn would need to implement FromRequest<S, M>, since it consumes the request.

The actual body of our implementation does exactly that:

fn call(self, mut req: Request, state: S) -> Response {
  // T1 implements FromRequestParts<S>, so we can lend a mutable reference to the request parts.
  let t1 = T1::from_request_parts(&mut req.parts, state.clone());
  // We can still use the request here since the parts were not consumed.
  let t2 = T2::from_request(req, state);

  // After this point, the request has been consumed, so we can't use it anymore.
  // However, we don't need it anymore!

  // We just need to call the function (since F implements FnOnce) with the arguments
  // we just constructed for it.
  self(t1, t2)
}

There's also the option of having a handler with no arguments, which just calls the function:

impl<S, F> Handler<(), S> for F
where
  F: Fn() -> Response,
{
  fn call(self, _: Request, _: S) -> Response {
    self()
  }
}

Now that that's done, let's implement one more extractor to remove the count field from the request parts.

struct Count(u8);

impl<S> FromRequestParts<S> for Count {
  fn from_request_parts(parts: &mut RequestParts, _: S) -> Self {
    Self(parts.count)
  }
}

Let's make a few routes:

/// A simple route that returns "Hello, world!".
fn simple() -> Response {
  Response {
    content: "Hello, world!".to_string(),
  }
}

/// A route that returns the count and state from the request parts.
fn with_count_and_state(State(state): State<u8>, Count(count): Count) -> Response {
  Response {
    content: format!("state: {state}, count: {count}"),
  }
}

/// A route that returns the count from the request parts and the expensive data from the request.
fn with_state_and_expensive(State(state): State<u8>, Expensive(expensive): Expensive) -> Response {
  Response {
    content: format!("state: {state}, expensive: {}", expensive.len()),
  }
}

#[derive(serde::Deserialize)]
struct Body {
  repeat: usize,
  text: String,
}

/// A route that extracts the request body as JSON and does something with it.
fn with_json(Json(body): Json<Body>) -> Response {
  Response {
    content: body.text.repeat(body.repeat),
  }
}

Before we can actually use these, we need a function that takes in any handler and returns a function that can process a request into a response.

fn get<S, H, T>(handler: H) -> impl Fn(Request, S) -> Response
where
  H: Handler<T, S> + Copy,
{
  move |req, state| handler.call(req, state)
}

For any state S, handler H, and arguments T, this function returns a function that takes in a request and state and returns a response. This is the function that we will use to actually process requests.

Note that we need to add an additional bound to H to make sure that it implements Copy, because we want to be able to call it multiple times (since we're returning an Fn, not FnOnce).

The final step is testing it out. Let's first make a fake request and state object:

let state = 42;
let request = Request {
  parts: RequestParts { count: 10 },
  expensive: br#"{
    "repeat": 6,
    "text": "hi"
  }"#
  .to_vec(),
};

Now we can test out our routes:

let route = get(simple);
let response = route(request.clone(), state);

assert_eq!(response.content, "Hello, world!");

let route = get(with_count_and_state);
let response = route(request.clone(), state);

assert_eq!(response.content, "state: 42, count: 10");

let route = get(with_state_and_expensive);
let response = route(request.clone(), state);

assert_eq!(response.content, "state: 42, expensive: 37");

let route = get(with_json);
let response = route(request.clone(), state);

assert_eq!(response.content, "hihihihihihi");

And that's it! We've implemented Axum's extract-as-argument pattern in just 100 lines of code.

About

An implementation of the extract-as-argument pattern in Rust.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages