Skip to content

Commit

Permalink
Refactor Expression errors
Browse files Browse the repository at this point in the history
Remove anyhow from expression
  • Loading branch information
kawasin73 committed Dec 5, 2023
1 parent d15357a commit 976fdc9
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 74 deletions.
79 changes: 57 additions & 22 deletions src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
// limitations under the License.

use std::cmp::Ordering;

use anyhow::bail;
use std::fmt::Display;

use crate::parser::BinaryOp;
use crate::parser::CompareOp;
Expand All @@ -32,6 +31,51 @@ use crate::value::Value;
use crate::value::ValueCmp;
use crate::value::DEFAULT_COLLATION;

#[derive(Debug)]
pub enum Error {
CollationNotFound,
ColumnNotFound,
NoTableContext,
FailGetColumn(Box<dyn std::error::Error + Sync + Send>),
}

impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::CollationNotFound => None,
Self::ColumnNotFound => None,
Self::NoTableContext => None,
Self::FailGetColumn(e) => Some(e.as_ref()),
}
}
}

impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CollationNotFound => {
write!(f, "collation not found")
}
Self::ColumnNotFound => {
write!(f, "column not found")
}
Self::NoTableContext => {
write!(f, "no table context")
}
Self::FailGetColumn(e) => {
write!(f, "fail to get column: {}", e)
}
}
}
}

