Skip to content

Commit

Permalink
Add TryIntoStruct trait for ergonomic result parsing into a struct.
Browse files Browse the repository at this point in the history
  • Loading branch information
obi1kenobi committed May 2, 2023
1 parent a9d8d78 commit 6bfa645
Show file tree
Hide file tree
Showing 5 changed files with 571 additions and 0 deletions.
3 changes: 3 additions & 0 deletions trustfall_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ pub mod graphql_query;
pub mod interpreter;
pub mod ir;
pub mod schema;
mod serialization;
mod util;

pub use serialization::TryIntoStruct;

#[cfg(test)]
mod numbers_interpreter;

Expand Down
1 change: 1 addition & 0 deletions trustfall_core/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod ir;
mod nullables_interpreter;
mod numbers_interpreter;
mod schema;
mod serialization;
mod util;

use std::{
Expand Down
267 changes: 267 additions & 0 deletions trustfall_core/src/serialization/deserializers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
use std::{collections::BTreeMap, sync::Arc};

use serde::de::{self, IntoDeserializer};

use crate::ir::FieldValue;

#[derive(Debug, Clone)]
pub(super) struct QueryResultDeserializer {
query_result: BTreeMap<Arc<str>, FieldValue>,
}

impl QueryResultDeserializer {
pub(super) fn new(query_result: BTreeMap<Arc<str>, FieldValue>) -> Self {
Self { query_result }
}
}

#[derive(Debug, Clone)]
struct QueryResultMapDeserializer<I: Iterator<Item = (Arc<str>, FieldValue)>> {
iter: I,
next_value: Option<FieldValue>,
}

impl<I: Iterator<Item = (Arc<str>, FieldValue)>> QueryResultMapDeserializer<I> {
fn new(iter: I) -> Self {
Self {
iter,
next_value: Default::default(),
}
}
}

#[derive(Debug, Clone, thiserror::Error)]
pub enum Error {
#[error("error from deserialize: {0}")]
Custom(String),
}

impl de::Error for Error {
fn custom<T>(msg: T) -> Self
where
T: std::fmt::Display,
{
Self::Custom(msg.to_string())
}
}

impl<'de> de::Deserializer<'de> for QueryResultDeserializer {
type Error = Error;

fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_map(QueryResultMapDeserializer::new(
self.query_result.into_iter(),
))
}

serde::forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map struct enum identifier ignored_any
}
}

impl<'de, I: Iterator<Item = (Arc<str>, FieldValue)>> de::MapAccess<'de>
for QueryResultMapDeserializer<I>
{
type Error = Error;

fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: de::DeserializeSeed<'de>,
{
self.iter
.next()
.map(|(key, value)| {
self.next_value = Some(value);
seed.deserialize(key.into_deserializer())
})
.transpose()
}

fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: de::DeserializeSeed<'de>,
{
seed.deserialize(
self.next_value
.take()
.expect("called next_value_seed out of order")
.into_deserializer(),
)
}
}

pub struct FieldValueDeserializer {
value: FieldValue,
}

impl<'de> de::IntoDeserializer<'de, Error> for FieldValue {
type Deserializer = FieldValueDeserializer;

fn into_deserializer(self) -> Self::Deserializer {
FieldValueDeserializer { value: self }
}
}

impl<'de> de::Deserializer<'de> for FieldValueDeserializer {
type Error = Error;

fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_i8(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_i8(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_i16(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_i16(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_i32(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_i32(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_u8(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_u8(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_u16(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_u16(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_u32(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_u32(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Float64(v) => visitor.visit_f32(v as f32),
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
if let FieldValue::List(v) = &self.value {
if len != v.len() {
return Err(Self::Error::Custom(format!(
"cannot deserialize {} length list into {len} sized tuple",
v.len()
)));
}
}
self.deserialize_any(visitor)
}

fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match &self.value {
&FieldValue::Null => visitor.visit_none(),
_ => visitor.visit_some(self),
}
}

fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_none()
}

fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Null => visitor.visit_none(),
FieldValue::Int64(v) => visitor.visit_i64(v),
FieldValue::Uint64(v) => visitor.visit_u64(v),
FieldValue::Float64(v) => visitor.visit_f64(v),
FieldValue::String(v) => visitor.visit_string(v),
FieldValue::Boolean(v) => visitor.visit_bool(v),
FieldValue::DateTimeUtc(_) => todo!(),
FieldValue::Enum(_) => todo!(),
FieldValue::List(v) => visitor.visit_seq(v.into_deserializer()),
}
}

serde::forward_to_deserialize_any! {
bool i64 i128 u64 u128 f64 char str string seq
bytes byte_buf unit unit_struct newtype_struct
tuple_struct map enum struct identifier
}
}
64 changes: 64 additions & 0 deletions trustfall_core/src/serialization/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::{collections::BTreeMap, sync::Arc};

use serde::de;

use crate::ir::FieldValue;

mod deserializers;

#[cfg(test)]
mod tests;

/// Deserialize Trustfall query results into a Rust struct.
///
/// ```rust
/// # use std::{collections::BTreeMap, sync::Arc};
/// # use maplit::btreemap;
/// # use trustfall_core::ir::FieldValue;
/// #
/// # fn run_query() -> Result<Box<dyn Iterator<Item = BTreeMap<Arc<str>, FieldValue>>>, ()> {
/// # Ok(Box::new(vec![
/// # btreemap! {
/// # Arc::from("number") => FieldValue::Int64(42),
/// # Arc::from("text") => FieldValue::String("the answer to everything".to_string()),
/// # }
/// # ].into_iter()))
/// # }
///
/// use trustfall_core::TryIntoStruct;
///
/// #[derive(Debug, PartialEq, Eq, serde::Deserialize)]
/// struct Output {
/// number: i64,
/// text: String,
/// }
///
/// let results: Vec<_> = run_query()
/// .expect("bad query arguments")
/// .map(|v| v.try_into_struct().expect("struct definition did not match query result shape"))
/// .collect();
///
/// assert_eq!(
/// vec![
/// Output {
/// number: 42,
/// text: "the answer to everything".to_string(),
/// },
/// ],
/// results,
/// );
/// ```
pub trait TryIntoStruct {
type Error;

fn try_into_struct<S: for<'de> de::Deserialize<'de>>(self) -> Result<S, Self::Error>;
}

impl TryIntoStruct for BTreeMap<Arc<str>, FieldValue> {
type Error = deserializers::Error;

fn try_into_struct<S: for<'de> de::Deserialize<'de>>(self) -> Result<S, deserializers::Error> {
let deserializer = deserializers::QueryResultDeserializer::new(self);
S::deserialize(deserializer)
}
}
Loading

0 comments on commit 6bfa645

Please sign in to comment.