Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 109 additions & 7 deletions fluree-db-core/src/nonempty.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,50 @@
//! `NonEmpty<T>`: a sequence with a type-level guarantee of at least one element.

/// Sequence with a type-level guarantee of at least one element. The
/// invariant is structural — `head` is always present — so downstream code
/// can rely on `first`, `iter`, etc. without empty-checks. Constructed only
/// at validation boundaries.
/// Construct a `NonEmpty<T>` from a comma-separated list of values.
///
/// Statically requires at least one expression — `nonempty![]` is a
/// compile error rather than a runtime panic. Useful in tests and other
/// callers where the non-empty constraint is obvious from context but
/// going through `try_from_vec(...).unwrap()` is noisy.
///
/// ```
/// use fluree_db_core::{nonempty, NonEmpty};
/// let xs: NonEmpty<i32> = nonempty![1, 2, 3];
/// assert_eq!(xs.len(), 3);
/// ```
#[macro_export]
macro_rules! nonempty {
($head:expr $(, $tail:expr)* $(,)?) => {
$crate::NonEmpty::from_head_tail($head, vec![$($tail),*])
};
}

/// Sequence with a type-level guarantee of at least one element.
///
/// The invariant is structural: `head` is always present. Downstream code
/// can rely on `first`, `iter`, etc. without empty-checks.
#[derive(Clone, Debug)]
pub struct NonEmpty<T> {
/// First element. Always present by construction.
pub head: T,
/// Remaining elements (possibly empty).
pub tail: Vec<T>,
}