pub type Result<T> = std::result::Result<T, Error>;
pub type ExecutionResult<'a> = Result<(
Option<Value<'a>>,
Option<TypeAffinity>,
Option<(&'a Collation, CollateOrigin)>,
)>;

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum CollateOrigin {
Column,
Expand All @@ -48,7 +92,10 @@ fn filter_expression_collation(
}

pub trait DataContext {
fn get_column_value(&self, column_idx: &ColumnNumber) -> anyhow::Result<Option<Value>>;
fn get_column_value(
&self,
column_idx: &ColumnNumber,
) -> std::result::Result<Option<Value>, Box<dyn std::error::Error + Sync + Send>>;
}

#[derive(Debug, Clone)]
Expand All @@ -75,14 +122,8 @@ pub enum Expression {
Const(ConstantValue),
}

type ExecutionResult<'a> = (
Option<Value<'a>>,
Option<TypeAffinity>,
Option<(&'a Collation, CollateOrigin)>,
);

impl Expression {
pub fn from(expr: Expr, table: Option<&Table>) -> anyhow::Result<Self> {
pub fn from(expr: Expr, table: Option<&Table>) -> Result<Self> {
match expr {
Expr::Null => Ok(Self::Null),
Expr::Integer(i) => Ok(Self::Const(ConstantValue::Integer(i))),
Expand All @@ -98,7 +139,7 @@ impl Expression {
collation_name,
} => Ok(Self::Collate {
expr: Box::new(Self::from(*expr, table)?),
collation: calc_collation(&collation_name)?,
collation: calc_collation(&collation_name).ok_or(Error::CollationNotFound)?,
}),
Expr::BinaryOperator {
operator,
Expand All @@ -115,12 +156,9 @@ impl Expression {
table
.get_column(&column_name)
.map(Self::Column)
.ok_or(anyhow::anyhow!(
"column not found: {}",
std::str::from_utf8(&column_name).unwrap_or_default()
))
.ok_or(Error::ColumnNotFound)
} else {
bail!("no table context is not specified");
Err(Error::NoTableContext)
}
}
Expr::Cast { expr, type_name } => Ok(Self::Cast {
Expand All @@ -133,20 +171,17 @@ impl Expression {
/// Execute the expression and return the result.
///
/// TODO: The row should be a context object.
pub fn execute<'a, D: DataContext>(
&'a self,
row: Option<&'a D>,
) -> anyhow::Result<ExecutionResult<'a>> {
pub fn execute<'a, D: DataContext>(&'a self, row: Option<&'a D>) -> ExecutionResult<'a> {
match self {
Self::Column((idx, affinity, collation)) => {
if let Some(row) = row {
Ok((
row.get_column_value(idx)?,
row.get_column_value(idx).map_err(Error::FailGetColumn)?,
Some(*affinity),
Some((collation, CollateOrigin::Column)),
))
} else {
bail!("column value is not available");
Err(Error::NoTableContext)
}
}
Self::UnaryOperator { operator, expr } => {
Expand Down
39 changes: 30 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ const MAX_ROWID: i64 = i64::MAX;
pub enum Error<'a> {
Parse(parser::Error<'a>),
Cursor(cursor::Error),
Expression(expression::Error),
UniqueConstraintViolation,
DataTypeMismatch,
Unsupported(&'static str),
Expand All @@ -101,12 +102,20 @@ impl From<cursor::Error> for Error<'_> {
}
}

impl From<expression::Error> for Error<'_> {
fn from(e: expression::Error) -> Self {
Self::Expression(e)
}
}

impl From<anyhow::Error> for Error<'_> {
fn from(e: anyhow::Error) -> Self {
Self::Other(e)
}
}

impl std::error::Error for Error<'_> {}

impl Display for Error<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand All @@ -116,6 +125,9 @@ impl Display for Error<'_> {
Error::Cursor(e) => {
write!(f, "Btree cursor error: {}", e)
}
Error::Expression(e) => {
write!(f, "expression error: {}", e)
}
Error::DataTypeMismatch => {
write!(f, "data type mismatch")
}
Expand Down Expand Up @@ -638,7 +650,7 @@ pub struct Rows<'conn> {
}

impl<'conn> Rows<'conn> {
pub fn next_row(&mut self) -> anyhow::Result<Option<Row<'_>>> {
pub fn next_row(&mut self) -> Result<Option<Row<'_>>> {
if self.completed {
return Ok(None);
}
Expand All @@ -656,7 +668,7 @@ impl<'conn> Rows<'conn> {
}
Err(e) => {
self.completed = true;
return Err(e);
return Err(Error::Other(e));
}
}

Expand All @@ -667,7 +679,7 @@ impl<'conn> Rows<'conn> {
headers = parse_record_header(&payload)?;

if headers.is_empty() {
bail!("empty header payload");
return Err(Error::Other(anyhow::anyhow!("empty header payload")));
}

content_offset = headers[0].1;
Expand All @@ -680,7 +692,9 @@ impl<'conn> Rows<'conn> {
tmp_buf.resize(content_size, 0);
let n = payload.load(content_offset, &mut tmp_buf)?;
if n != content_size {
bail!("payload does not have enough size");
return Err(Error::Other(anyhow::anyhow!(
"payload does not have enough size"
)));
}
};

Expand Down Expand Up @@ -781,7 +795,10 @@ pub struct RowData<'a> {
}

impl<'a> DataContext for RowData<'a> {
fn get_column_value(&self, column_idx: &ColumnNumber) -> anyhow::Result<Option<Value>> {
fn get_column_value(
&self,
column_idx: &ColumnNumber,
) -> std::result::Result<Option<Value>, Box<dyn std::error::Error + Sync + Send>> {
match column_idx {
ColumnNumber::Column(idx) => {
if let Some((serial_type, offset)) = self.headers.get(*idx) {
Expand All @@ -790,9 +807,13 @@ impl<'a> DataContext for RowData<'a> {
} else {
&self.tmp_buf
};
serial_type
.parse(&contents_buffer[offset - self.content_offset..])
.context("parse value")
let offset = offset - self.content_offset;
if contents_buffer.len() < offset
|| contents_buffer.len() - offset < serial_type.content_size() as usize
{
return Err(anyhow::anyhow!("payload does not have enough size").into());
}
Ok(serial_type.parse(&contents_buffer[offset..]))
} else {
Ok(None)
}
Expand All @@ -808,7 +829,7 @@ pub struct Row<'a> {
}

impl<'a> Row<'a> {
pub fn parse(&self) -> anyhow::Result<Columns<'_>> {
pub fn parse(&self) -> Result<Columns<'_>> {
let mut columns = Vec::with_capacity(self.stmt.columns.len());
for expr in self.stmt.columns.iter() {
let (value, _, _) = expr.execute(Some(&self.data))?;
Expand Down
55 changes: 21 additions & 34 deletions src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,32 +70,31 @@ impl SerialType {
}
}

pub fn parse<'a>(&self, buf: &'a [u8]) -> anyhow::Result<Option<Value<'a>>> {
let v = match self.0 {
/// Parse the buffer into [Value].
///
/// The buffer must be at least [content_size] bytes.
pub fn parse<'a>(&self, buf: &'a [u8]) -> Option<Value<'a>> {
match self.0 {
0 => None,
1 => Some(Value::Integer(
i8::from_be_bytes(buf[..1].try_into()?) as i64
i8::from_be_bytes(buf[..1].try_into().unwrap()) as i64,
)),
2 => Some(Value::Integer(
i16::from_be_bytes(buf[..2].try_into()?) as i64
i16::from_be_bytes(buf[..2].try_into().unwrap()) as i64,
)),
// TODO: use std::mem::transmute.
3 => {
if buf.len() < 3 {
bail!("buffer size {} does not match integer 3", buf.len());
}
assert!(buf.len() >= 3);
Some(Value::Integer(
((buf[0] as i64) << 56 | (buf[1] as i64) << 48 | (buf[2] as i64) << 40) >> 40,
))
}
4 => Some(Value::Integer(
i32::from_be_bytes(buf[..4].try_into()?) as i64
i32::from_be_bytes(buf[..4].try_into().unwrap()) as i64,
)),
// TODO: use std::mem::transmute.
5 => {
if buf.len() < 6 {
bail!("buffer size {} does not match integer 6", buf.len());
}
assert!(buf.len() >= 6);
Some(Value::Integer(
((buf[0] as i64) << 56
| (buf[1] as i64) << 48
Expand All @@ -106,9 +105,11 @@ impl SerialType {
>> 16,
))
}
6 => Some(Value::Integer(i64::from_be_bytes(buf[..8].try_into()?))),
6 => Some(Value::Integer(i64::from_be_bytes(
buf[..8].try_into().unwrap(),
))),
7 => {
let f = f64::from_be_bytes(buf[..8].try_into()?);
let f = f64::from_be_bytes(buf[..8].try_into().unwrap());
if f.is_nan() {
None
} else {
Expand All @@ -118,17 +119,11 @@ impl SerialType {
8 => Some(Value::Integer(0)),
9 => Some(Value::Integer(1)),
10 | 11 => {
bail!("reserved record is not implemented");
unreachable!("reserved record is not implemented");
}
n => {
let size = ((n - 12) >> 1) as usize;
if buf.len() < size {
bail!(
"buffer size {} is smaller than content size {}",
buf.len(),
size
);
}
assert!(buf.len() >= size);
let buf = &buf[..size];
let v = if n & 1 == 0 {
Value::Blob(Buffer::Ref(buf))
Expand All @@ -137,8 +132,7 @@ impl SerialType {
};
Some(v)
}
};
Ok(v)
}
}
}

Expand Down Expand Up @@ -193,7 +187,7 @@ impl<'a, P: LocalPayload<E>, E: Debug> Record<'a, P, E> {
} else {
&self.payload.buf()[offset..offset + content_size]
};
serial_type.parse(buf)
Ok(serial_type.parse(buf))
}
}

Expand Down Expand Up @@ -674,22 +668,15 @@ mod tests {
#[test]
fn test_parse_real() {
assert_eq!(
SerialType(7).parse(0_f64.to_be_bytes().as_slice()).unwrap(),
SerialType(7).parse(0_f64.to_be_bytes().as_slice()),
Some(Value::Real(0.0))
);
assert_eq!(
SerialType(7)
.parse(1.1_f64.to_be_bytes().as_slice())
.unwrap(),
SerialType(7).parse(1.1_f64.to_be_bytes().as_slice()),
Some(Value::Real(1.1))
);
// NaN
assert_eq!(
SerialType(7)
.parse(f64::NAN.to_be_bytes().as_slice())
.unwrap(),
None
);
assert_eq!(SerialType(7).parse(f64::NAN.to_be_bytes().as_slice()), None);
}

#[test]
Expand Down
Loading

0 comments on commit 976fdc9

Please sign in to comment.