From 0588187d90607428a81b5c762eb749b4c8f23e00 Mon Sep 17 00:00:00 2001 From: LJ Date: Mon, 3 Mar 2025 23:31:14 -0800 Subject: [PATCH] Implement a generic fingerprinter - for computing cache key. --- src/execution/fingerprint.rs | 409 +++++++++++++++++++++++++++++++++++ src/execution/indexer.rs | 16 +- src/execution/mod.rs | 1 + 3 files changed, 414 insertions(+), 12 deletions(-) create mode 100644 src/execution/fingerprint.rs diff --git a/src/execution/fingerprint.rs b/src/execution/fingerprint.rs new file mode 100644 index 00000000..de52713a --- /dev/null +++ b/src/execution/fingerprint.rs @@ -0,0 +1,409 @@ +use base64::prelude::*; +use blake2::digest::typenum; +use blake2::{Blake2b, Digest}; +use serde::ser::{ + Serialize, SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, + SerializeTupleStruct, SerializeTupleVariant, Serializer, +}; + +#[derive(Clone, Default)] +pub struct Fingerprinter { + hasher: Blake2b, +} + +impl Fingerprinter { + pub fn to_bytes(self) -> Vec { + self.hasher.finalize().to_vec() + } + + pub fn to_base64(self) -> String { + BASE64_STANDARD.encode(self.to_bytes()) + } + + fn write_type_tag(&mut self, tag: &str) { + self.hasher.update(tag.as_bytes()); + self.hasher.update(b";"); + } + + fn write_end_tag(&mut self) { + self.hasher.update(b"."); + } + + fn write_varlen_bytes(&mut self, bytes: &[u8]) { + self.write_usize(bytes.len()); + self.hasher.update(bytes); + } + + fn write_usize(&mut self, value: usize) { + self.hasher.update((value as u32).to_le_bytes()); + } +} + +#[derive(Debug)] +pub struct FingerprinterError { + msg: String, +} + +impl std::fmt::Display for FingerprinterError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "FingerprinterError: {}", self.msg) + } +} +impl std::error::Error for FingerprinterError {} +impl serde::ser::Error for FingerprinterError { + fn custom(msg: T) -> Self + where + T: std::fmt::Display, + { + FingerprinterError { + msg: format!("{msg}"), + } + } +} + +impl Serializer for &mut Fingerprinter { + type Ok = (); + type Error = FingerprinterError; + + type SerializeSeq = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + type SerializeMap = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + + fn serialize_bool(self, v: bool) -> Result<(), Self::Error> { + self.write_type_tag(if v { "t" } else { "f" }); + Ok(()) + } + + fn serialize_i8(self, v: i8) -> Result<(), Self::Error> { + self.write_type_tag("i1"); + self.hasher.update(v.to_le_bytes()); + Ok(()) + } + + fn serialize_i16(self, v: i16) -> Result<(), Self::Error> { + self.write_type_tag("i2"); + self.hasher.update(&v.to_le_bytes()); + Ok(()) + } + + fn serialize_i32(self, v: i32) -> Result<(), Self::Error> { + self.write_type_tag("i4"); + self.hasher.update(&v.to_le_bytes()); + Ok(()) + } + + fn serialize_i64(self, v: i64) -> Result<(), Self::Error> { + self.write_type_tag("i8"); + self.hasher.update(v.to_le_bytes()); + Ok(()) + } + + fn serialize_u8(self, v: u8) -> Result<(), Self::Error> { + self.write_type_tag("u1"); + self.hasher.update(v.to_le_bytes()); + Ok(()) + } + + fn serialize_u16(self, v: u16) -> Result<(), Self::Error> { + self.write_type_tag("u2"); + self.hasher.update(v.to_le_bytes()); + Ok(()) + } + + fn serialize_u32(self, v: u32) -> Result<(), Self::Error> { + self.write_type_tag("u4"); + self.hasher.update(v.to_le_bytes()); + Ok(()) + } + + fn serialize_u64(self, v: u64) -> Result<(), Self::Error> { + self.write_type_tag("u8"); + self.hasher.update(v.to_le_bytes()); + Ok(()) + } + + fn serialize_f32(self, v: f32) -> Result<(), Self::Error> { + self.write_type_tag("f4"); + self.hasher.update(v.to_le_bytes()); + Ok(()) + } + + fn serialize_f64(self, v: f64) -> Result<(), Self::Error> { + self.write_type_tag("f8"); + self.hasher.update(v.to_le_bytes()); + Ok(()) + } + + fn serialize_char(self, v: char) -> Result<(), Self::Error> { + self.write_type_tag("c"); + self.write_usize(v as usize); + Ok(()) + } + + fn serialize_str(self, v: &str) -> Result<(), Self::Error> { + self.write_type_tag("s"); + self.write_varlen_bytes(v.as_bytes()); + Ok(()) + } + + fn serialize_bytes(self, v: &[u8]) -> Result<(), Self::Error> { + self.write_type_tag("b"); + self.write_varlen_bytes(v); + Ok(()) + } + + fn serialize_none(self) -> Result<(), Self::Error> { + self.write_type_tag(""); + Ok(()) + } + + fn serialize_some(self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + fn serialize_unit(self) -> Result<(), Self::Error> { + self.write_type_tag("()"); + Ok(()) + } + + fn serialize_unit_struct(self, name: &'static str) -> Result<(), Self::Error> { + self.write_type_tag("US"); + self.write_varlen_bytes(name.as_bytes()); + Ok(()) + } + + fn serialize_unit_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result<(), Self::Error> { + self.write_type_tag("UV"); + self.write_varlen_bytes(name.as_bytes()); + self.write_varlen_bytes(variant.as_bytes()); + Ok(()) + } + + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.write_type_tag("NS"); + self.write_varlen_bytes(name.as_bytes()); + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.write_type_tag("NV"); + self.write_varlen_bytes(name.as_bytes()); + self.write_varlen_bytes(variant.as_bytes()); + value.serialize(self) + } + + fn serialize_seq(self, _len: Option) -> Result { + self.write_type_tag("L"); + Ok(self) + } + + fn serialize_tuple(self, _len: usize) -> Result { + self.write_type_tag("T"); + Ok(self) + } + + fn serialize_tuple_struct( + self, + name: &'static str, + _len: usize, + ) -> Result { + self.write_type_tag("TS"); + self.write_varlen_bytes(name.as_bytes()); + Ok(self) + } + + fn serialize_tuple_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + self.write_type_tag("TV"); + self.write_varlen_bytes(name.as_bytes()); + self.write_varlen_bytes(variant.as_bytes()); + Ok(self) + } + + fn serialize_map(self, _len: Option) -> Result { + self.write_type_tag("M"); + Ok(self) + } + + fn serialize_struct( + self, + name: &'static str, + _len: usize, + ) -> Result { + self.write_type_tag("S"); + self.write_varlen_bytes(name.as_bytes()); + Ok(self) + } + + fn serialize_struct_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + self.write_type_tag("SV"); + self.write_varlen_bytes(name.as_bytes()); + self.write_varlen_bytes(variant.as_bytes()); + Ok(self) + } +} + +impl SerializeSeq for &mut Fingerprinter { + type Ok = (); + type Error = FingerprinterError; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut **self) + } + + fn end(self) -> Result<(), Self::Error> { + self.write_end_tag(); + Ok(()) + } +} + +impl SerializeTuple for &mut Fingerprinter { + type Ok = (); + type Error = FingerprinterError; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut **self) + } + + fn end(self) -> Result<(), Self::Error> { + self.write_end_tag(); + Ok(()) + } +} + +impl SerializeTupleStruct for &mut Fingerprinter { + type Ok = (); + type Error = FingerprinterError; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut **self) + } + + fn end(self) -> Result<(), Self::Error> { + self.write_end_tag(); + Ok(()) + } +} + +impl SerializeTupleVariant for &mut Fingerprinter { + type Ok = (); + type Error = FingerprinterError; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut **self) + } + + fn end(self) -> Result<(), Self::Error> { + self.write_end_tag(); + Ok(()) + } +} + +impl SerializeMap for &mut Fingerprinter { + type Ok = (); + type Error = FingerprinterError; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + key.serialize(&mut **self) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut **self) + } + + fn end(self) -> Result<(), Self::Error> { + self.write_end_tag(); + Ok(()) + } +} + +impl SerializeStruct for &mut Fingerprinter { + type Ok = (); + type Error = FingerprinterError; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.hasher.update(key.as_bytes()); + self.hasher.update(b"\n"); + value.serialize(&mut **self) + } + + fn end(self) -> Result<(), Self::Error> { + self.write_end_tag(); + Ok(()) + } +} + +impl SerializeStructVariant for &mut Fingerprinter { + type Ok = (); + type Error = FingerprinterError; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.hasher.update(key.as_bytes()); + self.hasher.update(b"\n"); + value.serialize(&mut **self) + } + + fn end(self) -> Result<(), Self::Error> { + self.write_end_tag(); + Ok(()) + } +} diff --git a/src/execution/indexer.rs b/src/execution/indexer.rs index 0fa9e20e..c3de7eb2 100644 --- a/src/execution/indexer.rs +++ b/src/execution/indexer.rs @@ -1,8 +1,6 @@ use std::collections::{HashMap, HashSet}; use anyhow::Result; -use blake2::digest::typenum; -use blake2::{Blake2b, Digest}; use futures::future::{join_all, try_join, try_join_all}; use log::error; use serde::Serialize; @@ -10,6 +8,7 @@ use sqlx::PgPool; use super::db_tracking::{self, read_source_tracking_info}; use super::db_tracking_setup; +use super::fingerprint::Fingerprinter; use super::memoization::{EvaluationCache, MemoizationInfo}; use crate::base::schema; use crate::base::spec::FlowInstanceSpec; @@ -85,15 +84,6 @@ fn make_primary_key( Ok(key) } -fn fingerprint(values: &FieldValues) -> Result { - let mut hasher = Blake2b::::new(); - for field_value in values.fields.iter() { - hasher.update(serde_json::to_string(field_value)?.as_bytes()); - hasher.update(b"\n"); - } - Ok(format!("{:x}", hasher.finalize())) -} - enum WithApplyStatus { Normal(T), Collapsed, @@ -225,7 +215,9 @@ async fn precommit_source_tracking_info( .fields .push(value.fields[*field as usize].clone()); } - let curr_fp = Some(fingerprint(&field_values)?); + let mut fingerprinter = Fingerprinter::default(); + field_values.serialize(&mut fingerprinter)?; + let curr_fp = Some(fingerprinter.to_base64()); let existing_target_keys = target_info.existing_keys_info.remove(&primary_key_json); let existing_staging_target_keys = target_info diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 277d718a..c0c85c0c 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -5,4 +5,5 @@ pub mod query; mod db_tracking; pub mod db_tracking_setup; +mod fingerprint; mod memoization;