impl<T> NonEmpty<T> {
/// Construct from a single element.
pub fn singleton(head: T) -> Self {
Self {
head,
tail: Vec::new(),
}
}

/// Construct from a head plus a tail of arbitrary length.
pub fn from_head_tail(head: T, tail: Vec<T>) -> Self {
Self { head, tail }
}

/// Construct from a `Vec`, returning `None` if the input is empty.
pub fn try_from_vec(v: Vec<T>) -> Option<Self> {
let mut iter = v.into_iter();
Expand All @@ -26,6 +60,21 @@ impl<T> NonEmpty<T> {
std::iter::once(&self.head).chain(self.tail.iter())
}

/// Iterate mutably over all elements in order, starting with the head.
pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut T> {
std::iter::once(&mut self.head).chain(self.tail.iter_mut())
}

/// Map `f` over each element, returning a new `NonEmpty<U>`. Length is
/// preserved (the head is always mapped to a head), so the result is
/// also non-empty by construction.
pub fn map<U, F: FnMut(T) -> U>(self, mut f: F) -> NonEmpty<U> {
NonEmpty {
head: f(self.head),
tail: self.tail.into_iter().map(f).collect(),
}
}

/// Convert into a (necessarily non-empty) `Vec`.
pub fn into_vec(self) -> Vec<T> {
let mut v = Vec::with_capacity(1 + self.tail.len());
Expand All @@ -49,13 +98,66 @@ impl<T> NonEmpty<T> {
pub fn first(&self) -> &T {
&self.head
}

/// Last element. Always present by construction.
pub fn last(&self) -> &T {
self.tail.last().unwrap_or(&self.head)
}

/// Append an element to the tail. The non-empty invariant is preserved
/// trivially — the head is untouched and growing the tail only adds
/// elements.
pub fn push(&mut self, value: T) {
self.tail.push(value);
}

/// Extend the tail with the contents of an iterator. The non-empty
/// invariant is preserved trivially — the head is untouched.
pub fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
self.tail.extend(iter);
}
}

impl<T> From<T> for NonEmpty<T> {
fn from(t: T) -> Self {
Self {
head: t,
tail: Vec::new(),
Self::singleton(t)
}
}

/// Read-only positional access. Out-of-range indices panic just like
/// `Vec`/`slice` indexing; the invariant guarantees that index 0 always
/// succeeds.
impl<T> std::ops::Index<usize> for NonEmpty<T> {
type Output = T;
fn index(&self, idx: usize) -> &T {
if idx == 0 {
&self.head
} else {
&self.tail[idx - 1]
}
}
}

impl<'a, T> IntoIterator for &'a NonEmpty<T> {
type Item = &'a T;
type IntoIter = std::iter::Chain<std::iter::Once<&'a T>, std::slice::Iter<'a, T>>;
fn into_iter(self) -> Self::IntoIter {
std::iter::once(&self.head).chain(self.tail.iter())
}
}

impl<'a, T> IntoIterator for &'a mut NonEmpty<T> {
type Item = &'a mut T;
type IntoIter = std::iter::Chain<std::iter::Once<&'a mut T>, std::slice::IterMut<'a, T>>;
fn into_iter(self) -> Self::IntoIter {
std::iter::once(&mut self.head).chain(self.tail.iter_mut())
}
}

impl<T> IntoIterator for NonEmpty<T> {
type Item = T;
type IntoIter = std::iter::Chain<std::iter::Once<T>, std::vec::IntoIter<T>>;
fn into_iter(self) -> Self::IntoIter {
std::iter::once(self.head).chain(self.tail)
}
}
101 changes: 61 additions & 40 deletions fluree-db-query/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl AggregateOperator {
let mut group_size_col: Option<usize> = None;

for (agg_idx, spec) in aggregates.iter().enumerate() {
match spec.input_var {
match spec.function.input_var() {
Some(input_var) => {
// Regular aggregate with input variable
if let Some(col_idx) = child_schema.iter().position(|v| *v == input_var) {
Expand Down Expand Up @@ -198,7 +198,7 @@ impl Operator for AggregateOperator {
Some(agg_idx) => {
// This column needs aggregation
let spec = &self.aggregates[agg_idx];
apply_aggregate(&spec.function, input_binding, spec.distinct)
spec.function.apply(input_binding)
}
None => {
// Pass through unchanged
Expand All @@ -221,7 +221,7 @@ impl Operator for AggregateOperator {
Some(col_idx) => (0..batch.len())
.map(|row_idx| {
let input_binding = batch.get_by_col(row_idx, *col_idx);
apply_aggregate(&spec.function, input_binding, spec.distinct)
spec.function.apply(input_binding)
})
.collect(),
None => {
Expand Down Expand Up @@ -274,43 +274,54 @@ impl Operator for AggregateOperator {
}
}

/// Apply an aggregate function to a binding
///
/// If the binding is `Grouped(values)`, compute the aggregate.
/// Otherwise, pass through unchanged (shouldn't happen in normal usage).
/// When `distinct` is true, deduplicates values before aggregation.
pub fn apply_aggregate(func: &AggregateFn, binding: &Binding, distinct: bool) -> Binding {
match binding {
Binding::Grouped(values) => {
if distinct {
let mut seen = HashSet::with_capacity(values.len());
let deduped: Vec<Binding> =
values.iter().filter(|b| seen.insert(*b)).cloned().collect();
compute_aggregate(func, &deduped)
} else {
compute_aggregate(func, values)
}
impl AggregateFn {
/// Apply this aggregate function to a binding.
///
/// If the binding is `Grouped(values)`, compute the aggregate;
/// non-grouped values pass through (e.g. group-key columns).
/// Variants that need an upstream dedup pass (see
/// [`Self::needs_input_dedup`]) get one before being handed to
/// [`Self::compute`].
pub fn apply(&self, binding: &Binding) -> Binding {
let Binding::Grouped(values) = binding else {
return binding.clone();
};
if self.needs_input_dedup() {
let mut seen = HashSet::with_capacity(values.len());
let deduped: Vec<Binding> =
values.iter().filter(|b| seen.insert(*b)).cloned().collect();
self.compute(&deduped)
} else {
self.compute(values)
}
// Non-grouped values pass through (e.g., group key columns)
other => other.clone(),
}
}

/// Compute aggregate over a list of bindings
fn compute_aggregate(func: &AggregateFn, values: &[Binding]) -> Binding {
match func {
AggregateFn::Count => agg_count(values),
AggregateFn::CountAll => agg_count_all(values),
AggregateFn::CountDistinct => agg_count_distinct(values),
AggregateFn::Sum => agg_sum(values),
AggregateFn::Avg => agg_avg(values),
AggregateFn::Min => agg_min(values),
AggregateFn::Max => agg_max(values),
AggregateFn::Median => agg_median(values),
AggregateFn::Variance => agg_variance(values),
AggregateFn::Stddev => agg_stddev(values),
AggregateFn::GroupConcat { separator } => agg_group_concat(values, separator),
AggregateFn::Sample => agg_sample(values),
/// Whether [`Self::apply`] must deduplicate input values before
/// reducing. True for variants whose [`InputSemantics`] is
/// [`InputSemantics::Set`]; false for everything else, including
/// [`Self::CountDistinct`] — its streaming state is already a
/// `HashSet`, so an additional dedup pass would be redundant.
fn needs_input_dedup(&self) -> bool {
self.is_distinct() && !matches!(self, AggregateFn::CountDistinct(_))
}

/// Compute the aggregate result over an already-prepared list of
/// bindings (deduplicated upstream by [`Self::apply`] if needed).
fn compute(&self, values: &[Binding]) -> Binding {
match self {
Self::Count(_) => agg_count(values),
Self::CountAll => agg_count_all(values),
Self::CountDistinct(_) => agg_count_distinct(values),
Self::Sum { .. } => agg_sum(values),
Self::Avg { .. } => agg_avg(values),
Self::Min(_) => agg_min(values),
Self::Max(_) => agg_max(values),
Self::Median { .. } => agg_median(values),
Self::Variance { .. } => agg_variance(values),
Self::Stddev { .. } => agg_stddev(values),
Self::GroupConcat { separator, .. } => agg_group_concat(values, separator),
Self::Sample(_) => agg_sample(values),
}
}
}

Expand Down Expand Up @@ -716,6 +727,7 @@ fn extract_numbers(values: &[Binding]) -> Vec<f64> {
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::InputSemantics;

#[test]
fn test_agg_count() {
Expand Down Expand Up @@ -1021,11 +1033,20 @@ mod tests {
assert!(matches!(result, Binding::Unbound));
}

fn sum_of(input: VarId, distinct: bool) -> AggregateFn {
let semantics = if distinct {
InputSemantics::Set
} else {
InputSemantics::List
};
AggregateFn::Sum(input, semantics)
}

#[test]
fn test_apply_aggregate_non_grouped() {
// Non-grouped values pass through unchanged
let binding = Binding::lit(FlakeValue::Long(42), xsd_integer());
let result = apply_aggregate(&AggregateFn::Sum, &binding, false);
let result = sum_of(VarId(0), false).apply(&binding);
assert_eq!(result, binding);
}

Expand All @@ -1037,7 +1058,7 @@ mod tests {
Binding::lit(FlakeValue::Long(3), xsd_integer()),
]);

let result = apply_aggregate(&AggregateFn::Sum, &grouped, false);
let result = sum_of(VarId(0), false).apply(&grouped);
let (val, _) = result.as_lit().unwrap();
assert_eq!(*val, FlakeValue::Long(6));
}
Expand All @@ -1053,7 +1074,7 @@ mod tests {
Binding::lit(FlakeValue::Long(2), xsd_integer()), // duplicate
]);

let result = apply_aggregate(&AggregateFn::Sum, &grouped, true);
let result = sum_of(VarId(0), true).apply(&grouped);
let (val, _) = result.as_lit().unwrap();
assert_eq!(*val, FlakeValue::Long(6)); // 1+2+3 = 6, not 1+2+1+3+2 = 9
}
Expand Down
2 changes: 0 additions & 2 deletions fluree-db-query/src/count_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,9 +958,7 @@ mod tests {
aggregates: fluree_db_core::NonEmpty::try_from_vec(vec![
crate::ir::AggregateSpec {
function: AggregateFn::CountAll,
input_var: None,
output_var: out_var,
distinct: false,
},
])
.expect("non-empty"),
Expand Down
24 changes: 7 additions & 17 deletions fluree-db-query/src/execute/dependency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ pub fn compute_variable_deps(query: &Query) -> Option<VariableDeps> {
// Aggregates: replace output vars with input vars.
for spec in query.grouping.iter().flat_map(Grouping::aggregates) {
if deps.remove(&spec.output_var) {
if let Some(input_var) = spec.input_var {
if let Some(input_var) = spec.function.input_var() {
deps.insert(input_var);
}
}
Expand Down Expand Up @@ -118,7 +118,7 @@ mod tests {
use crate::ir::triple::{Ref, Term, TriplePattern};
use crate::ir::{
AggregateFn, AggregateSpec, Aggregation, ConstructTemplate, Expression, FlakeValue,
Pattern, Query, QueryOutput, ReasoningConfig,
InputSemantics, Pattern, Query, QueryOutput, ReasoningConfig,
};
use crate::parse::SelectMode;
use crate::sort::SortSpec;
Expand Down Expand Up @@ -215,10 +215,8 @@ mod tests {
group_by: fluree_db_core::NonEmpty::try_from_vec(vec![VarId(2)]).unwrap(),
aggregation: Some(Aggregation {
aggregates: fluree_db_core::NonEmpty::try_from_vec(vec![AggregateSpec {
function: AggregateFn::Avg,
input_var: Some(VarId(1)),
function: AggregateFn::Avg(VarId(1), InputSemantics::List),
output_var: VarId(3),
distinct: false,
}])
.unwrap(),
binds: Vec::new(),
Expand All @@ -244,10 +242,8 @@ mod tests {
group_by: fluree_db_core::NonEmpty::try_from_vec(vec![VarId(0)]).unwrap(),
aggregation: Some(Aggregation {
aggregates: fluree_db_core::NonEmpty::try_from_vec(vec![AggregateSpec {
function: AggregateFn::Avg,
input_var: Some(VarId(1)),
function: AggregateFn::Avg(VarId(1), InputSemantics::List),
output_var: VarId(2),
distinct: false,
}])
.unwrap(),
binds: vec![(
Expand Down Expand Up @@ -442,10 +438,8 @@ mod tests {
group_by: fluree_db_core::NonEmpty::try_from_vec(vec![VarId(0)]).unwrap(),
aggregation: Some(Aggregation {
aggregates: fluree_db_core::NonEmpty::try_from_vec(vec![AggregateSpec {
function: AggregateFn::Avg,
input_var: Some(VarId(1)),
function: AggregateFn::Avg(VarId(1), InputSemantics::List),
output_var: VarId(2),
distinct: false,
}])
.unwrap(),
binds: Vec::new(),
Expand Down Expand Up @@ -480,10 +474,8 @@ mod tests {
group_by: fluree_db_core::NonEmpty::try_from_vec(vec![VarId(0)]).unwrap(),
aggregation: Some(Aggregation {
aggregates: fluree_db_core::NonEmpty::try_from_vec(vec![AggregateSpec {
function: AggregateFn::Avg,
input_var: Some(VarId(1)),
function: AggregateFn::Avg(VarId(1), InputSemantics::List),
output_var: VarId(2),
distinct: false,
}])
.unwrap(),
binds: vec![(
Expand Down Expand Up @@ -526,10 +518,8 @@ mod tests {
group_by: fluree_db_core::NonEmpty::try_from_vec(vec![VarId(0)]).unwrap(),
aggregation: Some(Aggregation {
aggregates: fluree_db_core::NonEmpty::try_from_vec(vec![AggregateSpec {
function: AggregateFn::Count,
input_var: Some(VarId(1)),
function: AggregateFn::Count(VarId(1)),
output_var: VarId(2),
distinct: false,
}])
.unwrap(),
binds: Vec::new(),
Expand Down
Loading
Loading