From 361e4401c650daa21a0e8822ce08efa87c8508b5 Mon Sep 17 00:00:00 2001 From: Andrew Vasilyev Date: Thu, 27 Jun 2024 12:24:51 +0000 Subject: [PATCH] feat: implement aggregates to support min/max --- Cargo.toml | 3 -- src/aggregate.rs | 130 +++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 9 +--- src/typeid.rs | 20 ++++++-- 4 files changed, 148 insertions(+), 14 deletions(-) create mode 100644 src/aggregate.rs diff --git a/Cargo.toml b/Cargo.toml index 5db76fc..82f0652 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,9 +3,6 @@ name = "typeid" version = "0.0.0" edition = "2021" -[build] -rustflags = ["-C", "target-feature=+aes"] - [lib] crate-type = ["cdylib", "lib"] diff --git a/src/aggregate.rs b/src/aggregate.rs new file mode 100644 index 0000000..f78e78d --- /dev/null +++ b/src/aggregate.rs @@ -0,0 +1,130 @@ +use pgrx::{aggregate::*, pg_aggregate, pg_sys}; + +use crate::typeid::TypeID; + +pub struct TypeIDMin; +pub struct TypeIDMax; + +#[pg_aggregate] +impl Aggregate for TypeIDMin { + const NAME: &'static str = "min"; + type Args = TypeID; + type State = Option; + + fn state( + current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match current { + None => Some(arg), + Some(current) => Some(if arg < current { arg } else { current }), + } + } +} + +#[pg_aggregate] +impl Aggregate for TypeIDMax { + const NAME: &'static str = "max"; + type Args = TypeID; + type State = Option; + + fn state( + current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match current { + None => Some(arg), + Some(current) => Some(if arg > current { arg } else { current }), + } + } +} + +#[cfg(any(test, feature = "pg_test"))] +#[pgrx::pg_schema] +mod tests { + use super::*; + use pgrx::prelude::*; + + #[pg_test] + fn test_typeid_min_max_aggregates() { + Spi::connect(|mut client| { + // Create a temporary table + client + .update("CREATE TEMPORARY TABLE test_typeid (id typeid)", None, None) + .unwrap(); + + // Insert some test data + client.update("INSERT INTO test_typeid VALUES (typeid_generate('user')), (typeid_generate('user')), (typeid_generate('user'))", None, None).unwrap(); + + // Test min aggregate + let result = client + .select("SELECT min(id) FROM test_typeid", None, None) + .unwrap(); + + assert_eq!(result.len(), 1); + let min_typeid: TypeID = result + .first() + .get_one() + .unwrap() + .expect("didnt get min typeid"); + + // Test max aggregate + let result = client + .select("SELECT max(id) FROM test_typeid", None, None) + .unwrap(); + assert_eq!(result.len(), 1); + let max_typeid: TypeID = result + .first() + .get_one() + .unwrap() + .expect("didnt get max typeid"); + + // Verify that max is greater than min + assert!(max_typeid > min_typeid); + + // Test with empty table + client.update("TRUNCATE test_typeid", None, None).unwrap(); + let result = client + .select("SELECT min(id), max(id) FROM test_typeid", None, None) + .unwrap(); + assert_eq!(result.len(), 1); + + let (min_typeid, max_typeid): (Option, Option) = + result.first().get_two().unwrap(); + assert_eq!(min_typeid, None); + assert_eq!(max_typeid, None); + + // Test with single value + client + .update( + "INSERT INTO test_typeid VALUES (typeid_generate('user'))", + None, + None, + ) + .unwrap(); + let result = client + .select("SELECT min(id), max(id) FROM test_typeid", None, None) + .unwrap(); + assert_eq!(result.len(), 1); + let (min_typeid, max_typeid): (Option, Option) = + result.first().get_two().unwrap(); + + assert_eq!(min_typeid.unwrap(), max_typeid.unwrap()); + + // Test with multiple prefixes + client.update("TRUNCATE test_typeid", None, None).unwrap(); + client.update("INSERT INTO test_typeid VALUES (typeid_generate('user')), (typeid_generate('post')), (typeid_generate('comment'))", None, None).unwrap(); + let result = client + .select("SELECT min(id), max(id) FROM test_typeid", None, None) + .unwrap(); + assert_eq!(result.len(), 1); + let (min_typeid, max_typeid): (Option, Option) = + result.first().get_two().unwrap(); + + assert!(min_typeid.unwrap().type_prefix() == "comment"); + assert!(max_typeid.unwrap().type_prefix() == "user"); + }) + } +} diff --git a/src/lib.rs b/src/lib.rs index 6eb83e8..c552028 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod aggregate; pub mod base32; pub mod typeid; @@ -32,10 +33,7 @@ fn uuid_to_typeid(prefix: &str, uuid: pgrx::Uuid) -> TypeID { #[pg_extern] fn typeid_cmp(a: TypeID, b: TypeID) -> i32 { - match a.type_prefix().cmp(b.type_prefix()) { - std::cmp::Ordering::Equal => a.uuid().cmp(b.uuid()) as i32, - other => other as i32, - } + a.cmp(&b) as i32 } #[pg_extern] @@ -198,7 +196,6 @@ mod tests { Spi::run("CREATE TABLE question (id typeid);").unwrap(); Spi::run("CREATE TABLE answer (id typeid, question typeid);").unwrap(); - println!("Creating tables"); // Generate and insert test data let typeid1 = typeid_generate("qual"); let typeid2 = typeid_generate("answer"); @@ -236,7 +233,6 @@ mod tests { .unwrap() .expect("expected to find oid"); - println!("Inserting into table: {:?}", oid); Spi::run_with_args( &query, Some(vec![ @@ -251,7 +247,6 @@ mod tests { let query = format!("INSERT INTO {} (id) VALUES ($1::typeid)", table_name); let oid = oid_for_type("typeid").unwrap(); - println!("Inserting into table: {:?}", oid.unwrap()); Spi::run_with_args( &query, Some(vec![( diff --git a/src/typeid.rs b/src/typeid.rs index eb100ba..da57286 100644 --- a/src/typeid.rs +++ b/src/typeid.rs @@ -1,5 +1,5 @@ use core::fmt; -use std::borrow::Cow; +use std::{borrow::Cow, cmp::Ordering}; use pgrx::prelude::*; use serde::{Deserialize, Serialize}; @@ -25,7 +25,7 @@ pub enum Error { InvalidData, } -#[derive(Serialize, Deserialize, PartialEq, Eq, Clone)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, PartialOrd)] pub struct TypeIDPrefix(String); impl TypeIDPrefix { @@ -73,7 +73,7 @@ impl TypeIDPrefix { } } -#[derive(Serialize, Deserialize, Clone, PostgresType, PartialEq, Eq)] +#[derive(Debug, Serialize, Deserialize, Clone, PostgresType, PartialOrd, PartialEq, Eq)] #[inoutfuncs] pub struct TypeID(TypeIDPrefix, Uuid); @@ -108,6 +108,15 @@ impl TypeID { } } +impl Ord for TypeID { + fn cmp(&self, b: &Self) -> Ordering { + match self.type_prefix().cmp(b.type_prefix()) { + std::cmp::Ordering::Equal => self.uuid().cmp(b.uuid()), + other => other, + } + } +} + impl Hash for TypeID { fn hash(&self, state: &mut H) { self.type_prefix().as_bytes().hash(state); @@ -135,7 +144,10 @@ impl InOutFuncs for TypeID { // Convert the input to a str and handle potential UTF-8 errors let str_input = input.to_str().expect("text input is not valid UTF8"); - TypeID::from_string(str_input).unwrap() + match TypeID::from_string(str_input) { + Ok(typeid) => typeid, + Err(err) => panic!("Failed to construct TypeId<{str_input}>: {err}"), + } } fn output(&self, buffer: &mut pgrx::StringInfo) {