Skip to content

Commit

Permalink
Support value comparison
Browse files Browse the repository at this point in the history
https://www.sqlite.org/datatype3.html#comparison_expressions

Expr now supports NULL. If NULL returned on filter, it skip the row.
  • Loading branch information
kawasin73 committed Aug 14, 2023
1 parent 50be195 commit 349d2a5
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 81 deletions.
112 changes: 71 additions & 41 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ use std::path::Path;

use anyhow::bail;
use anyhow::Context;
use pager::PageId;

// TODO: This is to suppress the unused warning.
// pub use crate::btree::*;
use crate::btree::BtreeContext;
use crate::cursor::BtreeCursor;
use crate::cursor::BtreePayload;
use crate::pager::PageId;
use crate::pager::Pager;
use crate::parser::parse_select;
use crate::parser::BinaryOperator;
Expand All @@ -50,6 +50,7 @@ use crate::schema::Schema;
use crate::schema::Table;
use crate::token::get_token_no_space;
use crate::token::Token;
use crate::value::TypeAffinity;
pub use crate::value::Value;

const SQLITE_MAX_PAGE_SIZE: u32 = 65536;
Expand Down Expand Up @@ -154,11 +155,10 @@ impl Connection {
}
ResultColumn::ColumnName(column_name) => {
let column_name = column_name.dequote();
let column_idx =
table.get_column_index(&column_name).ok_or(anyhow::anyhow!(
"column not found: {}",
std::str::from_utf8(&column_name).unwrap_or_default()
))?;
let (column_idx, _) = table.get_column(&column_name).ok_or(anyhow::anyhow!(
"column not found: {}",
std::str::from_utf8(&column_name).unwrap_or_default()
))?;
columns.push(column_idx);
}
}
Expand All @@ -175,7 +175,7 @@ impl Connection {
right,
}) = &selection
{
if let Selection::Column(column_number) = left.as_ref() {
if let Selection::Column((column_number, _)) = left.as_ref() {
if let Selection::Integer(key) = right.as_ref() {
let mut next_index = table.indexes.as_ref();
while let Some(index) = next_index {
Expand Down Expand Up @@ -211,12 +211,13 @@ impl Connection {
}

enum Selection {
Column(ColumnNumber),
Column((ColumnNumber, TypeAffinity)),
BinaryOperator {
operator: BinaryOperator,
left: Box<Selection>,
right: Box<Selection>,
},
Null,
Integer(i64),
Real(f64),
Text(Vec<u8>),
Expand All @@ -226,6 +227,7 @@ enum Selection {
impl Selection {
fn from(expr: Expr, table: &Table) -> anyhow::Result<Self> {
match expr {
Expr::Null => Ok(Self::Null),
Expr::Integer(i) => Ok(Self::Integer(i)),
Expr::Real(f) => Ok(Self::Real(f)),
Expr::Text(text) => Ok(Self::Text(text.dequote())),
Expand All @@ -242,7 +244,7 @@ impl Selection {
Expr::Column(column_name) => {
let column_name = column_name.dequote();
table
.get_column_index(&column_name)
.get_column(&column_name)
.map(Self::Column)
.ok_or(anyhow::anyhow!(
"column not found: {}",
Expand All @@ -252,46 +254,69 @@ impl Selection {
}
}

fn execute<'a>(&'a self, row: &'a RowRef) -> anyhow::Result<Value<'a>> {
fn execute<'a>(&'a self, row: &'a RowRef) -> anyhow::Result<(Value<'a>, Option<TypeAffinity>)> {
match self {
Self::Column(idx) => row.get(idx),
Self::Column((idx, affinity)) => Ok((row.get(idx)?, Some(*affinity))),
Self::BinaryOperator {
operator,
left,
right,
} => {
let left = left.execute(row)?;
let right = right.execute(row)?;
let result = if operator == &BinaryOperator::Eq {
left == right
} else if operator == &BinaryOperator::Ne {
left != right
} else {
let Value::Integer(left) = left else {
bail!("invalid value for selection: {:?}", left);
};
let Value::Integer(right) = right else {
bail!("invalid value for selection: {:?}", left);
};
match operator {
BinaryOperator::Eq => left == right,
BinaryOperator::Ne => left != right,
BinaryOperator::Lt => left < right,
BinaryOperator::Le => left <= right,
BinaryOperator::Gt => left > right,
BinaryOperator::Ge => left >= right,
let (left_value, left_affinity) = left.execute(row)?;
let (right_value, right_affinity) = right.execute(row)?;

match (&left_value, &right_value) {
(Value::Null, _) => return Ok((Value::Null, None)),
(_, Value::Null) => return Ok((Value::Null, None)),
_ => {}
}

// TODO: Type Conversions Prior To Comparison
match (left_affinity, right_affinity) {
(
Some(TypeAffinity::Integer)
| Some(TypeAffinity::Real)
| Some(TypeAffinity::Numeric),
Some(TypeAffinity::Text) | Some(TypeAffinity::Blob) | None,
) => {
// TODO: Apply numeric affinity to the right operand.
}
(
Some(TypeAffinity::Text) | Some(TypeAffinity::Blob) | None,
Some(TypeAffinity::Integer)
| Some(TypeAffinity::Real)
| Some(TypeAffinity::Numeric),
) => {
// TODO: Apply numeric affinity to the left operand.
}
(Some(TypeAffinity::Text), None) => {
// TODO: Apply text affinity to the right operands.
}
(None, Some(TypeAffinity::Text)) => {
// TODO: Apply text affinity to the left operands.
}
_ => {}
}

let result = match operator {
BinaryOperator::Eq => left_value == right_value,
BinaryOperator::Ne => left_value != right_value,
BinaryOperator::Lt => left_value < right_value,
BinaryOperator::Le => left_value <= right_value,
BinaryOperator::Gt => left_value > right_value,
BinaryOperator::Ge => left_value >= right_value,
};
if result {
Ok(Value::Integer(1))
Ok((Value::Integer(1), None))
} else {
Ok(Value::Integer(0))
Ok((Value::Integer(0), None))
}
}
Self::Integer(value) => Ok(Value::Integer(*value)),
Self::Real(value) => Ok(Value::Real(*value)),
Self::Text(value) => Ok(Value::Text(value)),
Self::Blob(value) => Ok(Value::Blob(value)),
Self::Null => Ok((Value::Null, None)),
Self::Integer(value) => Ok((Value::Integer(*value), None)),
Self::Real(value) => Ok((Value::Real(*value), None)),
Self::Text(value) => Ok((Value::Text(value), None)),
Self::Blob(value) => Ok((Value::Blob(value), None)),
}
}
}
Expand Down Expand Up @@ -319,8 +344,12 @@ impl<'conn> Statement<'conn> {
left,
right,
}) => match (left.as_ref(), right.as_ref()) {
(Selection::Column(ColumnNumber::RowId), Selection::Integer(value)) => Some(*value),
(Selection::Integer(value), Selection::Column(ColumnNumber::RowId)) => Some(*value),
(Selection::Column((ColumnNumber::RowId, _)), Selection::Integer(value)) => {
Some(*value)
}
(Selection::Integer(value), Selection::Column((ColumnNumber::RowId, _))) => {
Some(*value)
}
_ => None,
},
_ => None,
Expand Down Expand Up @@ -440,8 +469,9 @@ impl<'conn> Rows<'conn> {
use_local_buffer,
content_offset,
};
if selection.execute(&column_value_loader)? == Value::Integer(0) {
continue;
match selection.execute(&column_value_loader)?.0 {
Value::Null | Value::Integer(0) => continue,
_ => {}
}
}

Expand Down
5 changes: 5 additions & 0 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ pub enum Expr<'a> {
left: Box<Expr<'a>>,
right: Box<Expr<'a>>,
},
Null,
Integer(i64),
Real(f64),
Text(MaybeQuotedBytes<'a>),
Expand All @@ -357,6 +358,7 @@ fn parse_expr(input: &[u8]) -> Result<(usize, Expr)> {
let input_len = input.len();
let (n, left) = match get_token_no_space(input) {
Some((n, Token::Identifier(id))) => (n, Expr::Column(id)),
Some((n, Token::Null)) => (n, Expr::Null),
Some((n, Token::Integer(buf))) => {
let v = parse_integer_literal(buf);
if v < 0 {
Expand Down Expand Up @@ -608,6 +610,9 @@ mod tests {

#[test]
fn test_parse_expr() {
// Parse null
assert_eq!(parse_expr(b"null").unwrap(), (4, Expr::Null));

// Parse integer
assert_eq!(
parse_expr(b"123456789a").unwrap(),
Expand Down
67 changes: 30 additions & 37 deletions src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ impl Index {
for column in &create_index.columns {
// TODO: use the reference of given column name.
let column_name = column.name.dequote();
let Some(column_number) = table.get_column_index(&column_name) else {
let Some((column_number, _)) = table.get_column(&column_name) else {
bail!(
"column {:?} in create index sql is not found in table {:?}",
column.name,
Expand Down Expand Up @@ -369,21 +369,23 @@ impl Table {
))
}

pub fn get_column_index(&self, column: &[u8]) -> Option<ColumnNumber> {
let column = CaseInsensitiveBytes::from(column);
if let Some(i) = self
pub fn get_column(&self, name: &[u8]) -> Option<(ColumnNumber, TypeAffinity)> {
let column = CaseInsensitiveBytes::from(name);
if let Some((i, column)) = self
.columns
.iter()
.position(|c| CaseInsensitiveBytes::from(&c.name) == column)
.enumerate()
.find(|(_, c)| CaseInsensitiveBytes::from(&c.name) == column)
{
let column = &self.columns[i];
if column.primary_key && column.type_affinity == TypeAffinity::Integer {
Some(ColumnNumber::RowId)
} else {
Some(ColumnNumber::Column(i))
}
let column_number =
if column.primary_key && column.type_affinity == TypeAffinity::Integer {
ColumnNumber::RowId
} else {
ColumnNumber::Column(i)
};
Some((column_number, column.type_affinity))
} else if column.equal_to_lower_bytes(&b"rowid"[..]) {
Some(ColumnNumber::RowId)
Some((ColumnNumber::RowId, TypeAffinity::Integer))
} else {
None
}
Expand Down Expand Up @@ -730,43 +732,34 @@ mod tests {
let schema = generate_schema(file.path());

let table = schema.get_table(b"example").unwrap();
assert_eq!(
table.get_column_index(b"col"),
Some(ColumnNumber::Column(0))
);
assert_eq!(table.get_column_index(b"rowid"), Some(ColumnNumber::RowId));
assert_eq!(table.get_column_index(b"invalid"), None);
assert_eq!(table.get_column(b"col").unwrap().0, ColumnNumber::Column(0));
assert_eq!(table.get_column(b"rowid").unwrap().0, ColumnNumber::RowId);
assert!(table.get_column(b"invalid").is_none());

let table = schema.get_table(b"example2").unwrap();
assert_eq!(
table.get_column_index(b"col1"),
Some(ColumnNumber::Column(0))
table.get_column(b"col1").unwrap().0,
ColumnNumber::Column(0)
);
assert_eq!(
table.get_column_index(b"col2"),
Some(ColumnNumber::Column(1))
table.get_column(b"col2").unwrap().0,
ColumnNumber::Column(1)
);
assert_eq!(
table.get_column_index(b"rowid"),
Some(ColumnNumber::Column(2))
table.get_column(b"rowid").unwrap().0,
ColumnNumber::Column(2)
);
assert_eq!(table.get_column_index(b"invalid"), None);
assert!(table.get_column(b"invalid").is_none());

let table = schema.get_table(b"example3").unwrap();
assert_eq!(table.get_column_index(b"id"), Some(ColumnNumber::RowId));
assert_eq!(
table.get_column_index(b"col"),
Some(ColumnNumber::Column(1))
);
assert_eq!(table.get_column_index(b"rowid"), Some(ColumnNumber::RowId));
assert_eq!(table.get_column(b"id").unwrap().0, ColumnNumber::RowId);
assert_eq!(table.get_column(b"col").unwrap().0, ColumnNumber::Column(1));
assert_eq!(table.get_column(b"rowid").unwrap().0, ColumnNumber::RowId);

let table = schema.get_table(b"example4").unwrap();
assert_eq!(table.get_column_index(b"id"), Some(ColumnNumber::Column(0)));
assert_eq!(
table.get_column_index(b"col"),
Some(ColumnNumber::Column(1))
);
assert_eq!(table.get_column_index(b"rowid"), Some(ColumnNumber::RowId));
assert_eq!(table.get_column(b"id").unwrap().0, ColumnNumber::Column(0));
assert_eq!(table.get_column(b"col").unwrap().0, ColumnNumber::Column(1));
assert_eq!(table.get_column(b"rowid").unwrap().0, ColumnNumber::RowId);
}

#[test]
Expand Down
Loading

0 comments on commit 349d2a5

Please sign in to comment.