diff --git a/src/expression.rs b/src/expression.rs index e8fb675..e10420c 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -13,8 +13,7 @@ // limitations under the License. use std::cmp::Ordering; - -use anyhow::bail; +use std::fmt::Display; use crate::parser::BinaryOp; use crate::parser::CompareOp; @@ -32,6 +31,51 @@ use crate::value::Value; use crate::value::ValueCmp; use crate::value::DEFAULT_COLLATION; +#[derive(Debug)] +pub enum Error { + CollationNotFound, + ColumnNotFound, + NoTableContext, + FailGetColumn(Box), +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::CollationNotFound => None, + Self::ColumnNotFound => None, + Self::NoTableContext => None, + Self::FailGetColumn(e) => Some(e.as_ref()), + } + } +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::CollationNotFound => { + write!(f, "collation not found") + } + Self::ColumnNotFound => { + write!(f, "column not found") + } + Self::NoTableContext => { + write!(f, "no table context") + } + Self::FailGetColumn(e) => { + write!(f, "fail to get column: {}", e) + } + } + } +} + +pub type Result = std::result::Result; +pub type ExecutionResult<'a> = Result<( + Option>, + Option, + Option<(&'a Collation, CollateOrigin)>, +)>; + #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum CollateOrigin { Column, @@ -48,7 +92,10 @@ fn filter_expression_collation( } pub trait DataContext { - fn get_column_value(&self, column_idx: &ColumnNumber) -> anyhow::Result>; + fn get_column_value( + &self, + column_idx: &ColumnNumber, + ) -> std::result::Result, Box>; } #[derive(Debug, Clone)] @@ -75,14 +122,8 @@ pub enum Expression { Const(ConstantValue), } -type ExecutionResult<'a> = ( - Option>, - Option, - Option<(&'a Collation, CollateOrigin)>, -); - impl Expression { - pub fn from(expr: Expr, table: Option<&Table>) -> anyhow::Result { + pub fn from(expr: Expr, table: Option<&Table>) -> Result { match expr { Expr::Null => Ok(Self::Null), Expr::Integer(i) => Ok(Self::Const(ConstantValue::Integer(i))), @@ -98,7 +139,7 @@ impl Expression { collation_name, } => Ok(Self::Collate { expr: Box::new(Self::from(*expr, table)?), - collation: calc_collation(&collation_name)?, + collation: calc_collation(&collation_name).ok_or(Error::CollationNotFound)?, }), Expr::BinaryOperator { operator, @@ -115,12 +156,9 @@ impl Expression { table .get_column(&column_name) .map(Self::Column) - .ok_or(anyhow::anyhow!( - "column not found: {}", - std::str::from_utf8(&column_name).unwrap_or_default() - )) + .ok_or(Error::ColumnNotFound) } else { - bail!("no table context is not specified"); + Err(Error::NoTableContext) } } Expr::Cast { expr, type_name } => Ok(Self::Cast { @@ -133,20 +171,17 @@ impl Expression { /// Execute the expression and return the result. /// /// TODO: The row should be a context object. - pub fn execute<'a, D: DataContext>( - &'a self, - row: Option<&'a D>, - ) -> anyhow::Result> { + pub fn execute<'a, D: DataContext>(&'a self, row: Option<&'a D>) -> ExecutionResult<'a> { match self { Self::Column((idx, affinity, collation)) => { if let Some(row) = row { Ok(( - row.get_column_value(idx)?, + row.get_column_value(idx).map_err(Error::FailGetColumn)?, Some(*affinity), Some((collation, CollateOrigin::Column)), )) } else { - bail!("column value is not available"); + Err(Error::NoTableContext) } } Self::UnaryOperator { operator, expr } => { diff --git a/src/lib.rs b/src/lib.rs index 9410e5d..aee48eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -83,6 +83,7 @@ const MAX_ROWID: i64 = i64::MAX; pub enum Error<'a> { Parse(parser::Error<'a>), Cursor(cursor::Error), + Expression(expression::Error), UniqueConstraintViolation, DataTypeMismatch, Unsupported(&'static str), @@ -101,12 +102,20 @@ impl From for Error<'_> { } } +impl From for Error<'_> { + fn from(e: expression::Error) -> Self { + Self::Expression(e) + } +} + impl From for Error<'_> { fn from(e: anyhow::Error) -> Self { Self::Other(e) } } +impl std::error::Error for Error<'_> {} + impl Display for Error<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -116,6 +125,9 @@ impl Display for Error<'_> { Error::Cursor(e) => { write!(f, "Btree cursor error: {}", e) } + Error::Expression(e) => { + write!(f, "expression error: {}", e) + } Error::DataTypeMismatch => { write!(f, "data type mismatch") } @@ -638,7 +650,7 @@ pub struct Rows<'conn> { } impl<'conn> Rows<'conn> { - pub fn next_row(&mut self) -> anyhow::Result>> { + pub fn next_row(&mut self) -> Result>> { if self.completed { return Ok(None); } @@ -656,7 +668,7 @@ impl<'conn> Rows<'conn> { } Err(e) => { self.completed = true; - return Err(e); + return Err(Error::Other(e)); } } @@ -667,7 +679,7 @@ impl<'conn> Rows<'conn> { headers = parse_record_header(&payload)?; if headers.is_empty() { - bail!("empty header payload"); + return Err(Error::Other(anyhow::anyhow!("empty header payload"))); } content_offset = headers[0].1; @@ -680,7 +692,9 @@ impl<'conn> Rows<'conn> { tmp_buf.resize(content_size, 0); let n = payload.load(content_offset, &mut tmp_buf)?; if n != content_size { - bail!("payload does not have enough size"); + return Err(Error::Other(anyhow::anyhow!( + "payload does not have enough size" + ))); } }; @@ -781,7 +795,10 @@ pub struct RowData<'a> { } impl<'a> DataContext for RowData<'a> { - fn get_column_value(&self, column_idx: &ColumnNumber) -> anyhow::Result> { + fn get_column_value( + &self, + column_idx: &ColumnNumber, + ) -> std::result::Result, Box> { match column_idx { ColumnNumber::Column(idx) => { if let Some((serial_type, offset)) = self.headers.get(*idx) { @@ -790,9 +807,13 @@ impl<'a> DataContext for RowData<'a> { } else { &self.tmp_buf }; - serial_type - .parse(&contents_buffer[offset - self.content_offset..]) - .context("parse value") + let offset = offset - self.content_offset; + if contents_buffer.len() < offset + || contents_buffer.len() - offset < serial_type.content_size() as usize + { + return Err(anyhow::anyhow!("payload does not have enough size").into()); + } + Ok(serial_type.parse(&contents_buffer[offset..])) } else { Ok(None) } @@ -808,7 +829,7 @@ pub struct Row<'a> { } impl<'a> Row<'a> { - pub fn parse(&self) -> anyhow::Result> { + pub fn parse(&self) -> Result> { let mut columns = Vec::with_capacity(self.stmt.columns.len()); for expr in self.stmt.columns.iter() { let (value, _, _) = expr.execute(Some(&self.data))?; diff --git a/src/record.rs b/src/record.rs index bdcbd92..6677fdf 100644 --- a/src/record.rs +++ b/src/record.rs @@ -70,32 +70,31 @@ impl SerialType { } } - pub fn parse<'a>(&self, buf: &'a [u8]) -> anyhow::Result>> { - let v = match self.0 { + /// Parse the buffer into [Value]. + /// + /// The buffer must be at least [content_size] bytes. + pub fn parse<'a>(&self, buf: &'a [u8]) -> Option> { + match self.0 { 0 => None, 1 => Some(Value::Integer( - i8::from_be_bytes(buf[..1].try_into()?) as i64 + i8::from_be_bytes(buf[..1].try_into().unwrap()) as i64, )), 2 => Some(Value::Integer( - i16::from_be_bytes(buf[..2].try_into()?) as i64 + i16::from_be_bytes(buf[..2].try_into().unwrap()) as i64, )), // TODO: use std::mem::transmute. 3 => { - if buf.len() < 3 { - bail!("buffer size {} does not match integer 3", buf.len()); - } + assert!(buf.len() >= 3); Some(Value::Integer( ((buf[0] as i64) << 56 | (buf[1] as i64) << 48 | (buf[2] as i64) << 40) >> 40, )) } 4 => Some(Value::Integer( - i32::from_be_bytes(buf[..4].try_into()?) as i64 + i32::from_be_bytes(buf[..4].try_into().unwrap()) as i64, )), // TODO: use std::mem::transmute. 5 => { - if buf.len() < 6 { - bail!("buffer size {} does not match integer 6", buf.len()); - } + assert!(buf.len() >= 6); Some(Value::Integer( ((buf[0] as i64) << 56 | (buf[1] as i64) << 48 @@ -106,9 +105,11 @@ impl SerialType { >> 16, )) } - 6 => Some(Value::Integer(i64::from_be_bytes(buf[..8].try_into()?))), + 6 => Some(Value::Integer(i64::from_be_bytes( + buf[..8].try_into().unwrap(), + ))), 7 => { - let f = f64::from_be_bytes(buf[..8].try_into()?); + let f = f64::from_be_bytes(buf[..8].try_into().unwrap()); if f.is_nan() { None } else { @@ -118,17 +119,11 @@ impl SerialType { 8 => Some(Value::Integer(0)), 9 => Some(Value::Integer(1)), 10 | 11 => { - bail!("reserved record is not implemented"); + unreachable!("reserved record is not implemented"); } n => { let size = ((n - 12) >> 1) as usize; - if buf.len() < size { - bail!( - "buffer size {} is smaller than content size {}", - buf.len(), - size - ); - } + assert!(buf.len() >= size); let buf = &buf[..size]; let v = if n & 1 == 0 { Value::Blob(Buffer::Ref(buf)) @@ -137,8 +132,7 @@ impl SerialType { }; Some(v) } - }; - Ok(v) + } } } @@ -193,7 +187,7 @@ impl<'a, P: LocalPayload, E: Debug> Record<'a, P, E> { } else { &self.payload.buf()[offset..offset + content_size] }; - serial_type.parse(buf) + Ok(serial_type.parse(buf)) } } @@ -674,22 +668,15 @@ mod tests { #[test] fn test_parse_real() { assert_eq!( - SerialType(7).parse(0_f64.to_be_bytes().as_slice()).unwrap(), + SerialType(7).parse(0_f64.to_be_bytes().as_slice()), Some(Value::Real(0.0)) ); assert_eq!( - SerialType(7) - .parse(1.1_f64.to_be_bytes().as_slice()) - .unwrap(), + SerialType(7).parse(1.1_f64.to_be_bytes().as_slice()), Some(Value::Real(1.1)) ); // NaN - assert_eq!( - SerialType(7) - .parse(f64::NAN.to_be_bytes().as_slice()) - .unwrap(), - None - ); + assert_eq!(SerialType(7).parse(f64::NAN.to_be_bytes().as_slice()), None); } #[test] diff --git a/src/schema.rs b/src/schema.rs index 0b43e03..e20428b 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -134,11 +134,16 @@ impl Schema { pub fn generate(stmt: SelectStatement, schema_table: Table) -> anyhow::Result { let stmt = stmt; - let mut rows = stmt.query()?; + let mut rows = stmt + .query() + .map_err(|e| anyhow::anyhow!("query: {:?}", e))?; let mut tables = HashMap::new(); let mut indexes = HashMap::new(); - while let Some(row) = rows.next_row()? { - let columns = row.parse()?; + while let Some(row) = rows + .next_row() + .map_err(|e| anyhow::anyhow!("next row: {:?}", e))? + { + let columns = row.parse().map_err(|e| anyhow::anyhow!("parse: {:?}", e))?; let schema = SchemaRecord::parse(&columns)?; match schema.type_ { b"table" => { @@ -338,18 +343,18 @@ pub fn calc_type_affinity(type_name: &[MaybeQuotedBytes]) -> TypeAffinity { /// This now supports BINARY, NOCASE, and RTRIM only. /// /// TODO: Support user defined collating sequence. -pub fn calc_collation(collation_name: &MaybeQuotedBytes) -> anyhow::Result { +pub fn calc_collation(collation_name: &MaybeQuotedBytes) -> Option { // TODO: Validate with iterator. let collation_name = collation_name.dequote(); let case_insensitive_collation_name = CaseInsensitiveBytes::from(collation_name.as_slice()); if case_insensitive_collation_name.equal_to_lower_bytes(b"binary") { - Ok(Collation::Binary) + Some(Collation::Binary) } else if case_insensitive_collation_name.equal_to_lower_bytes(b"nocase") { - Ok(Collation::NoCase) + Some(Collation::NoCase) } else if case_insensitive_collation_name.equal_to_lower_bytes(b"rtrim") { - Ok(Collation::RTrim) + Some(Collation::RTrim) } else { - bail!("invalid collation: {:?}", collation_name); + None } } @@ -397,7 +402,8 @@ impl Table { let mut collation = DEFAULT_COLLATION.clone(); for constraint in &column_def.constraints { if let ColumnConstraint::Collate(collation_name) = constraint { - collation = calc_collation(collation_name)?; + collation = calc_collation(collation_name) + .ok_or_else(|| anyhow::anyhow!("collation is not found"))?; } }