From db7ec5c2d0a726b99daf014a70cdee8c15d3721b Mon Sep 17 00:00:00 2001 From: Dmitry Patsura Date: Tue, 21 Jun 2022 20:43:58 +0300 Subject: [PATCH] feat(cubesql): Support Numeric type (text + binary) in pg-wire --- rust/cubesql/Cargo.lock | 15 +++-- rust/cubesql/cubesql/Cargo.toml | 2 + rust/cubesql/cubesql/e2e/tests/postgres.rs | 4 ++ .../e2e__tests__postgres__pg_test_types.snap | 18 +++-- .../cubesql/src/compile/engine/provider.rs | 1 + rust/cubesql/cubesql/src/error.rs | 6 ++ rust/cubesql/cubesql/src/sql/dataframe.rs | 65 ++++++++++++++++--- .../cubesql/src/sql/postgres/extended.rs | 1 + .../cubesql/src/sql/postgres/pg_type.rs | 1 + .../cubesql/src/sql/postgres/writer.rs | 60 +++++++++++++++-- rust/cubesql/cubesql/src/sql/types.rs | 2 + 11 files changed, 148 insertions(+), 27 deletions(-) diff --git a/rust/cubesql/Cargo.lock b/rust/cubesql/Cargo.lock index 743828da6e6b..042aa3cb5296 100644 --- a/rust/cubesql/Cargo.lock +++ b/rust/cubesql/Cargo.lock @@ -792,6 +792,7 @@ dependencies = [ "paste", "pg-srv", "portpicker", + "postgres-types", "pretty_assertions", "rand 0.8.5", "regex", @@ -2621,9 +2622,9 @@ dependencies = [ [[package]] name = "postgres-protocol" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79ec03bce71f18b4a27c4c64c6ba2ddf74686d69b91d8714fb32ead3adaed713" +checksum = "878c6cbf956e03af9aa8204b407b9cbf47c072164800aa918c516cd4b056c50c" dependencies = [ "base64 0.13.0", "byteorder", @@ -2639,9 +2640,9 @@ dependencies = [ [[package]] name = "postgres-types" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04619f94ba0cc80999f4fc7073607cb825bc739a883cb6d20900fc5e009d6b0d" +checksum = "ebd6e8b7189a73169290e89bd24c771071f1012d8fe6f738f5226531f0b03d89" dependencies = [ "bytes 1.1.0", "chrono", @@ -3025,9 +3026,9 @@ dependencies = [ [[package]] name = "rust_decimal" -version = "1.23.1" +version = "1.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22dc69eadbf0ee2110b8d20418c0c6edbaefec2811c4963dc17b6344e11fe0f8" +checksum = "34a3bb58e85333f1ab191bf979104b586ebd77475bc6681882825f4532dfe87c" dependencies = [ "arrayvec 0.7.2", "byteorder", @@ -3801,7 +3802,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ee73e6e4924fe940354b8d4d98cad5231175d615cd855b758adc658c0aac6a0" dependencies = [ "cfg-if 1.0.0", - "rand 0.6.5", + "rand 0.8.5", "static_assertions", ] diff --git a/rust/cubesql/cubesql/Cargo.toml b/rust/cubesql/cubesql/Cargo.toml index d1d7eed379a0..8d3bf140df9c 100644 --- a/rust/cubesql/cubesql/Cargo.toml +++ b/rust/cubesql/cubesql/Cargo.toml @@ -28,6 +28,8 @@ rand = "0.8.3" smallvec = "1.7.0" byteorder = "1.3.4" log = "=0.4.11" +rust_decimal = { version = "1.25", features = ["c-repr", "db-postgres"]} +postgres-types = "0.2.3" # Locked, because starting from 1.15 this crate switch from chrono to time # which panic with Could not determine the UTC offset on this system. # It's a problem with determing local_offset_at for local-offset feature diff --git a/rust/cubesql/cubesql/e2e/tests/postgres.rs b/rust/cubesql/cubesql/e2e/tests/postgres.rs index b7337205b8e2..f9ac966aaff0 100644 --- a/rust/cubesql/cubesql/e2e/tests/postgres.rs +++ b/rust/cubesql/cubesql/e2e/tests/postgres.rs @@ -605,6 +605,10 @@ impl AsyncTestSuite for PostgresIntegrationTestSuite { true as bool_true, false as bool_false, 'test' as str, + CAST(1.25 as DECIMAL(15, 0)) as d0, + CAST(1.25 as DECIMAL(15, 2)) as d2, + CAST(1.25 as DECIMAL(15, 5)) as d5, + CAST(1.25 as DECIMAL(15, 10)) as d10, ARRAY['test1', 'test2'] as str_arr, ARRAY[1,2,3] as i64_arr, ARRAY[1.2,2.3,3.4] as f64_arr, diff --git a/rust/cubesql/cubesql/e2e/tests/snapshots/e2e__tests__postgres__pg_test_types.snap b/rust/cubesql/cubesql/e2e/tests/snapshots/e2e__tests__postgres__pg_test_types.snap index 360cee5589fc..a685d1d31d1b 100644 --- a/rust/cubesql/cubesql/e2e/tests/snapshots/e2e__tests__postgres__pg_test_types.snap +++ b/rust/cubesql/cubesql/e2e/tests/snapshots/e2e__tests__postgres__pg_test_types.snap @@ -1,7 +1,7 @@ --- source: cubesql/e2e/tests/postgres.rs -assertion_line: 233 -expression: "self.print_query_result(res, with_description).await" +assertion_line: 261 +expression: "self.print_query_result(res, with_description, true).await" --- Utf8(NULL) type: 25 (text) f32 type: 700 (float4) @@ -15,12 +15,16 @@ u64 type: 20 (int8) bool_true type: 16 (bool) bool_false type: 16 (bool) str type: 25 (text) +d0 type: 1700 (numeric) +d2 type: 1700 (numeric) +d5 type: 1700 (numeric) +d10 type: 1700 (numeric) str_arr type: 1009 (_text) i64_arr type: 1016 (_int8) f64_arr type: 1022 (_float8) tsmp type: 1114 (timestamp) -+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+-------------+---------+-------------+----------------------------+ -| Utf8(NULL) | f32 | f64 | i16 | u16 | i32 | u32 | i64 | u64 | bool_true | bool_false | str | str_arr | i64_arr | f64_arr | tsmp | -+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+-------------+---------+-------------+----------------------------+ -| NULL | 1.234 | 1.234 | 1 | 1 | 1 | 1 | 1 | 1 | true | false | test | test1,test2 | 1,2,3 | 1.2,2.3,3.4 | 2022-04-25 16:25:01.164774 | -+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+-------------+---------+-------------+----------------------------+ ++------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+-------------+---------+-------------+----------------------------+ +| Utf8(NULL) | f32 | f64 | i16 | u16 | i32 | u32 | i64 | u64 | bool_true | bool_false | str | d0 | d2 | d5 | d10 | str_arr | i64_arr | f64_arr | tsmp | ++------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+-------------+---------+-------------+----------------------------+ +| NULL | 1.234 | 1.234 | 1 | 1 | 1 | 1 | 1 | 1 | true | false | test | 1 | 1.25 | 1.25000 | 1.2500000000 | test1,test2 | 1,2,3 | 1.2,2.3,3.4 | 2022-04-25 16:25:01.164774 | ++------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+-------------+---------+-------------+----------------------------+ diff --git a/rust/cubesql/cubesql/src/compile/engine/provider.rs b/rust/cubesql/cubesql/src/compile/engine/provider.rs index f5e369db8e2b..ef929a1ea679 100644 --- a/rust/cubesql/cubesql/src/compile/engine/provider.rs +++ b/rust/cubesql/cubesql/src/compile/engine/provider.rs @@ -503,6 +503,7 @@ impl TableProvider for CubeTableProvider { ColumnType::Int32 => DataType::Int64, ColumnType::Int64 => DataType::Int64, ColumnType::Blob => DataType::Utf8, + ColumnType::Decimal(p, s) => DataType::Decimal(p, s), ColumnType::List(field) => DataType::List(field.clone()), ColumnType::Timestamp => { DataType::Timestamp(TimeUnit::Millisecond, None) diff --git a/rust/cubesql/cubesql/src/error.rs b/rust/cubesql/cubesql/src/error.rs index d71ab8140de7..2d880a9e7092 100644 --- a/rust/cubesql/cubesql/src/error.rs +++ b/rust/cubesql/cubesql/src/error.rs @@ -139,6 +139,12 @@ impl From for CubeError { } } +impl From for CubeError { + fn from(v: rust_decimal::Error) -> Self { + CubeError::internal(format!("{:?}", v)) + } +} + impl From for CubeError { fn from(v: tokio::task::JoinError) -> Self { CubeError::internal(v.to_string()) diff --git a/rust/cubesql/cubesql/src/sql/dataframe.rs b/rust/cubesql/cubesql/src/sql/dataframe.rs index 3cea1a4ffc56..4d1942119733 100644 --- a/rust/cubesql/cubesql/src/sql/dataframe.rs +++ b/rust/cubesql/cubesql/src/sql/dataframe.rs @@ -10,16 +10,17 @@ use chrono_tz::Tz; use comfy_table::{Cell, Table}; use datafusion::arrow::{ array::{ - Array, ArrayRef, BooleanArray, Float16Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, - IntervalYearMonthArray, LargeStringArray, ListArray, StringArray, - TimestampMicrosecondArray, TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array, - UInt8Array, + Array, ArrayRef, BooleanArray, DecimalArray, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeStringArray, ListArray, + StringArray, TimestampMicrosecondArray, TimestampNanosecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }, datatypes::{DataType, IntervalUnit, TimeUnit}, record_batch::RecordBatch, temporal_conversions, }; +use rust_decimal::prelude::*; use std::{ fmt::{self, Debug, Formatter}, io, @@ -92,9 +93,10 @@ pub enum TableValue { Int32(i32), Int64(i64), Boolean(bool), - List(ArrayRef), Float32(f32), Float64(f64), + List(ArrayRef), + Decimal128(Decimal128Value), Timestamp(TimestampValue), } @@ -110,6 +112,7 @@ impl ToString for TableValue { TableValue::Float32(v) => v.to_string(), TableValue::Float64(v) => v.to_string(), TableValue::Timestamp(v) => v.to_string(), + TableValue::Decimal128(v) => v.to_string(), TableValue::List(v) => { let mut values: Vec = Vec::with_capacity(v.len()); @@ -283,6 +286,46 @@ impl ToString for TimestampValue { } } +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct Decimal128Value { + n: i128, + // number of digits after . + scale: usize, +} + +impl Decimal128Value { + pub fn new(n: i128, scale: usize) -> Self { + Self { n, scale } + } + + pub fn as_decimal(&self) -> Result { + Ok(Decimal::try_from_i128_with_scale( + self.n, + self.scale as u32, + )?) + } +} + +impl ToString for Decimal128Value { + fn to_string(&self) -> String { + let as_str = self.n.to_string(); + + if self.scale == 0 { + as_str + } else { + let (sign, rest) = as_str.split_at(if self.n >= 0 { 0 } else { 1 }); + + if rest.len() > self.scale { + let (whole, decimal) = as_str.split_at(as_str.len() - self.scale); + format!("{}.{}", whole, decimal) + } else { + // String has to be padded + format!("{}0.{:0>w$}", sign, rest, w = self.scale) + } + } + } +} + macro_rules! convert_array_cast_native { ($V: expr, (Vec)) => {{ $V.to_vec() @@ -315,6 +358,7 @@ pub fn arrow_to_column_type(arrow_type: DataType) -> Result Ok(ColumnType::Boolean), DataType::List(field) => Ok(ColumnType::List(field)), DataType::Int32 | DataType::UInt32 => Ok(ColumnType::Int32), + DataType::Decimal(_, _) => Ok(ColumnType::Int32), DataType::Int8 | DataType::Int16 | DataType::Int64 @@ -359,6 +403,9 @@ pub fn batch_to_dataframe(batches: &Vec) -> Result convert_array!(array, num_rows, rows, Int32Array, Int32, i32), DataType::UInt64 => convert_array!(array, num_rows, rows, UInt64Array, Int64, i64), DataType::Int64 => convert_array!(array, num_rows, rows, Int64Array, Int64, i64), + DataType::Boolean => { + convert_array!(array, num_rows, rows, BooleanArray, Boolean, bool) + } DataType::Float32 => { convert_array!(array, num_rows, rows, Float32Array, Float32, f32) } @@ -443,13 +490,13 @@ pub fn batch_to_dataframe(batches: &Vec) -> Result { - let a = array.as_any().downcast_ref::().unwrap(); + DataType::Decimal(_, s) => { + let a = array.as_any().downcast_ref::().unwrap(); for i in 0..num_rows { rows[i].push(if a.is_null(i) { TableValue::Null } else { - TableValue::Boolean(a.value(i)) + TableValue::Decimal128(Decimal128Value::new(a.value(i), *s)) }); } } diff --git a/rust/cubesql/cubesql/src/sql/postgres/extended.rs b/rust/cubesql/cubesql/src/sql/postgres/extended.rs index 93bb0ce1567c..d147c1dc5be2 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/extended.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/extended.rs @@ -205,6 +205,7 @@ impl Portal { TableValue::Float64(v) => writer.write_value(*v)?, TableValue::List(v) => writer.write_value(v.clone())?, TableValue::Timestamp(v) => writer.write_value(v.clone())?, + TableValue::Decimal128(v) => writer.write_value(v.clone())?, }; } diff --git a/rust/cubesql/cubesql/src/sql/postgres/pg_type.rs b/rust/cubesql/cubesql/src/sql/postgres/pg_type.rs index 6d4332672bab..76c16e51defe 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/pg_type.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/pg_type.rs @@ -12,6 +12,7 @@ pub fn df_type_to_pg_tid(dt: &DataType) -> Result { DataType::UInt64 => Ok(PgTypeId::INT8), DataType::Float32 => Ok(PgTypeId::FLOAT4), DataType::Float64 => Ok(PgTypeId::FLOAT8), + DataType::Decimal(_, _) => Ok(PgTypeId::NUMERIC), DataType::Utf8 | DataType::LargeUtf8 => Ok(PgTypeId::TEXT), DataType::Timestamp(_, tz) => match tz { None => Ok(PgTypeId::TIMESTAMP), diff --git a/rust/cubesql/cubesql/src/sql/postgres/writer.rs b/rust/cubesql/cubesql/src/sql/postgres/writer.rs index fc7e6abb009e..0fec0638e124 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/writer.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/writer.rs @@ -1,4 +1,7 @@ -use crate::sql::{dataframe::TimestampValue, df_type_to_pg_tid}; +use crate::sql::{ + dataframe::{Decimal128Value, TimestampValue}, + df_type_to_pg_tid, +}; use bytes::{BufMut, BytesMut}; use chrono::{ format::{ @@ -17,9 +20,10 @@ use datafusion::arrow::{ }; use pg_srv::{ protocol, - protocol::{Format, Serialize}, - ProtocolError, + protocol::{ErrorCode, ErrorResponse, Format, Serialize}, + PgTypeId, ProtocolError, }; +use postgres_types::{ToSql, Type}; use std::{convert::TryFrom, io, io::Error, mem}; pub trait ToPostgresValue { @@ -160,6 +164,26 @@ impl ToPostgresValue for Option { } } +/// https://github.com/postgres/postgres/blob/REL_14_4/src/backend/utils/adt/numeric.c#L1022 +impl ToPostgresValue for Decimal128Value { + fn to_text(&self, buf: &mut BytesMut) -> Result<(), ProtocolError> { + self.to_string().to_text(buf) + } + + fn to_binary(&self, buf: &mut BytesMut) -> Result<(), ProtocolError> { + let mut tmp = postgres_types::private::BytesMut::new(); + self.as_decimal() + .map_err(|err| ErrorResponse::error(ErrorCode::InternalError, err.to_string()))? + .to_sql(&Type::from_oid(PgTypeId::NUMERIC as u32).unwrap(), &mut tmp) + .map_err(|err| ErrorResponse::error(ErrorCode::InternalError, err.to_string()))?; + + buf.put_i32(tmp.len() as i32); + buf.extend_from_slice(&tmp[..]); + + Ok(()) + } +} + impl ToPostgresValue for ArrayRef { fn to_text(&self, buf: &mut BytesMut) -> Result<(), ProtocolError> { let mut values: Vec = Vec::with_capacity(self.len()); @@ -381,7 +405,7 @@ impl<'a> Serialize for BatchWriter { #[cfg(test)] mod tests { use crate::sql::{ - dataframe::TimestampValue, + dataframe::{Decimal128Value, TimestampValue}, shim::ConnectionError, writer::{BatchWriter, ToPostgresValue}, }; @@ -499,6 +523,34 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_backend_writer_binary_numeric() -> Result<(), ConnectionError> { + // DECLARE test BINARY CURSOR FOR SELECT CAST(1 as decimal(10, 5)) UNION ALL SELECT CAST(2 as decimal(25, 15)); + // fetch 2 in test; + let mut cursor = Cursor::new(vec![]); + + let mut writer = BatchWriter::new(Format::Binary); + writer.write_value(Decimal128Value::new(1, 5))?; + writer.end_row()?; + + writer.write_value(Decimal128Value::new(2, 15))?; + writer.end_row()?; + + buffer::write_direct(&mut cursor, writer).await?; + + assert_eq!( + cursor.get_ref()[0..], + vec![ + // row + 68, 0, 0, 0, 20, 0, 1, 0, 0, 0, 10, 0, 1, 255, 254, 0, 0, 0, 5, 3, 232, + // row + 68, 0, 0, 0, 20, 0, 1, 0, 0, 0, 10, 0, 1, 255, 252, 0, 0, 0, 15, 0, 20 + ] + ); + + Ok(()) + } + #[tokio::test] async fn test_backend_writer_binary_int8_array() -> Result<(), ConnectionError> { let mut cursor = Cursor::new(vec![]); diff --git a/rust/cubesql/cubesql/src/sql/types.rs b/rust/cubesql/cubesql/src/sql/types.rs index 270a0182c27d..52741cb7ecae 100644 --- a/rust/cubesql/cubesql/src/sql/types.rs +++ b/rust/cubesql/cubesql/src/sql/types.rs @@ -16,6 +16,7 @@ pub enum ColumnType { Int64, Blob, Timestamp, + Decimal(usize, usize), List(Box), } @@ -42,6 +43,7 @@ impl ColumnType { ColumnType::String | ColumnType::VarStr => PgTypeId::TEXT, ColumnType::Timestamp => PgTypeId::TIMESTAMP, ColumnType::Double => PgTypeId::NUMERIC, + ColumnType::Decimal(_, _) => PgTypeId::NUMERIC, ColumnType::List(field) => match field.data_type() { DataType::Binary => PgTypeId::ARRAYBYTEA, DataType::Boolean => PgTypeId::ARRAYBOOL,