diff --git a/Cargo.lock b/Cargo.lock index add311be4..e58932ba6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22,7 +22,7 @@ dependencies = [ "memcached-rs 0.1.2 (git+https://github.com/jonhoo/memcached-rs.git?branch=expose-multi)", "mysql 12.0.3 (registry+https://github.com/rust-lang/crates.io-index)", "net2 0.2.31 (registry+https://github.com/rust-lang/crates.io-index)", - "nom_sql 0.0.1 (git+https://github.com/ms705/nom-sql.git?rev=7758136babe72f470b7ea9fea384fca4f47db647)", + "nom_sql 0.0.1 (git+https://github.com/ms705/nom-sql.git?rev=b01998fc34a5d473387987110724a395298fc6c0)", "petgraph 0.4.5 (git+https://github.com/fintelia/petgraph?branch=serde)", "rand 0.3.16 (registry+https://github.com/rust-lang/crates.io-index)", "regex 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", @@ -812,7 +812,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "nom_sql" version = "0.0.1" -source = "git+https://github.com/ms705/nom-sql.git?rev=7758136babe72f470b7ea9fea384fca4f47db647#7758136babe72f470b7ea9fea384fca4f47db647" +source = "git+https://github.com/ms705/nom-sql.git?rev=b01998fc34a5d473387987110724a395298fc6c0#b01998fc34a5d473387987110724a395298fc6c0" dependencies = [ "nom 1.2.4 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.15 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1621,7 +1621,7 @@ dependencies = [ "checksum nix 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a2c5afeb0198ec7be8569d666644b574345aad2e95a53baf3a532da3e0f3fb32" "checksum nodrop 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "52cd74cd09beba596430cc6e3091b74007169a56246e1262f0ba451ea95117b2" "checksum nom 1.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "a5b8c256fd9471521bcb84c3cdba98921497f1a331cbc15b8030fc63b82050ce" -"checksum nom_sql 0.0.1 (git+https://github.com/ms705/nom-sql.git?rev=7758136babe72f470b7ea9fea384fca4f47db647)" = "" +"checksum nom_sql 0.0.1 (git+https://github.com/ms705/nom-sql.git?rev=b01998fc34a5d473387987110724a395298fc6c0)" = "" "checksum num 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)" = "a311b77ebdc5dd4cf6449d81e4135d9f0e3b153839ac90e648a8ef538f923525" "checksum num-integer 0.1.35 (registry+https://github.com/rust-lang/crates.io-index)" = "d1452e8b06e448a07f0e6ebb0bb1d92b8890eea63288c0b627331d53514d0fba" "checksum num-iter 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)" = "7485fcc84f85b4ecd0ea527b14189281cf27d60e583ae65ebc9c088b13dffe01" diff --git a/Cargo.toml b/Cargo.toml index 30f334c9b..c24be5e46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ vec_map = { version = "0.8.0", features = ["eders"] } hurdles = "1.0.0" arrayvec = "0.4.0" -nom_sql = { git = "https://github.com/ms705/nom-sql.git", rev = "7758136babe72f470b7ea9fea384fca4f47db647"} +nom_sql = { git = "https://github.com/ms705/nom-sql.git", rev = "b01998fc34a5d473387987110724a395298fc6c0"} # for benchmarks # cli diff --git a/src/flow/core/data.rs b/src/flow/core/data.rs index feaaaf77a..4634ba1c1 100644 --- a/src/flow/core/data.rs +++ b/src/flow/core/data.rs @@ -8,9 +8,10 @@ use nom_sql::Literal; use serde_json::Value; use std::hash::{Hash, Hasher}; -use std::ops::{Deref, DerefMut}; +use std::ops::{Add, Deref, DerefMut, Div, Mul, Sub}; use std::fmt; +const FLOAT_PRECISION: f64 = 1000_000_000.0; const TINYTEXT_WIDTH: usize = 15; /// The main type used for user data throughout the codebase. @@ -128,7 +129,7 @@ impl From for DataType { } let mut i = f.trunc() as i32; - let mut frac = (f.fract() * 1000_000_000.0).round() as i32; + let mut frac = (f.fract() * FLOAT_PRECISION).round() as i32; if frac == 1000_000_000 { i += 1; frac = 0; @@ -202,6 +203,17 @@ impl Into for DataType { } } +impl Into for DataType { + fn into(self) -> f64 { + match self { + DataType::Real(i, f) => i as f64 + (f as f64) / FLOAT_PRECISION, + DataType::Int(i) => i as f64, + DataType::BigInt(i) => i as f64, + _ => unreachable!(), + } + } +} + impl From for DataType { fn from(s: String) -> Self { let len = s.as_bytes().len(); @@ -225,6 +237,67 @@ impl<'a> From<&'a str> for DataType { } } +// Performs an arithmetic operation on two numeric DataTypes, +// returning a new DataType as the result. +macro_rules! arithmetic_operation ( + ($op:tt, $first:ident, $second:ident) => ( + match ($first, $second) { + (DataType::Int(a), DataType::Int(b)) => (a $op b).into(), + (DataType::BigInt(a), DataType::BigInt(b)) => (a $op b).into(), + (DataType::Int(a), DataType::BigInt(b)) => ((a as i64) $op b).into(), + (DataType::BigInt(a), DataType::Int(b)) => (a $op (b as i64)).into(), + + (first @ DataType::Int(..), second @ DataType::Real(..)) | + (first @ DataType::Real(..), second @ DataType::Int(..)) | + (first @ DataType::Real(..), second @ DataType::Real(..)) => { + let a: f64 = first.into(); + let b: f64 = second.into(); + (a $op b).into() + } + (first, second) => panic!( + format!( + "can't {} a {:?} and {:?}", + stringify!($op), + first, + second, + ) + ), + } + ); +); + +impl Add for DataType { + type Output = DataType; + + fn add(self, other: DataType) -> DataType { + arithmetic_operation!(+, self, other) + } +} + +impl Sub for DataType { + type Output = DataType; + + fn sub(self, other: DataType) -> DataType { + arithmetic_operation!(-, self, other) + } +} + +impl Mul for DataType { + type Output = DataType; + + fn mul(self, other: DataType) -> DataType { + arithmetic_operation!(*, self, other) + } +} + +impl Div for DataType { + type Output = DataType; + + fn div(self, other: DataType) -> DataType { + arithmetic_operation!(/, self, other) + } +} + impl fmt::Debug for DataType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -442,6 +515,66 @@ mod tests { assert_eq!(c.to_json(), json!(-0.012345678)); } + #[test] + fn real_to_float() { + let original = 2.5; + let data_type: DataType = original.into(); + let converted: f64 = data_type.into(); + assert_eq!(original, converted); + } + + #[test] + fn add_data_types() { + assert_eq!(DataType::from(1) + DataType::from(2), 3.into()); + assert_eq!(DataType::from(1.5) + DataType::from(2), (3.5).into()); + assert_eq!(DataType::from(2) + DataType::from(1.5), (3.5).into()); + assert_eq!(DataType::from(1.5) + DataType::from(2.5), (4.0).into()); + assert_eq!(DataType::BigInt(1) + DataType::BigInt(2), 3.into()); + assert_eq!(DataType::from(1) + DataType::BigInt(2), 3.into()); + assert_eq!(DataType::BigInt(2) + DataType::from(1), 3.into()); + } + + #[test] + fn subtract_data_types() { + assert_eq!(DataType::from(2) - DataType::from(1), 1.into()); + assert_eq!(DataType::from(3.5) - DataType::from(2), (1.5).into()); + assert_eq!(DataType::from(2) - DataType::from(1.5), (0.5).into()); + assert_eq!(DataType::from(3.5) - DataType::from(2.0), (1.5).into()); + assert_eq!(DataType::BigInt(1) - DataType::BigInt(2), (-1).into()); + assert_eq!(DataType::from(1) - DataType::BigInt(2), (-1).into()); + assert_eq!(DataType::BigInt(2) - DataType::from(1), 1.into()); + } + + #[test] + fn multiply_data_types() { + assert_eq!(DataType::from(2) * DataType::from(1), 2.into()); + assert_eq!(DataType::from(3.5) * DataType::from(2), (7.0).into()); + assert_eq!(DataType::from(2) * DataType::from(1.5), (3.0).into()); + assert_eq!(DataType::from(3.5) * DataType::from(2.0), (7.0).into()); + assert_eq!(DataType::BigInt(1) * DataType::BigInt(2), 2.into()); + assert_eq!(DataType::from(1) * DataType::BigInt(2), 2.into()); + assert_eq!(DataType::BigInt(2) * DataType::from(1), 2.into()); + } + + #[test] + fn divide_data_types() { + assert_eq!(DataType::from(2) / DataType::from(1), 2.into()); + assert_eq!(DataType::from(7.5) / DataType::from(2), (3.75).into()); + assert_eq!(DataType::from(7) / DataType::from(2.5), (2.8).into()); + assert_eq!(DataType::from(3.5) / DataType::from(2.0), (1.75).into()); + assert_eq!(DataType::BigInt(4) / DataType::BigInt(2), 2.into()); + assert_eq!(DataType::from(4) / DataType::BigInt(2), 2.into()); + assert_eq!(DataType::BigInt(4) / DataType::from(2), 2.into()); + } + + #[test] + #[should_panic(expected = "can't + a TinyText(\"hi\") and Int(5)")] + fn add_invalid_types() { + let a: DataType = "hi".into(); + let b: DataType = 5.into(); + a + b; + } + #[test] fn data_type_debug() { let tiny_text: DataType = "hi".into(); diff --git a/src/mir/node.rs b/src/mir/node.rs index ef1643339..34b865ee8 100644 --- a/src/mir/node.rs +++ b/src/mir/node.rs @@ -1,4 +1,4 @@ -use nom_sql::{Column, ColumnSpecification, Operator, OrderType}; +use nom_sql::{ArithmeticExpression, Column, ColumnSpecification, Operator, OrderType}; use std::cell::RefCell; use std::fmt::{Debug, Display, Error, Formatter}; use std::rc::Rc; @@ -477,6 +477,7 @@ impl MirNode { MirNodeType::Project { ref emit, ref literals, + ref arithmetic, } => { assert_eq!(self.ancestors.len(), 1); let parent = self.ancestors[0].clone(); @@ -485,6 +486,7 @@ impl MirNode { parent, self.columns.as_slice(), emit, + arithmetic, literals, mig, ) @@ -598,6 +600,7 @@ pub enum MirNodeType { /// emit columns Project { emit: Vec, + arithmetic: Vec<(String, ArithmeticExpression)>, literals: Vec<(String, DataType)>, }, /// emit columns @@ -772,11 +775,13 @@ impl MirNodeType { MirNodeType::Project { emit: ref our_emit, literals: ref our_literals, + arithmetic: ref our_arithmetic, } => match *other { MirNodeType::Project { ref emit, ref literals, - } => our_emit == emit && our_literals == literals, + ref arithmetic, + } => our_emit == emit && our_literals == literals && our_arithmetic == arithmetic, _ => false, }, MirNodeType::Reuse { node: ref us } => { @@ -994,14 +999,29 @@ impl Debug for MirNodeType { MirNodeType::Project { ref emit, ref literals, + ref arithmetic, } => write!( f, - "π [{}{}]", + "π [{}{}{}]", emit.iter() .map(|c| c.name.as_str()) .collect::>() .join(", "), - if literals.len() > 0 { + if arithmetic.is_empty() { + format!("") + } else { + format!( + ", {}", + arithmetic + .iter() + .map(|&(ref n, ref e)| format!("{}: {:?}", n, e)) + .collect::>() + .join(", ") + ) + }, + if literals.is_empty() { + format!("") + } else { format!( ", lit: {}", literals @@ -1010,9 +1030,7 @@ impl Debug for MirNodeType { .collect::>() .join(", ") ) - } else { - format!("") - } + }, ), MirNodeType::Reuse { ref node } => write!( f, diff --git a/src/mir/reuse.rs b/src/mir/reuse.rs index ff3c83df6..07aafa626 100644 --- a/src/mir/reuse.rs +++ b/src/mir/reuse.rs @@ -314,6 +314,7 @@ mod tests { vec![Column::from("aa")], MirNodeType::Project { emit: vec![Column::from("aa")], + arithmetic: vec![], literals: vec![], }, vec![c.clone()], diff --git a/src/mir/to_flow.rs b/src/mir/to_flow.rs index 4b3d89629..f60d527ec 100644 --- a/src/mir/to_flow.rs +++ b/src/mir/to_flow.rs @@ -1,4 +1,5 @@ -use nom_sql::{Column, ColumnConstraint, ColumnSpecification, Operator, OrderType}; +use nom_sql::{ArithmeticBase, ArithmeticExpression, Column, ColumnConstraint, ColumnSpecification, + Operator, OrderType}; use std::collections::HashMap; use flow::Migration; @@ -9,7 +10,7 @@ use mir::node::GroupedNodeType; use ops; use ops::join::{Join, JoinType}; use ops::latest::Latest; -use ops::project::Project; +use ops::project::{Project, ProjectExpression, ProjectExpressionBase}; #[derive(Clone, Debug)] pub enum FlowNode { @@ -360,11 +361,26 @@ pub(crate) fn make_latest_node( FlowNode::New(na) } +// Converts a nom_sql::ArithmeticBase into a project::ProjectExpressionBase: +fn generate_projection_base(parent: &MirNodeRef, base: &ArithmeticBase) -> ProjectExpressionBase { + match *base { + ArithmeticBase::Column(ref column) => { + let column_id = parent.borrow().column_id_for_column(column); + ProjectExpressionBase::Column(column_id) + } + ArithmeticBase::Scalar(ref literal) => { + let data: DataType = literal.into(); + ProjectExpressionBase::Literal(data) + } + } +} + pub(crate) fn make_project_node( name: &str, parent: MirNodeRef, columns: &[Column], emit: &Vec, + arithmetic: &Vec<(String, ArithmeticExpression)>, literals: &Vec<(String, DataType)>, mig: &mut Migration, ) -> FlowNode { @@ -377,6 +393,17 @@ pub(crate) fn make_project_node( let (_, literal_values): (Vec<_>, Vec<_>) = literals.iter().cloned().unzip(); + let projected_arithmetic: Vec = arithmetic + .iter() + .map(|&(_, ref e)| { + ProjectExpression::new( + e.op.clone(), + generate_projection_base(&parent, &e.left), + generate_projection_base(&parent, &e.right), + ) + }) + .collect(); + let n = mig.add_ingredient( String::from(name), column_names.as_slice(), @@ -384,6 +411,7 @@ pub(crate) fn make_project_node( parent_na, projected_column_ids.as_slice(), Some(literal_values), + Some(projected_arithmetic), ), ); FlowNode::New(n) diff --git a/src/mir/visualize.rs b/src/mir/visualize.rs index ba82a7d36..5310c64f6 100644 --- a/src/mir/visualize.rs +++ b/src/mir/visualize.rs @@ -216,15 +216,30 @@ impl GraphViz for MirNodeType { MirNodeType::Project { ref emit, ref literals, + ref arithmetic, } => { write!( out, - "π: {}{}", + "π: {}{}{}", emit.iter() .map(|c| c.name.as_str()) .collect::>() .join(", "), - if literals.len() > 0 { + if arithmetic.is_empty() { + format!("") + } else { + format!( + ", {}", + arithmetic + .iter() + .map(|&(ref n, ref e)| format!("{}: {:?}", n, e)) + .collect::>() + .join(", ") + ) + }, + if literals.is_empty() { + format!("") + } else { format!( ", lit: {}", literals @@ -233,8 +248,6 @@ impl GraphViz for MirNodeType { .collect::>() .join(", ") ) - } else { - format!("") } )?; } diff --git a/src/ops/project.rs b/src/ops/project.rs index 5d9d5b28e..f737fc8db 100644 --- a/src/ops/project.rs +++ b/src/ops/project.rs @@ -1,23 +1,82 @@ +use nom_sql::ArithmeticOperator; +use std::fmt; use std::collections::HashMap; use flow::prelude::*; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ProjectExpressionBase { + Column(usize), + Literal(DataType), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProjectExpression { + op: ArithmeticOperator, + left: ProjectExpressionBase, + right: ProjectExpressionBase, +} + +impl ProjectExpression { + pub fn new( + op: ArithmeticOperator, + left: ProjectExpressionBase, + right: ProjectExpressionBase, + ) -> ProjectExpression { + ProjectExpression { + op: op, + left: left, + right: right, + } + } +} + +impl fmt::Display for ProjectExpressionBase { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ProjectExpressionBase::Column(u) => write!(f, "{}", u), + ProjectExpressionBase::Literal(ref l) => write!(f, "(lit: {})", l), + } + } +} + +impl fmt::Display for ProjectExpression { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let op = match self.op { + ArithmeticOperator::Add => "+", + ArithmeticOperator::Subtract => "-", + ArithmeticOperator::Divide => "/", + ArithmeticOperator::Multiply => "*", + }; + + write!(f, "{} {} {}", self.left, op, self.right) + } +} + + /// Permutes or omits columns from its source node, or adds additional literal value columns. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Project { us: Option, emit: Option>, additional: Option>, + expressions: Option>, src: IndexPair, cols: usize, } impl Project { /// Construct a new permuter operator. - pub fn new(src: NodeIndex, emit: &[usize], additional: Option>) -> Project { + pub fn new( + src: NodeIndex, + emit: &[usize], + additional: Option>, + expressions: Option>, + ) -> Project { Project { emit: Some(emit.into()), additional: additional, + expressions: expressions, src: src.into(), cols: 0, us: None, @@ -34,6 +93,25 @@ impl Project { self.emit.as_ref().map_or(col, |emit| emit[col]) } } + + fn eval_expression(&self, expression: &ProjectExpression, record: &Record) -> DataType { + let left = match expression.left { + ProjectExpressionBase::Column(i) => &record[i], + ProjectExpressionBase::Literal(ref data) => data, + }.clone(); + + let right = match expression.right { + ProjectExpressionBase::Column(i) => &record[i], + ProjectExpressionBase::Literal(ref data) => data, + }.clone(); + + match expression.op { + ArithmeticOperator::Add => left + right, + ArithmeticOperator::Subtract => left - right, + ArithmeticOperator::Multiply => left * right, + ArithmeticOperator::Divide => left / right, + } + } } impl Ingredient for Project { @@ -89,6 +167,12 @@ impl Ingredient for Project { for i in e { new_r.push(r[*i].clone()); } + match self.expressions { + Some(ref e) => for i in e { + new_r.push(self.eval_expression(i, r)); + }, + None => (), + } match self.additional { Some(ref a) => for i in a { new_r.push(i.clone()); @@ -113,21 +197,31 @@ impl Ingredient for Project { } fn description(&self) -> String { - let emit_cols = match self.emit.as_ref() { - None => "*".into(), - Some(emit) => match self.additional { - None => emit.iter() - .map(|e| e.to_string()) - .collect::>() - .join(", "), - Some(ref add) => emit.iter() - .map(|e| e.to_string()) - .chain(add.iter().map(|e| format!("lit: {}", e.to_string()))) - .collect::>() - .join(", "), - }, + let mut emit_cols = vec![]; + match self.emit.as_ref() { + None => emit_cols.push("*".to_string()), + Some(emit) => { + emit_cols.extend(emit.iter().map(|e| e.to_string()).collect::>()); + + if let Some(ref arithmetic) = self.expressions { + emit_cols.extend( + arithmetic + .iter() + .map(|e| format!("{}", e)) + .collect::>(), + ); + } + + if let Some(ref add) = self.additional { + emit_cols.extend( + add.iter() + .map(|e| format!("lit: {}", e.to_string())) + .collect::>(), + ); + } + } }; - format!("π[{}]", emit_cols) + format!("π[{}]", emit_cols.join(", ")) } fn parent_columns(&self, column: usize) -> Vec<(NodeIndex, Option)> { @@ -159,18 +253,53 @@ mod tests { g.set_op( "permute", &["x", "y", "z"], - Project::new(s.as_global(), &permutation[..], additional), + Project::new(s.as_global(), &permutation[..], additional, None), materialized, ); g } + fn setup_arithmetic(expression: ProjectExpression) -> ops::test::MockGraph { + let mut g = ops::test::MockGraph::new(); + let s = g.add_base("source", &["x", "y", "z"]); + + let permutation = vec![0, 1]; + g.set_op( + "permute", + &["x", "y", "z"], + Project::new( + s.as_global(), + &permutation[..], + None, + Some(vec![expression]), + ), + false, + ); + g + } + + fn setup_column_arithmetic(op: ArithmeticOperator) -> ops::test::MockGraph { + let expression = ProjectExpression { + left: ProjectExpressionBase::Column(0), + right: ProjectExpressionBase::Column(1), + op: op, + }; + + setup_arithmetic(expression) + } + #[test] fn it_describes() { let p = setup(false, false, true); assert_eq!(p.node().description(), "π[2, 0, lit: \"hello\", lit: 42]"); } + #[test] + fn it_describes_arithmetic() { + let p = setup_column_arithmetic(ArithmeticOperator::Add); + assert_eq!(p.node().description(), "π[0, 1, 0 + 1]"); + } + #[test] fn it_describes_all() { let p = setup(false, true, false); @@ -227,6 +356,81 @@ mod tests { ); } + #[test] + fn it_forwards_addition_arithmetic() { + let mut p = setup_column_arithmetic(ArithmeticOperator::Add); + let rec = vec![10.into(), 20.into()]; + assert_eq!( + p.narrow_one_row(rec, false), + vec![vec![10.into(), 20.into(), 30.into()]].into() + ); + } + + #[test] + fn it_forwards_subtraction_arithmetic() { + let mut p = setup_column_arithmetic(ArithmeticOperator::Subtract); + let rec = vec![10.into(), 20.into()]; + assert_eq!( + p.narrow_one_row(rec, false), + vec![vec![10.into(), 20.into(), (-10).into()]].into() + ); + } + + #[test] + fn it_forwards_multiplication_arithmetic() { + let mut p = setup_column_arithmetic(ArithmeticOperator::Multiply); + let rec = vec![10.into(), 20.into()]; + assert_eq!( + p.narrow_one_row(rec, false), + vec![vec![10.into(), 20.into(), 200.into()]].into() + ); + } + + #[test] + fn it_forwards_division_arithmetic() { + let mut p = setup_column_arithmetic(ArithmeticOperator::Divide); + let rec = vec![10.into(), 2.into()]; + assert_eq!( + p.narrow_one_row(rec, false), + vec![vec![10.into(), 2.into(), 5.into()]].into() + ); + } + + #[test] + fn it_forwards_arithmetic_w_literals() { + let number: DataType = 40.into(); + let expression = ProjectExpression { + left: ProjectExpressionBase::Column(0), + right: ProjectExpressionBase::Literal(number), + op: ArithmeticOperator::Multiply, + }; + + let mut p = setup_arithmetic(expression); + let rec = vec![10.into(), 0.into()]; + assert_eq!( + p.narrow_one_row(rec, false), + vec![vec![10.into(), 0.into(), 400.into()]].into() + ); + } + + #[test] + fn it_forwards_arithmetic_w_only_literals() { + let a: DataType = 80.into(); + let b: DataType = 40.into(); + let expression = ProjectExpression { + left: ProjectExpressionBase::Literal(a), + right: ProjectExpressionBase::Literal(b), + op: ArithmeticOperator::Divide, + }; + + let mut p = setup_arithmetic(expression); + let rec = vec![0.into(), 0.into()]; + assert_eq!( + p.narrow_one_row(rec, false), + vec![vec![0.into(), 0.into(), 2.into()]].into() + ); + } + #[test] fn it_suggests_indices() { let me = 1.into(); diff --git a/src/sql/mir.rs b/src/sql/mir.rs index c07db277c..bbce4521d 100644 --- a/src/sql/mir.rs +++ b/src/sql/mir.rs @@ -7,8 +7,8 @@ use mir::query::MirQuery; pub use mir::to_flow::FlowNode; use ops::join::JoinType; -use nom_sql::{Column, ColumnSpecification, ConditionBase, ConditionExpression, ConditionTree, - Literal, Operator, SqlQuery, TableKey}; +use nom_sql::{ArithmeticExpression, Column, ColumnSpecification, ConditionBase, + ConditionExpression, ConditionTree, Literal, Operator, SqlQuery, TableKey}; use nom_sql::{LimitClause, OrderClause, SelectStatement}; use sql::query_graph::{JoinRef, OutputColumn, QueryGraph, QueryGraphEdge}; @@ -168,6 +168,7 @@ impl SqlToMirConverter { MirNodeType::Project { emit: columns.clone(), literals: vec![], + arithmetic: vec![], }, vec![parent.clone()], vec![], @@ -718,6 +719,7 @@ impl SqlToMirConverter { name, parent, vec![fn_col], + vec![], vec![(String::from("grp"), DataType::from(0 as i32))], ) } @@ -727,11 +729,17 @@ impl SqlToMirConverter { name: &str, parent_node: MirNodeRef, proj_cols: Vec<&Column>, + arithmetic: Vec<(String, ArithmeticExpression)>, literals: Vec<(String, DataType)>, ) -> MirNodeRef { //assert!(proj_cols.iter().all(|c| c.table == parent_name)); - let literal_names: Vec = literals.iter().map(|&(ref n, _)| n.clone()).collect(); + let names: Vec = literals + .iter() + .map(|&(ref n, _)| n.clone()) + .chain(arithmetic.iter().map(|&(ref n, _)| n.clone())) + .collect(); + let fields = proj_cols .clone() .into_iter() @@ -744,7 +752,7 @@ impl SqlToMirConverter { }, None => c.clone(), }) - .chain(literal_names.into_iter().map(|n| { + .chain(names.into_iter().map(|n| { Column { name: n, alias: None, @@ -776,6 +784,7 @@ impl SqlToMirConverter { MirNodeType::Project { emit: emit_cols, literals: literals, + arithmetic: arithmetic, }, vec![parent_node.clone()], vec![], @@ -1320,6 +1329,7 @@ impl SqlToMirConverter { let mut projected_columns: Vec<&Column> = qg.columns .iter() .filter_map(|oc| match *oc { + OutputColumn::Arithmetic(_) => None, OutputColumn::Data(ref c) => Some(c), OutputColumn::Literal(_) => None, }) @@ -1329,9 +1339,20 @@ impl SqlToMirConverter { projected_columns.push(pc); } } + let projected_arithmetic: Vec<(String, ArithmeticExpression)> = qg.columns + .iter() + .filter_map(|oc| match *oc { + OutputColumn::Arithmetic(ref ac) => { + Some((ac.name.clone(), ac.expression.clone())) + } + OutputColumn::Data(_) => None, + OutputColumn::Literal(_) => None, + }) + .collect(); let projected_literals: Vec<(String, DataType)> = qg.columns .iter() .filter_map(|oc| match *oc { + OutputColumn::Arithmetic(_) => None, OutputColumn::Data(_) => None, OutputColumn::Literal(ref lc) => { Some((lc.name.clone(), DataType::from(&lc.value))) @@ -1340,8 +1361,13 @@ impl SqlToMirConverter { .collect(); let ident = format!("q_{:x}_n{}", qg.signature().hash, new_node_count); - let leaf_project_node = - self.make_project_node(&ident, final_node, projected_columns, projected_literals); + let leaf_project_node = self.make_project_node( + &ident, + final_node, + projected_columns, + projected_arithmetic, + projected_literals, + ); nodes_added.push(leaf_project_node.clone()); // We always materialize leaves of queries (at least currently), so add a diff --git a/src/sql/mod.rs b/src/sql/mod.rs index f6926a29d..a7030b997 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -9,7 +9,7 @@ use flow::Migration; use flow::prelude::NodeIndex; use mir::reuse as mir_reuse; use nom_sql::parser as sql_parser; -use nom_sql::{Column, SqlQuery}; +use nom_sql::{ArithmeticBase, Column, SqlQuery}; use nom_sql::SelectStatement; use self::mir::{MirNodeRef, SqlToMirConverter}; use mir::query::MirQuery; @@ -193,6 +193,18 @@ impl SqlIncorporator { // GROUP BY clause if qg.columns.iter().all(|c| match *c { OutputColumn::Literal(_) => true, + OutputColumn::Arithmetic(ref ac) => { + let mut is_function = false; + if let ArithmeticBase::Column(ref c) = ac.expression.left { + is_function = is_function || c.function.is_some(); + } + + if let ArithmeticBase::Column(ref c) = ac.expression.right { + is_function = is_function || c.function.is_some(); + } + + !is_function + } OutputColumn::Data(ref dc) => dc.function.is_none(), }) { // QGs are identical, except for parameters (or their order) @@ -708,7 +720,7 @@ mod tests { // Should have two nodes: source and "users" base table let ncount = mig.graph().node_count(); assert_eq!(ncount, 2); - assert_eq!(get_node(&inc, &*mig, "users").name(), "users"); + assert_eq!(get_node(&inc, mig, "users").name(), "users"); assert!( "SELECT users.id from users;" @@ -743,9 +755,9 @@ mod tests { ); // Should have source and "users" base table node assert_eq!(mig.graph().node_count(), 2); - assert_eq!(get_node(&inc, &*mig, "users").name(), "users"); - assert_eq!(get_node(&inc, &*mig, "users").fields(), &["id", "name"]); - assert_eq!(get_node(&inc, &*mig, "users").description(), "B"); + assert_eq!(get_node(&inc, mig, "users").name(), "users"); + assert_eq!(get_node(&inc, mig, "users").fields(), &["id", "name"]); + assert_eq!(get_node(&inc, mig, "users").description(), "B"); // Establish a base write type for "articles" assert!( @@ -757,12 +769,12 @@ mod tests { ); // Should have source and "users" base table node assert_eq!(mig.graph().node_count(), 3); - assert_eq!(get_node(&inc, &*mig, "articles").name(), "articles"); + assert_eq!(get_node(&inc, mig, "articles").name(), "articles"); assert_eq!( - get_node(&inc, &*mig, "articles").fields(), + get_node(&inc, mig, "articles").fields(), &["id", "author", "title"] ); - assert_eq!(get_node(&inc, &*mig, "articles").description(), "B"); + assert_eq!(get_node(&inc, mig, "articles").description(), "B"); // Try a simple equi-JOIN query let q = "SELECT users.name, articles.title \ @@ -776,13 +788,13 @@ mod tests { &[&Column::from("articles.title"), &Column::from("users.name")], ); // join node - let new_join_view = get_node(&inc, &*mig, &format!("q_{:x}_n0", qid)); + let new_join_view = get_node(&inc, mig, &format!("q_{:x}_n0", qid)); assert_eq!( new_join_view.fields(), &["id", "author", "title", "id", "name"] ); // leaf node - let new_leaf_view = get_node(&inc, &*mig, &q.unwrap().name); + let new_leaf_view = get_node(&inc, mig, &q.unwrap().name); assert_eq!(new_leaf_view.fields(), &["name", "title"]); assert_eq!(new_leaf_view.description(), format!("π[4, 2]")); }); @@ -801,9 +813,9 @@ mod tests { ); // Should have source and "users" base table node assert_eq!(mig.graph().node_count(), 2); - assert_eq!(get_node(&inc, &*mig, "users").name(), "users"); - assert_eq!(get_node(&inc, &*mig, "users").fields(), &["id", "name"]); - assert_eq!(get_node(&inc, &*mig, "users").description(), "B"); + assert_eq!(get_node(&inc, mig, "users").name(), "users"); + assert_eq!(get_node(&inc, mig, "users").fields(), &["id", "name"]); + assert_eq!(get_node(&inc, mig, "users").description(), "B"); // Try a simple query let res = inc.add_query( @@ -819,11 +831,11 @@ mod tests { &[&Column::from("users.name")], ); // filter node - let filter = get_node(&inc, &*mig, &format!("q_{:x}_n0_p0_f0", qid)); + let filter = get_node(&inc, mig, &format!("q_{:x}_n0_p0_f0", qid)); assert_eq!(filter.fields(), &["id", "name"]); assert_eq!(filter.description(), format!("σ[f0 = 42]")); // leaf view node - let edge = get_node(&inc, &*mig, &res.unwrap().name); + let edge = get_node(&inc, mig, &res.unwrap().name); assert_eq!(edge.fields(), &["name"]); assert_eq!(edge.description(), format!("π[1]")); }); @@ -842,9 +854,9 @@ mod tests { ); // Should have source and "users" base table node assert_eq!(mig.graph().node_count(), 2); - assert_eq!(get_node(&inc, &*mig, "votes").name(), "votes"); - assert_eq!(get_node(&inc, &*mig, "votes").fields(), &["aid", "userid"]); - assert_eq!(get_node(&inc, &*mig, "votes").description(), "B"); + assert_eq!(get_node(&inc, mig, "votes").name(), "votes"); + assert_eq!(get_node(&inc, mig, "votes").fields(), &["aid", "userid"]); + assert_eq!(get_node(&inc, mig, "votes").description(), "B"); // Try a simple COUNT function let res = inc.add_query( @@ -874,11 +886,11 @@ mod tests { }, ], ); - let agg_view = get_node(&inc, &*mig, &format!("q_{:x}_n0", qid)); + let agg_view = get_node(&inc, mig, &format!("q_{:x}_n0", qid)); assert_eq!(agg_view.fields(), &["aid", "votes"]); assert_eq!(agg_view.description(), format!("|*| γ[0]")); // check edge view - let edge_view = get_node(&inc, &*mig, &res.unwrap().name); + let edge_view = get_node(&inc, mig, &res.unwrap().name); assert_eq!(edge_view.fields(), &["votes"]); assert_eq!(edge_view.description(), format!("π[1]")); }); @@ -926,9 +938,9 @@ mod tests { ); // Should have source and "users" base table node assert_eq!(mig.graph().node_count(), 2); - assert_eq!(get_node(&inc, &*mig, "users").name(), "users"); - assert_eq!(get_node(&inc, &*mig, "users").fields(), &["id", "name"]); - assert_eq!(get_node(&inc, &*mig, "users").description(), "B"); + assert_eq!(get_node(&inc, mig, "users").name(), "users"); + assert_eq!(get_node(&inc, mig, "users").fields(), &["id", "name"]); + assert_eq!(get_node(&inc, mig, "users").description(), "B"); // Add a new query let res = inc.add_query("SELECT id, name FROM users WHERE users.id = 42;", None, mig); @@ -964,12 +976,12 @@ mod tests { ); // Should have source and "users" base table node assert_eq!(mig.graph().node_count(), 2); - assert_eq!(get_node(&inc, &*mig, "users").name(), "users"); + assert_eq!(get_node(&inc, mig, "users").name(), "users"); assert_eq!( - get_node(&inc, &*mig, "users").fields(), + get_node(&inc, mig, "users").fields(), &["id", "name", "address"] ); - assert_eq!(get_node(&inc, &*mig, "users").description(), "B"); + assert_eq!(get_node(&inc, mig, "users").description(), "B"); // Add a new query let res = inc.add_query("SELECT id, name FROM users WHERE users.id = ?;", None, mig); @@ -990,7 +1002,7 @@ mod tests { assert_eq!(mig.graph().node_count(), ncount + 2); // only the identity node is returned in the vector of new nodes assert_eq!(qfp.new_nodes.len(), 1); - assert_eq!(get_node(&inc, &*mig, &qfp.name).description(), "≡"); + assert_eq!(get_node(&inc, mig, &qfp.name).description(), "≡"); // we should be based off the identity as our leaf let id_node = qfp.new_nodes.iter().next().unwrap(); assert_eq!(qfp.query_leaf, *id_node); @@ -1009,10 +1021,7 @@ mod tests { assert_eq!(mig.graph().node_count(), ncount + 2); // only the projection node is returned in the vector of new nodes assert_eq!(qfp.new_nodes.len(), 1); - assert_eq!( - get_node(&inc, &*mig, &qfp.name).description(), - "π[0, 1, 2]" - ); + assert_eq!(get_node(&inc, mig, &qfp.name).description(), "π[0, 1, 2]"); // we should be based off the new projection as our leaf let id_node = qfp.new_nodes.iter().next().unwrap(); assert_eq!(qfp.query_leaf, *id_node); @@ -1032,9 +1041,9 @@ mod tests { ); // Should have source and "users" base table node assert_eq!(mig.graph().node_count(), 2); - assert_eq!(get_node(&inc, &*mig, "votes").name(), "votes"); - assert_eq!(get_node(&inc, &*mig, "votes").fields(), &["aid", "userid"]); - assert_eq!(get_node(&inc, &*mig, "votes").description(), "B"); + assert_eq!(get_node(&inc, mig, "votes").name(), "votes"); + assert_eq!(get_node(&inc, mig, "votes").fields(), &["aid", "userid"]); + assert_eq!(get_node(&inc, mig, "votes").description(), "B"); // Try a simple COUNT function without a GROUP BY clause let res = inc.add_query("SELECT COUNT(votes.userid) AS count FROM votes;", None, mig); assert!(res.is_ok()); @@ -1057,16 +1066,16 @@ mod tests { }, ], ); - let proj_helper_view = get_node(&inc, &*mig, &format!("q_{:x}_n0_prj_hlpr", qid)); + let proj_helper_view = get_node(&inc, mig, &format!("q_{:x}_n0_prj_hlpr", qid)); assert_eq!(proj_helper_view.fields(), &["userid", "grp"]); assert_eq!(proj_helper_view.description(), format!("π[1, lit: 0]")); // check aggregation view - let agg_view = get_node(&inc, &*mig, &format!("q_{:x}_n0", qid)); + let agg_view = get_node(&inc, mig, &format!("q_{:x}_n0", qid)); assert_eq!(agg_view.fields(), &["grp", "count"]); assert_eq!(agg_view.description(), format!("|*| γ[1]")); // check edge view -- note that it's not actually currently possible to read from // this for a lack of key (the value would be the key) - let edge_view = get_node(&inc, &*mig, &res.unwrap().name); + let edge_view = get_node(&inc, mig, &res.unwrap().name); assert_eq!(edge_view.fields(), &["count"]); assert_eq!(edge_view.description(), format!("π[1]")); }); @@ -1085,9 +1094,9 @@ mod tests { ); // Should have source and "users" base table node assert_eq!(mig.graph().node_count(), 2); - assert_eq!(get_node(&inc, &*mig, "votes").name(), "votes"); - assert_eq!(get_node(&inc, &*mig, "votes").fields(), &["userid", "aid"]); - assert_eq!(get_node(&inc, &*mig, "votes").description(), "B"); + assert_eq!(get_node(&inc, mig, "votes").name(), "votes"); + assert_eq!(get_node(&inc, mig, "votes").fields(), &["userid", "aid"]); + assert_eq!(get_node(&inc, mig, "votes").description(), "B"); // Try a simple COUNT function without a GROUP BY clause let res = inc.add_query( "SELECT COUNT(*) AS count FROM votes GROUP BY votes.userid;", @@ -1111,12 +1120,12 @@ mod tests { }, ], ); - let agg_view = get_node(&inc, &*mig, &format!("q_{:x}_n0", qid)); + let agg_view = get_node(&inc, mig, &format!("q_{:x}_n0", qid)); assert_eq!(agg_view.fields(), &["userid", "count"]); assert_eq!(agg_view.description(), format!("|*| γ[0]")); // check edge view -- note that it's not actually currently possible to read from // this for a lack of key (the value would be the key) - let edge_view = get_node(&inc, &*mig, &res.unwrap().name); + let edge_view = get_node(&inc, mig, &res.unwrap().name); assert_eq!(edge_view.fields(), &["count"]); assert_eq!(edge_view.description(), format!("π[1]")); }); @@ -1169,7 +1178,7 @@ mod tests { // XXX(malte): non-deterministic join ordering make it difficult to assert on the join // views // leaf view - let leaf_view = get_node(&inc, &*mig, "q_3"); + let leaf_view = get_node(&inc, mig, "q_3"); assert_eq!(leaf_view.fields(), &["name", "title", "uid"]); }); } @@ -1220,20 +1229,20 @@ mod tests { &Column::from("votes.uid"), ], ); - let join1_view = get_node(&inc, &*mig, &format!("q_{:x}_n0", qid)); + let join1_view = get_node(&inc, mig, &format!("q_{:x}_n0", qid)); // articles join votes assert_eq!( join1_view.fields(), &["aid", "title", "author", "id", "name"] ); - let join2_view = get_node(&inc, &*mig, &format!("q_{:x}_n1", qid)); + let join2_view = get_node(&inc, mig, &format!("q_{:x}_n1", qid)); // join1_view join users assert_eq!( join2_view.fields(), &["aid", "title", "author", "id", "name", "aid", "uid"] ); // leaf view - let leaf_view = get_node(&inc, &*mig, "q_3"); + let leaf_view = get_node(&inc, mig, "q_3"); assert_eq!(leaf_view.fields(), &["name", "title", "uid"]); }); } @@ -1253,12 +1262,33 @@ mod tests { assert!(res.is_ok()); // leaf view node - let edge = get_node(&inc, &*mig, &res.unwrap().name); + let edge = get_node(&inc, mig, &res.unwrap().name); assert_eq!(edge.fields(), &["name", "literal"]); assert_eq!(edge.description(), format!("π[1, lit: 1]")); }); } + #[test] + fn it_incorporates_arithmetic_projection() { + // set up graph + let mut g = Blender::new(); + let mut inc = SqlIncorporator::default(); + g.migrate(|mig| { + assert!( + inc.add_query("CREATE TABLE users (id int, age int);", None, mig) + .is_ok() + ); + + let res = inc.add_query("SELECT 2 * users.age FROM users;", None, mig); + assert!(res.is_ok()); + + // leaf view node + let edge = get_node(&inc, mig, &res.unwrap().name); + assert_eq!(edge.fields(), &["arithmetic"]); + assert_eq!(edge.description(), format!("π[(lit: 2) * 1]")); + }); + } + #[test] fn it_incorporates_join_with_nested_query() { let mut g = Blender::new(); @@ -1294,13 +1324,13 @@ mod tests { ], ); // join node - let new_join_view = get_node(&inc, &*mig, &format!("q_{:x}_n0", qid)); + let new_join_view = get_node(&inc, mig, &format!("q_{:x}_n0", qid)); assert_eq!( new_join_view.fields(), &["id", "name", "id", "author", "title"] ); // leaf node - let new_leaf_view = get_node(&inc, &*mig, &q.unwrap().name); + let new_leaf_view = get_node(&inc, mig, &q.unwrap().name); assert_eq!(new_leaf_view.fields(), &["name", "title"]); assert_eq!(new_leaf_view.description(), format!("π[1, 4]")); }); diff --git a/src/sql/passes/count_star_rewrite.rs b/src/sql/passes/count_star_rewrite.rs index 54598f293..46e907f94 100644 --- a/src/sql/passes/count_star_rewrite.rs +++ b/src/sql/passes/count_star_rewrite.rs @@ -96,6 +96,7 @@ impl CountStarRewrite for SqlQuery { &mut FieldExpression::All => panic!(err), &mut FieldExpression::AllInTable(_) => panic!(err), &mut FieldExpression::Literal(_) => (), + &mut FieldExpression::Arithmetic(_) => (), &mut FieldExpression::Col(ref mut c) => { rewrite_count_star(c, &tables, &avoid_cols) } diff --git a/src/sql/passes/implied_tables.rs b/src/sql/passes/implied_tables.rs index 9393a68e9..573f3cb2b 100644 --- a/src/sql/passes/implied_tables.rs +++ b/src/sql/passes/implied_tables.rs @@ -1,5 +1,5 @@ -use nom_sql::{Column, ConditionExpression, ConditionTree, FieldExpression, JoinRightSide, - SqlQuery, Table}; +use nom_sql::{ArithmeticBase, Column, ConditionExpression, ConditionTree, FieldExpression, + JoinRightSide, SqlQuery, Table}; use std::collections::HashMap; @@ -174,6 +174,15 @@ impl ImpliedTableExpansion for SqlQuery { &mut FieldExpression::All => panic!(err), &mut FieldExpression::AllInTable(_) => panic!(err), &mut FieldExpression::Literal(_) => (), + &mut FieldExpression::Arithmetic(ref mut e) => { + if let ArithmeticBase::Column(ref mut c) = e.left { + *c = expand_columns(c.clone(), &tables); + } + + if let ArithmeticBase::Column(ref mut c) = e.right { + *c = expand_columns(c.clone(), &tables); + } + } &mut FieldExpression::Col(ref mut f) => { *f = expand_columns(f.clone(), &tables); } diff --git a/src/sql/passes/star_expansion.rs b/src/sql/passes/star_expansion.rs index c19e97a6b..2aa1af1c6 100644 --- a/src/sql/passes/star_expansion.rs +++ b/src/sql/passes/star_expansion.rs @@ -37,6 +37,9 @@ impl StarExpansion for SqlQuery { let v: Vec<_> = expand_table(t).collect(); v.into_iter() } + FieldExpression::Arithmetic(a) => { + vec![FieldExpression::Arithmetic(a)].into_iter() + } FieldExpression::Literal(l) => vec![FieldExpression::Literal(l)].into_iter(), FieldExpression::Col(c) => vec![FieldExpression::Col(c)].into_iter(), }) diff --git a/src/sql/query_graph.rs b/src/sql/query_graph.rs index 9a81d1d95..4945a9300 100644 --- a/src/sql/query_graph.rs +++ b/src/sql/query_graph.rs @@ -1,5 +1,6 @@ -use nom_sql::{Column, ConditionBase, ConditionExpression, ConditionTree, FieldExpression, - JoinConstraint, JoinOperator, JoinRightSide, Literal, Operator}; +use nom_sql::{ArithmeticBase, ArithmeticExpression, Column, ConditionBase, ConditionExpression, + ConditionTree, FieldExpression, JoinConstraint, JoinOperator, JoinRightSide, + Literal, Operator}; use nom_sql::SelectStatement; use nom_sql::ConditionExpression::*; @@ -18,15 +19,28 @@ pub struct LiteralColumn { pub value: Literal, } +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct ArithmeticColumn { + pub name: String, + pub table: Option, + pub expression: ArithmeticExpression, +} + #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub enum OutputColumn { Data(Column), + Arithmetic(ArithmeticColumn), Literal(LiteralColumn), } impl Ord for OutputColumn { fn cmp(&self, other: &OutputColumn) -> Ordering { match *self { + OutputColumn::Arithmetic(ArithmeticColumn { + ref name, + ref table, + .. + }) | OutputColumn::Data(Column { ref name, ref table, @@ -37,6 +51,11 @@ impl Ord for OutputColumn { ref table, .. }) => match *other { + OutputColumn::Arithmetic(ArithmeticColumn { + name: ref other_name, + table: ref other_table, + .. + }) | OutputColumn::Data(Column { name: ref other_name, table: ref other_table, @@ -62,6 +81,11 @@ impl Ord for OutputColumn { impl PartialOrd for OutputColumn { fn partial_cmp(&self, other: &OutputColumn) -> Option { match *self { + OutputColumn::Arithmetic(ArithmeticColumn { + ref name, + ref table, + .. + }) | OutputColumn::Data(Column { ref name, ref table, @@ -72,6 +96,11 @@ impl PartialOrd for OutputColumn { ref table, .. }) => match *other { + OutputColumn::Arithmetic(ArithmeticColumn { + name: ref other_name, + table: ref other_table, + .. + }) | OutputColumn::Data(Column { name: ref other_name, table: ref other_table, @@ -419,6 +448,7 @@ pub fn to_query_graph(st: &SelectStatement) -> Result { // No need to do anything for literals here, as they aren't associated with a // relation (and thus have no QGN) FieldExpression::Literal(_) => None, + FieldExpression::Arithmetic(_) => None, FieldExpression::Col(ref c) => { match c.table.as_ref() { None => { @@ -639,6 +669,23 @@ pub fn to_query_graph(st: &SelectStatement) -> Result { } } + // Adds a computed column to the query graph if the given column has a function: + let add_computed_column = |query_graph: &mut QueryGraph, column: &Column| { + match column.function { + None => (), // we've already dealt with this column as part of some relation + Some(_) => { + // add a special node representing the computed columns; if it already + // exists, add another computed column to it + let n = query_graph + .relations + .entry(String::from("computed_columns")) + .or_insert_with(|| new_node(String::from("computed_columns"), vec![], st)); + + n.columns.push(column.clone()); + } + } + }; + // 4. Add query graph nodes for any computed columns, which won't be represented in the // nodes corresponding to individual relations. for field in st.fields.iter() { @@ -653,20 +700,23 @@ pub fn to_query_graph(st: &SelectStatement) -> Result { value: l.clone(), })); } - FieldExpression::Col(ref c) => { - match c.function { - None => (), // we've already dealt with this column as part of some relation - Some(_) => { - // add a special node representing the computed columns; if it already - // exists, add another computed column to it - let n = qg.relations - .entry(String::from("computed_columns")) - .or_insert_with( - || new_node(String::from("computed_columns"), vec![], st), - ); - n.columns.push(c.clone()); - } + FieldExpression::Arithmetic(ref a) => { + if let ArithmeticBase::Column(ref c) = a.left { + add_computed_column(&mut qg, c); + } + + if let ArithmeticBase::Column(ref c) = a.right { + add_computed_column(&mut qg, c); } + + qg.columns.push(OutputColumn::Arithmetic(ArithmeticColumn { + name: String::from("arithmetic"), + table: None, + expression: a.clone(), + })); + } + FieldExpression::Col(ref c) => { + add_computed_column(&mut qg, c); qg.columns.push(OutputColumn::Data(c.clone())); } } diff --git a/tests/lib.rs b/tests/lib.rs index ed2c7cc8e..50e207bd9 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -293,7 +293,7 @@ fn it_works_deletion() { fn it_works_with_sql_recipe() { let mut g = distributary::Blender::new(); let sql = " - CREATE Table Car (id int, brand varchar(255), PRIMARY KEY(id)); + CREATE TABLE Car (id int, brand varchar(255), PRIMARY KEY(id)); CountCars: SELECT COUNT(*) FROM Car WHERE brand = ?; "; @@ -322,6 +322,149 @@ fn it_works_with_sql_recipe() { assert_eq!(result[0][0], 2.into()); } +#[test] +fn it_works_with_simple_arithmetic() { + let mut g = distributary::Blender::new(); + let sql = " + CREATE TABLE Car (id int, price int, PRIMARY KEY(id)); + CarPrice: SELECT 2 * price FROM Car WHERE id = ?; + "; + + let recipe = g.migrate(|mig| { + let mut recipe = distributary::Recipe::from_str(&sql, None).unwrap(); + recipe.activate(mig, false).unwrap(); + recipe + }); + + let car_index = recipe.node_addr_for("Car").unwrap(); + let count_index = recipe.node_addr_for("CarPrice").unwrap(); + let mut mutator = g.get_mutator(car_index); + let getter = g.get_getter(count_index).unwrap(); + let id: distributary::DataType = 1.into(); + let price: distributary::DataType = 123.into(); + mutator.put(vec![id.clone(), price]).unwrap(); + + // Let writes propagate: + sleep(); + + // Retrieve the result of the count query: + let result = getter.lookup(&id, true).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0][1], 246.into()); +} + +#[test] +fn it_works_with_multiple_arithmetic_expressions() { + let mut g = distributary::Blender::new(); + let sql = " + CREATE TABLE Car (id int, price int, PRIMARY KEY(id)); + CarPrice: SELECT 10 * 10, 2 * price, 10 * price, FROM Car WHERE id = ?; + "; + + let recipe = g.migrate(|mig| { + let mut recipe = distributary::Recipe::from_str(&sql, None).unwrap(); + recipe.activate(mig, false).unwrap(); + recipe + }); + + let car_index = recipe.node_addr_for("Car").unwrap(); + let count_index = recipe.node_addr_for("CarPrice").unwrap(); + let mut mutator = g.get_mutator(car_index); + let getter = g.get_getter(count_index).unwrap(); + let id: distributary::DataType = 1.into(); + let price: distributary::DataType = 123.into(); + mutator.put(vec![id.clone(), price]).unwrap(); + + // Let writes propagate: + sleep(); + + // Retrieve the result of the count query: + let result = getter.lookup(&id, true).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0][1], 100.into()); + assert_eq!(result[0][2], 246.into()); + assert_eq!(result[0][3], 1230.into()); +} + +#[test] +fn it_works_with_join_arithmetic() { + let mut g = distributary::Blender::new(); + let sql = " + CREATE TABLE Car (car_id int, price_id int, PRIMARY KEY(car_id)); + CREATE TABLE Price (price_id int, price int, PRIMARY KEY(price_id)); + CREATE TABLE Sales (sales_id int, price_id int, fraction float, PRIMARY KEY(sales_id)); + CarPrice: SELECT price * fraction FROM Car \ + JOIN Price ON Car.price_id = Price.price_id \ + JOIN Sales ON Price.price_id = Sales.price_id \ + WHERE car_id = ?; + "; + + let recipe = g.migrate(|mig| { + let mut recipe = distributary::Recipe::from_str(&sql, None).unwrap(); + recipe.activate(mig, false).unwrap(); + recipe + }); + + let car_index = recipe.node_addr_for("Car").unwrap(); + let price_index = recipe.node_addr_for("Price").unwrap(); + let sales_index = recipe.node_addr_for("Sales").unwrap(); + let query_index = recipe.node_addr_for("CarPrice").unwrap(); + let mut car_mutator = g.get_mutator(car_index); + let mut price_mutator = g.get_mutator(price_index); + let mut sales_mutator = g.get_mutator(sales_index); + let getter = g.get_getter(query_index).unwrap(); + let id = 1; + let price = 123; + let fraction = 0.7; + car_mutator.put(vec![id.into(), id.into()]).unwrap(); + price_mutator.put(vec![id.into(), price.into()]).unwrap(); + sales_mutator + .put(vec![id.into(), id.into(), fraction.into()]) + .unwrap(); + + // Let writes propagate: + sleep(); + + // Retrieve the result of the count query: + let result = getter.lookup(&id.into(), true).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0][1], (price as f64 * fraction).into()); +} + +#[test] +fn it_works_with_function_arithmetic() { + let mut g = distributary::Blender::new(); + let sql = " + CREATE TABLE Bread (id int, price int, PRIMARY KEY(id)); + Price: SELECT 2 * MAX(price) FROM Bread; + "; + + let recipe = g.migrate(|mig| { + let mut recipe = distributary::Recipe::from_str(&sql, None).unwrap(); + recipe.activate(mig, false).unwrap(); + recipe + }); + + let bread_index = recipe.node_addr_for("Bread").unwrap(); + let query_index = recipe.node_addr_for("Price").unwrap(); + let mut mutator = g.get_mutator(bread_index); + let getter = g.get_getter(query_index).unwrap(); + let max_price = 20; + for (i, price) in (10..max_price + 1).enumerate() { + let id = (i + 1) as i32; + mutator.put(vec![id.into(), price.into()]).unwrap(); + } + + // Let writes propagate: + sleep(); + + // Retrieve the result of the count query: + let key = distributary::DataType::BigInt(max_price * 2); + let result = getter.lookup(&key, true).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0][0], key); +} + #[test] fn votes() { use distributary::{Aggregation, Base, Join, JoinType, Union}; @@ -750,7 +893,7 @@ fn migrate_added_columns() { let b = mig.add_ingredient( "x", &["c", "b"], - distributary::Project::new(a, &[2, 0], None), + distributary::Project::new(a, &[2, 0], None, None), ); mig.maintain(b, 1); b @@ -871,7 +1014,7 @@ fn key_on_added() { let b = mig.add_ingredient( "x", &["c", "b"], - distributary::Project::new(a, &[2, 1], None), + distributary::Project::new(a, &[2, 1], None, None), ); mig.maintain(b, 0); b @@ -1004,7 +1147,7 @@ fn full_aggregation_with_bogokey() { let bogo = mig.add_ingredient( "bogo", &["x", "bogo"], - distributary::Project::new(base, &[0], Some(vec![0.into()])), + distributary::Project::new(base, &[0], Some(vec![0.into()]), None), ); let agg = mig.add_ingredient( "agg",