From 6bc95d513f6aef3d1543b3bf951750f3922d0d3f Mon Sep 17 00:00:00 2001 From: LJ Date: Fri, 7 Mar 2025 16:59:08 -0800 Subject: [PATCH] Serializing/deserializing struct to JSON object for Postgres. --- src/base/value.rs | 183 ++++++++++++++++++++++++++++------- src/ops/py_factory.rs | 1 - src/ops/storages/postgres.rs | 7 +- 3 files changed, 155 insertions(+), 36 deletions(-) diff --git a/src/base/value.rs b/src/base/value.rs index 3e379f9d..0717208e 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -5,7 +5,7 @@ use anyhow::Result; use base64::prelude::*; use serde::{ de::{SeqAccess, Visitor}, - ser::{SerializeSeq, SerializeTuple}, + ser::{SerializeMap, SerializeSeq, SerializeTuple}, Deserialize, Serialize, }; use std::{collections::BTreeMap, ops::Deref, sync::Arc}; @@ -577,30 +577,58 @@ where Self { fields } } - pub fn from_json_values( - values: impl Iterator, - schema: &[FieldSchema], + fn from_json_values<'a>( + fields: impl Iterator, ) -> Result { - let fields = values - .zip(schema) - .map(|(v, s)| { - let value = Value::::from_json(v, &s.value_type.typ)?; - if value.is_null() && !s.value_type.nullable { - api_bail!("expected non-null value for `{}`", s.name); - } - Ok(value) - }) - .collect::>>()?; - Ok(Self { fields }) + Ok(Self { + fields: fields + .map(|(s, v)| { + let value = Value::::from_json(v, &s.value_type.typ)?; + if value.is_null() && !s.value_type.nullable { + api_bail!("expected non-null value for `{}`", s.name); + } + Ok(value) + }) + .collect::>>()?, + }) + } + + fn from_json_object<'a>( + values: serde_json::Map, + fields_schema: impl Iterator, + ) -> Result { + let mut values = values; + Ok(Self { + fields: fields_schema + .map(|field| { + let value = match values.get_mut(&field.name) { + Some(v) => { + Value::::from_json(std::mem::take(v), &field.value_type.typ)? + } + None => Value::::default(), + }; + if value.is_null() && !field.value_type.nullable { + api_bail!("expected non-null value for `{}`", field.name); + } + Ok(value) + }) + .collect::>>()?, + }) } - pub fn from_json(value: serde_json::Value, schema: &[FieldSchema]) -> Result { + pub fn from_json<'a>(value: serde_json::Value, fields_schema: &[FieldSchema]) -> Result { match value { serde_json::Value::Array(v) => { - if v.len() != schema.len() { + if v.len() != fields_schema.len() { api_bail!("unmatched value length"); } - Self::from_json_values(v.into_iter(), &schema) + Self::from_json_values(fields_schema.iter().zip(v.into_iter())) + } + serde_json::Value::Object(v) => { + if v.len() != fields_schema.len() { + api_bail!("unmatched value length"); + } + Self::from_json_object(v, fields_schema.iter()) } _ => api_bail!("invalid value type"), } @@ -738,22 +766,45 @@ where CollectionKind::Table => { let rows = v .into_iter() - .map(|v| match v { - serde_json::Value::Array(v) => { - let mut fields_iter = v.into_iter(); - let key = Self::from_json( - fields_iter - .next() - .ok_or_else(|| api_error!("Empty struct field values"))?, - &s.row.fields[0].value_type.typ, - )? - .to_key()?; - let values = - FieldValues::from_json_values(fields_iter, &s.row.fields[1..])? - .into(); - Ok((key, values)) + .map(|v| { + let mut fields_iter = s.row.fields.iter(); + let key_field = fields_iter + .next() + .ok_or_else(|| api_error!("Empty struct field values"))?; + + match v { + serde_json::Value::Array(v) => { + let mut field_vals_iter = v.into_iter(); + let key = Self::from_json( + field_vals_iter.next().ok_or_else(|| { + api_error!("Empty struct field values") + })?, + &key_field.value_type.typ, + )? + .to_key()?; + let values = FieldValues::from_json_values( + fields_iter.zip(field_vals_iter), + )?; + Ok((key, values.into())) + } + serde_json::Value::Object(mut v) => { + let key = Self::from_json( + std::mem::take(v.get_mut(&key_field.name).ok_or_else( + || { + api_error!( + "key field `{}` doesn't exist in value", + key_field.name + ) + }, + )?), + &key_field.value_type.typ, + )? + .to_key()?; + let values = FieldValues::from_json_object(v, fields_iter)?; + Ok((key, values.into())) + } + _ => api_bail!("Table value must be a JSON array or object"), } - _ => api_bail!("Table value must be a JSON array"), }) .collect::>>()?; Value::Table(rows) @@ -773,3 +824,69 @@ where Ok(result) } } + +#[derive(Debug, Clone, Copy)] +pub struct TypedValue<'a> { + pub t: &'a ValueType, + pub v: &'a Value, +} + +impl<'a> Serialize for TypedValue<'a> { + fn serialize(&self, serializer: S) -> Result { + match (self.t, self.v) { + (ValueType::Basic(_), v) => v.serialize(serializer), + (ValueType::Struct(s), Value::Struct(field_values)) => TypedFieldsValue { + schema: s, + values_iter: field_values.fields.iter(), + } + .serialize(serializer), + (ValueType::Collection(c), Value::Collection(rows) | Value::List(rows)) => { + let mut seq = serializer.serialize_seq(Some(rows.len()))?; + for row in rows { + seq.serialize_element(&TypedFieldsValue { + schema: &c.row, + values_iter: row.fields.iter(), + })?; + } + seq.end() + } + (ValueType::Collection(c), Value::Table(rows)) => { + let mut seq = serializer.serialize_seq(Some(rows.len()))?; + for (k, v) in rows { + seq.serialize_element(&TypedFieldsValue { + schema: &c.row, + values_iter: std::iter::once(&Value::from(k.clone())) + .chain(v.fields.iter()), + })?; + } + seq.end() + } + _ => Err(serde::ser::Error::custom(format!( + "Incompatible value type: {:?} {:?}", + self.t, self.v + ))), + } + } +} + +pub struct TypedFieldsValue<'a, I: Iterator + Clone> { + schema: &'a StructSchema, + values_iter: I, +} + +impl<'a, I: Iterator + Clone> Serialize for TypedFieldsValue<'a, I> { + fn serialize(&self, serializer: S) -> Result { + let mut map = serializer.serialize_map(Some(self.schema.fields.len()))?; + let values_iter = self.values_iter.clone(); + for (field, value) in self.schema.fields.iter().zip(values_iter) { + map.serialize_entry( + &field.name, + &TypedValue { + t: &field.value_type.typ, + v: value, + }, + )?; + } + map.end() + } +} diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index 5905640f..1a5efca5 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -3,7 +3,6 @@ use std::{collections::BTreeMap, sync::Arc}; use axum::async_trait; use blocking::unblock; use futures::FutureExt; -use log::warn; use pyo3::{ exceptions::PyException, pyclass, pymethods, diff --git a/src/ops/storages/postgres.rs b/src/ops/storages/postgres.rs index fe4be466..f1b6e8d0 100644 --- a/src/ops/storages/postgres.rs +++ b/src/ops/storages/postgres.rs @@ -89,7 +89,7 @@ fn bind_key_field<'arg>( fn bind_value_field<'arg>( builder: &mut sqlx::QueryBuilder<'arg, sqlx::Postgres>, - field_schema: &FieldSchema, + field_schema: &'arg FieldSchema, value: &'arg Value, ) -> Result<()> { match &value { @@ -145,7 +145,10 @@ fn bind_value_field<'arg>( builder.push("NULL"); } v => { - builder.push_bind(sqlx::types::Json(*v)); + builder.push_bind(sqlx::types::Json(TypedValue { + t: &field_schema.value_type.typ, + v, + })); } }; Ok(())