Skip to content

Commit

Permalink
feat(cubesql): Support Numeric type (text + binary) in pg-wire
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr committed Jun 22, 2022
1 parent 779d35d commit db7ec5c
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 27 deletions.
15 changes: 8 additions & 7 deletions rust/cubesql/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions rust/cubesql/cubesql/Cargo.toml
Expand Up @@ -28,6 +28,8 @@ rand = "0.8.3"
smallvec = "1.7.0"
byteorder = "1.3.4"
log = "=0.4.11"
rust_decimal = { version = "1.25", features = ["c-repr", "db-postgres"]}
postgres-types = "0.2.3"
# Locked, because starting from 1.15 this crate switch from chrono to time
# which panic with Could not determine the UTC offset on this system.
# It's a problem with determing local_offset_at for local-offset feature
Expand Down
4 changes: 4 additions & 0 deletions rust/cubesql/cubesql/e2e/tests/postgres.rs
Expand Up @@ -605,6 +605,10 @@ impl AsyncTestSuite for PostgresIntegrationTestSuite {
true as bool_true,
false as bool_false,
'test' as str,
CAST(1.25 as DECIMAL(15, 0)) as d0,
CAST(1.25 as DECIMAL(15, 2)) as d2,
CAST(1.25 as DECIMAL(15, 5)) as d5,
CAST(1.25 as DECIMAL(15, 10)) as d10,
ARRAY['test1', 'test2'] as str_arr,
ARRAY[1,2,3] as i64_arr,
ARRAY[1.2,2.3,3.4] as f64_arr,
Expand Down
@@ -1,7 +1,7 @@
---
source: cubesql/e2e/tests/postgres.rs
assertion_line: 233
expression: "self.print_query_result(res, with_description).await"
assertion_line: 261
expression: "self.print_query_result(res, with_description, true).await"
---
Utf8(NULL) type: 25 (text)
f32 type: 700 (float4)
Expand All @@ -15,12 +15,16 @@ u64 type: 20 (int8)
bool_true type: 16 (bool)
bool_false type: 16 (bool)
str type: 25 (text)
d0 type: 1700 (numeric)
d2 type: 1700 (numeric)
d5 type: 1700 (numeric)
d10 type: 1700 (numeric)
str_arr type: 1009 (_text)
i64_arr type: 1016 (_int8)
f64_arr type: 1022 (_float8)
tsmp type: 1114 (timestamp)
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+-------------+---------+-------------+----------------------------+
| Utf8(NULL) | f32 | f64 | i16 | u16 | i32 | u32 | i64 | u64 | bool_true | bool_false | str | str_arr | i64_arr | f64_arr | tsmp |
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+-------------+---------+-------------+----------------------------+
| NULL | 1.234 | 1.234 | 1 | 1 | 1 | 1 | 1 | 1 | true | false | test | test1,test2 | 1,2,3 | 1.2,2.3,3.4 | 2022-04-25 16:25:01.164774 |
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+-------------+---------+-------------+----------------------------+
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+-------------+---------+-------------+----------------------------+
| Utf8(NULL) | f32 | f64 | i16 | u16 | i32 | u32 | i64 | u64 | bool_true | bool_false | str | d0 | d2 | d5 | d10 | str_arr | i64_arr | f64_arr | tsmp |
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+-------------+---------+-------------+----------------------------+
| NULL | 1.234 | 1.234 | 1 | 1 | 1 | 1 | 1 | 1 | true | false | test | 1 | 1.25 | 1.25000 | 1.2500000000 | test1,test2 | 1,2,3 | 1.2,2.3,3.4 | 2022-04-25 16:25:01.164774 |
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+-------------+---------+-------------+----------------------------+
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/compile/engine/provider.rs
Expand Up @@ -503,6 +503,7 @@ impl TableProvider for CubeTableProvider {
ColumnType::Int32 => DataType::Int64,
ColumnType::Int64 => DataType::Int64,
ColumnType::Blob => DataType::Utf8,
ColumnType::Decimal(p, s) => DataType::Decimal(p, s),
ColumnType::List(field) => DataType::List(field.clone()),
ColumnType::Timestamp => {
DataType::Timestamp(TimeUnit::Millisecond, None)
Expand Down
6 changes: 6 additions & 0 deletions rust/cubesql/cubesql/src/error.rs
Expand Up @@ -139,6 +139,12 @@ impl From<ParserError> for CubeError {
}
}

impl From<rust_decimal::Error> for CubeError {
fn from(v: rust_decimal::Error) -> Self {
CubeError::internal(format!("{:?}", v))
}
}

impl From<tokio::task::JoinError> for CubeError {
fn from(v: tokio::task::JoinError) -> Self {
CubeError::internal(v.to_string())
Expand Down
65 changes: 56 additions & 9 deletions rust/cubesql/cubesql/src/sql/dataframe.rs
Expand Up @@ -10,16 +10,17 @@ use chrono_tz::Tz;
use comfy_table::{Cell, Table};
use datafusion::arrow::{
array::{
Array, ArrayRef, BooleanArray, Float16Array, Float32Array, Float64Array, Int16Array,
Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
IntervalYearMonthArray, LargeStringArray, ListArray, StringArray,
TimestampMicrosecondArray, TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array,
UInt8Array,
Array, ArrayRef, BooleanArray, DecimalArray, Float16Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeStringArray, ListArray,
StringArray, TimestampMicrosecondArray, TimestampNanosecondArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array,
},
datatypes::{DataType, IntervalUnit, TimeUnit},
record_batch::RecordBatch,
temporal_conversions,
};
use rust_decimal::prelude::*;
use std::{
fmt::{self, Debug, Formatter},
io,
Expand Down Expand Up @@ -92,9 +93,10 @@ pub enum TableValue {
Int32(i32),
Int64(i64),
Boolean(bool),
List(ArrayRef),
Float32(f32),
Float64(f64),
List(ArrayRef),
Decimal128(Decimal128Value),
Timestamp(TimestampValue),
}

Expand All @@ -110,6 +112,7 @@ impl ToString for TableValue {
TableValue::Float32(v) => v.to_string(),
TableValue::Float64(v) => v.to_string(),
TableValue::Timestamp(v) => v.to_string(),
TableValue::Decimal128(v) => v.to_string(),
TableValue::List(v) => {
let mut values: Vec<String> = Vec::with_capacity(v.len());

Expand Down Expand Up @@ -283,6 +286,46 @@ impl ToString for TimestampValue {
}
}

#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct Decimal128Value {
n: i128,
// number of digits after .
scale: usize,
}

impl Decimal128Value {
pub fn new(n: i128, scale: usize) -> Self {
Self { n, scale }
}

pub fn as_decimal(&self) -> Result<Decimal, CubeError> {
Ok(Decimal::try_from_i128_with_scale(
self.n,
self.scale as u32,
)?)
}
}

impl ToString for Decimal128Value {
fn to_string(&self) -> String {
let as_str = self.n.to_string();

if self.scale == 0 {
as_str
} else {
let (sign, rest) = as_str.split_at(if self.n >= 0 { 0 } else { 1 });

if rest.len() > self.scale {
let (whole, decimal) = as_str.split_at(as_str.len() - self.scale);
format!("{}.{}", whole, decimal)
} else {
// String has to be padded
format!("{}0.{:0>w$}", sign, rest, w = self.scale)
}
}
}
}

macro_rules! convert_array_cast_native {
($V: expr, (Vec<u8>)) => {{
$V.to_vec()
Expand Down Expand Up @@ -315,6 +358,7 @@ pub fn arrow_to_column_type(arrow_type: DataType) -> Result<ColumnType, CubeErro
DataType::Boolean => Ok(ColumnType::Boolean),
DataType::List(field) => Ok(ColumnType::List(field)),
DataType::Int32 | DataType::UInt32 => Ok(ColumnType::Int32),
DataType::Decimal(_, _) => Ok(ColumnType::Int32),
DataType::Int8
| DataType::Int16
| DataType::Int64
Expand Down Expand Up @@ -359,6 +403,9 @@ pub fn batch_to_dataframe(batches: &Vec<RecordBatch>) -> Result<DataFrame, CubeE
DataType::Int32 => convert_array!(array, num_rows, rows, Int32Array, Int32, i32),
DataType::UInt64 => convert_array!(array, num_rows, rows, UInt64Array, Int64, i64),
DataType::Int64 => convert_array!(array, num_rows, rows, Int64Array, Int64, i64),
DataType::Boolean => {
convert_array!(array, num_rows, rows, BooleanArray, Boolean, bool)
}
DataType::Float32 => {
convert_array!(array, num_rows, rows, Float32Array, Float32, f32)
}
Expand Down Expand Up @@ -443,13 +490,13 @@ pub fn batch_to_dataframe(batches: &Vec<RecordBatch>) -> Result<DataFrame, CubeE
}
}
}
DataType::Boolean => {
let a = array.as_any().downcast_ref::<BooleanArray>().unwrap();
DataType::Decimal(_, s) => {
let a = array.as_any().downcast_ref::<DecimalArray>().unwrap();
for i in 0..num_rows {
rows[i].push(if a.is_null(i) {
TableValue::Null
} else {
TableValue::Boolean(a.value(i))
TableValue::Decimal128(Decimal128Value::new(a.value(i), *s))
});
}
}
Expand Down
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/sql/postgres/extended.rs
Expand Up @@ -205,6 +205,7 @@ impl Portal {
TableValue::Float64(v) => writer.write_value(*v)?,
TableValue::List(v) => writer.write_value(v.clone())?,
TableValue::Timestamp(v) => writer.write_value(v.clone())?,
TableValue::Decimal128(v) => writer.write_value(v.clone())?,
};
}

Expand Down
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/sql/postgres/pg_type.rs
Expand Up @@ -12,6 +12,7 @@ pub fn df_type_to_pg_tid(dt: &DataType) -> Result<PgTypeId, ProtocolError> {
DataType::UInt64 => Ok(PgTypeId::INT8),
DataType::Float32 => Ok(PgTypeId::FLOAT4),
DataType::Float64 => Ok(PgTypeId::FLOAT8),
DataType::Decimal(_, _) => Ok(PgTypeId::NUMERIC),
DataType::Utf8 | DataType::LargeUtf8 => Ok(PgTypeId::TEXT),
DataType::Timestamp(_, tz) => match tz {
None => Ok(PgTypeId::TIMESTAMP),
Expand Down
60 changes: 56 additions & 4 deletions rust/cubesql/cubesql/src/sql/postgres/writer.rs
@@ -1,4 +1,7 @@
use crate::sql::{dataframe::TimestampValue, df_type_to_pg_tid};
use crate::sql::{
dataframe::{Decimal128Value, TimestampValue},
df_type_to_pg_tid,
};
use bytes::{BufMut, BytesMut};
use chrono::{
format::{
Expand All @@ -17,9 +20,10 @@ use datafusion::arrow::{
};
use pg_srv::{
protocol,
protocol::{Format, Serialize},
ProtocolError,
protocol::{ErrorCode, ErrorResponse, Format, Serialize},
PgTypeId, ProtocolError,
};
use postgres_types::{ToSql, Type};
use std::{convert::TryFrom, io, io::Error, mem};

pub trait ToPostgresValue {
Expand Down Expand Up @@ -160,6 +164,26 @@ impl<T: ToPostgresValue> ToPostgresValue for Option<T> {
}
}

/// https://github.com/postgres/postgres/blob/REL_14_4/src/backend/utils/adt/numeric.c#L1022
impl ToPostgresValue for Decimal128Value {
fn to_text(&self, buf: &mut BytesMut) -> Result<(), ProtocolError> {
self.to_string().to_text(buf)
}

fn to_binary(&self, buf: &mut BytesMut) -> Result<(), ProtocolError> {
let mut tmp = postgres_types::private::BytesMut::new();
self.as_decimal()
.map_err(|err| ErrorResponse::error(ErrorCode::InternalError, err.to_string()))?
.to_sql(&Type::from_oid(PgTypeId::NUMERIC as u32).unwrap(), &mut tmp)
.map_err(|err| ErrorResponse::error(ErrorCode::InternalError, err.to_string()))?;

buf.put_i32(tmp.len() as i32);
buf.extend_from_slice(&tmp[..]);

Ok(())
}
}

impl ToPostgresValue for ArrayRef {
fn to_text(&self, buf: &mut BytesMut) -> Result<(), ProtocolError> {
let mut values: Vec<String> = Vec::with_capacity(self.len());
Expand Down Expand Up @@ -381,7 +405,7 @@ impl<'a> Serialize for BatchWriter {
#[cfg(test)]
mod tests {
use crate::sql::{
dataframe::TimestampValue,
dataframe::{Decimal128Value, TimestampValue},
shim::ConnectionError,
writer::{BatchWriter, ToPostgresValue},
};
Expand Down Expand Up @@ -499,6 +523,34 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_backend_writer_binary_numeric() -> Result<(), ConnectionError> {
// DECLARE test BINARY CURSOR FOR SELECT CAST(1 as decimal(10, 5)) UNION ALL SELECT CAST(2 as decimal(25, 15));
// fetch 2 in test;
let mut cursor = Cursor::new(vec![]);

let mut writer = BatchWriter::new(Format::Binary);
writer.write_value(Decimal128Value::new(1, 5))?;
writer.end_row()?;

writer.write_value(Decimal128Value::new(2, 15))?;
writer.end_row()?;

buffer::write_direct(&mut cursor, writer).await?;

assert_eq!(
cursor.get_ref()[0..],
vec![
// row
68, 0, 0, 0, 20, 0, 1, 0, 0, 0, 10, 0, 1, 255, 254, 0, 0, 0, 5, 3, 232,
// row
68, 0, 0, 0, 20, 0, 1, 0, 0, 0, 10, 0, 1, 255, 252, 0, 0, 0, 15, 0, 20
]
);

Ok(())
}

#[tokio::test]
async fn test_backend_writer_binary_int8_array() -> Result<(), ConnectionError> {
let mut cursor = Cursor::new(vec![]);
Expand Down
2 changes: 2 additions & 0 deletions rust/cubesql/cubesql/src/sql/types.rs
Expand Up @@ -16,6 +16,7 @@ pub enum ColumnType {
Int64,
Blob,
Timestamp,
Decimal(usize, usize),
List(Box<Field>),
}

Expand All @@ -42,6 +43,7 @@ impl ColumnType {
ColumnType::String | ColumnType::VarStr => PgTypeId::TEXT,
ColumnType::Timestamp => PgTypeId::TIMESTAMP,
ColumnType::Double => PgTypeId::NUMERIC,
ColumnType::Decimal(_, _) => PgTypeId::NUMERIC,
ColumnType::List(field) => match field.data_type() {
DataType::Binary => PgTypeId::ARRAYBYTEA,
DataType::Boolean => PgTypeId::ARRAYBOOL,
Expand Down

0 comments on commit db7ec5c

Please sign in to comment.