From c9995950f8801eca384fd89ac7d310205ac0ce8d Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Sat, 5 Mar 2022 20:43:42 +0100 Subject: [PATCH] Support to read/write from/to ODBC (#849) Co-authored-by: Markus Klein --- .github/workflows/integration-odbc.yml | 40 ++++ Cargo.toml | 5 + README.md | 5 +- arrow-odbc-integration-testing/Cargo.toml | 11 ++ .../docker-compose.yml | 10 + arrow-odbc-integration-testing/src/lib.rs | 41 ++++ arrow-odbc-integration-testing/src/read.rs | 139 ++++++++++++++ arrow-odbc-integration-testing/src/write.rs | 133 +++++++++++++ examples/io_odbc.rs | 83 ++++++++ guide/src/io/README.md | 1 + guide/src/io/odbc.md | 8 + src/io/mod.rs | 4 + src/io/odbc/mod.rs | 11 ++ src/io/odbc/read/deserialize.rs | 141 ++++++++++++++ src/io/odbc/read/mod.rs | 37 ++++ src/io/odbc/read/schema.rs | 80 ++++++++ src/io/odbc/write/mod.rs | 71 +++++++ src/io/odbc/write/schema.rs | 38 ++++ src/io/odbc/write/serialize.rs | 179 ++++++++++++++++++ 19 files changed, 1036 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/integration-odbc.yml create mode 100644 arrow-odbc-integration-testing/Cargo.toml create mode 100644 arrow-odbc-integration-testing/docker-compose.yml create mode 100644 arrow-odbc-integration-testing/src/lib.rs create mode 100644 arrow-odbc-integration-testing/src/read.rs create mode 100644 arrow-odbc-integration-testing/src/write.rs create mode 100644 examples/io_odbc.rs create mode 100644 guide/src/io/odbc.md create mode 100644 src/io/odbc/mod.rs create mode 100644 src/io/odbc/read/deserialize.rs create mode 100644 src/io/odbc/read/mod.rs create mode 100644 src/io/odbc/read/schema.rs create mode 100644 src/io/odbc/write/mod.rs create mode 100644 src/io/odbc/write/schema.rs create mode 100644 src/io/odbc/write/serialize.rs diff --git a/.github/workflows/integration-odbc.yml b/.github/workflows/integration-odbc.yml new file mode 100644 index 00000000000..2ee3bf096f2 --- /dev/null +++ b/.github/workflows/integration-odbc.yml @@ -0,0 +1,40 @@ +name: Integration ODBC + +on: [push, pull_request] + +env: + CARGO_TERM_COLOR: always + +jobs: + linux: + name: Test + runs-on: ubuntu-latest + + services: + sqlserver: + image: mcr.microsoft.com/mssql/server:2017-latest-ubuntu + ports: + - 1433:1433 + env: + ACCEPT_EULA: Y + SA_PASSWORD: My@Test@Password1 + + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Install ODBC Drivers + run: | + curl https://packages.microsoft.com/keys/microsoft.asc | apt-key add - + curl https://packages.microsoft.com/config/ubuntu/20.04/prod.list > /etc/apt/sources.list.d/mssql-release.list + apt-get update + ACCEPT_EULA=Y apt-get install -y msodbcsql17 + ln -s /opt/microsoft/msodbcsql17/lib64/libmsodbcsql-17.*.so.* /opt/microsoft/msodbcsql17/lib64/libmsodbcsql-17.so + shell: sudo bash {0} + - name: Setup Rust toolchain + run: | + rustup toolchain install stable + rustup default stable + rustup component add rustfmt clippy + - uses: Swatinem/rust-cache@v1 + - name: Test + run: cd arrow-odbc-integration-testing && cargo test diff --git a/Cargo.toml b/Cargo.toml index d5d9a832d66..c6a9e189f35 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,6 +87,9 @@ strength_reduce = { version = "0.2", optional = true } # For instruction multiversioning multiversion = { version = "0.6.1", optional = true } +# For support for odbc +odbc-api = { version = "0.35", optional = true } + [dev-dependencies] criterion = "0.3" flate2 = "1" @@ -106,6 +109,7 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = [] full = [ + "io_odbc", "io_csv", "io_csv_async", "io_json", @@ -126,6 +130,7 @@ full = [ # parses timezones used in timestamp conversions "chrono-tz", ] +io_odbc = ["odbc-api"] io_csv = ["io_csv_read", "io_csv_write"] io_csv_async = ["io_csv_read_async"] io_csv_read = ["csv", "lexical-core"] diff --git a/README.md b/README.md index e54883f6eed..f7457762491 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,9 @@ documentation of each of its APIs. * Apache Arrow IPC (all types) * Apache Arrow Flight (all types) * Apache Parquet (except deep nested types) - * Apache Avro (not all types yet) + * Apache Avro (all types) * NJSON + * ODBC (some types) * Extensive suite of compute operations * aggregations * arithmetics @@ -58,8 +59,10 @@ documentation of each of its APIs. This crate uses `unsafe` when strickly necessary: * when the compiler can't prove certain invariants and * FFI + We have extensive tests over these, all of which run and pass under MIRI. Most uses of `unsafe` fall into 3 categories: + * The Arrow format has invariants over utf8 that can't be written in safe Rust * `TrustedLen` and trait specialization are still nightly features * FFI diff --git a/arrow-odbc-integration-testing/Cargo.toml b/arrow-odbc-integration-testing/Cargo.toml new file mode 100644 index 00000000000..add16ba5d2d --- /dev/null +++ b/arrow-odbc-integration-testing/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "arrow-odbc-integration-testing" +version = "0.1.0" +authors = ["Jorge C. Leitao "] +edition = "2021" + +[dependencies] +arrow2 = { path = "../", default-features = false, features = ["io_odbc"] } +lazy_static = "1.4.0" +# Function name macro is used to ensure unique table names in test +stdext = "0.3.1" diff --git a/arrow-odbc-integration-testing/docker-compose.yml b/arrow-odbc-integration-testing/docker-compose.yml new file mode 100644 index 00000000000..9c344136652 --- /dev/null +++ b/arrow-odbc-integration-testing/docker-compose.yml @@ -0,0 +1,10 @@ +services: + + mssql: + image: mcr.microsoft.com/mssql/server:2019-latest + ports: + - 1433:1433 + + environment: + - MSSQL_SA_PASSWORD=My@Test@Password1 + command: ["/opt/mssql/bin/sqlservr", "--accept-eula", "--reset-sa-password"] diff --git a/arrow-odbc-integration-testing/src/lib.rs b/arrow-odbc-integration-testing/src/lib.rs new file mode 100644 index 00000000000..bfc24d65dc3 --- /dev/null +++ b/arrow-odbc-integration-testing/src/lib.rs @@ -0,0 +1,41 @@ +#![cfg(test)] + +mod read; +mod write; + +use arrow2::io::odbc::api::{Connection, Environment, Error as OdbcError}; +use lazy_static::lazy_static; + +lazy_static! { + /// This is an example for using doc comment attributes + pub static ref ENV: Environment = Environment::new().unwrap(); +} + +/// Connection string for our test instance of Microsoft SQL Server +const MSSQL: &str = + "Driver={ODBC Driver 17 for SQL Server};Server=localhost;UID=SA;PWD=My@Test@Password1;"; + +/// Creates the table and assures it is empty. Columns are named a,b,c, etc. +pub fn setup_empty_table( + conn: &Connection<'_>, + table_name: &str, + column_types: &[&str], +) -> std::result::Result<(), OdbcError> { + let drop_table = &format!("DROP TABLE IF EXISTS {}", table_name); + + let column_names = &["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"]; + let cols = column_types + .iter() + .zip(column_names) + .map(|(ty, name)| format!("{} {}", name, ty)) + .collect::>() + .join(", "); + + let create_table = format!( + "CREATE TABLE {} (id int IDENTITY(1,1),{});", + table_name, cols + ); + conn.execute(drop_table, ())?; + conn.execute(&create_table, ())?; + Ok(()) +} diff --git a/arrow-odbc-integration-testing/src/read.rs b/arrow-odbc-integration-testing/src/read.rs new file mode 100644 index 00000000000..145ec852dcf --- /dev/null +++ b/arrow-odbc-integration-testing/src/read.rs @@ -0,0 +1,139 @@ +use stdext::function_name; + +use arrow2::array::{Array, BinaryArray, BooleanArray, Int32Array, Utf8Array}; +use arrow2::chunk::Chunk; +use arrow2::datatypes::Field; +use arrow2::error::Result; +use arrow2::io::odbc::api::{Connection, Cursor}; +use arrow2::io::odbc::read::{buffer_from_metadata, deserialize, infer_schema}; + +use super::{setup_empty_table, ENV, MSSQL}; + +#[test] +fn int() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let expected = vec![Chunk::new(vec![Box::new(Int32Array::from_slice([1])) as _])]; + + test(expected, "INT", "(1)", table_name) +} + +#[test] +fn int_nullable() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let expected = vec![Chunk::new(vec![ + Box::new(Int32Array::from([Some(1), None])) as _, + ])]; + + test(expected, "INT", "(1),(NULL)", table_name) +} + +#[test] +fn bool() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let expected = vec![Chunk::new(vec![ + Box::new(BooleanArray::from_slice([true])) as _ + ])]; + + test(expected, "BIT", "(1)", table_name) +} + +#[test] +fn bool_nullable() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let expected = vec![Chunk::new(vec![ + Box::new(BooleanArray::from([Some(true), None])) as _, + ])]; + + test(expected, "BIT", "(1),(NULL)", table_name) +} + +#[test] +fn binary() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let expected = vec![Chunk::new(vec![ + Box::new(BinaryArray::::from([Some(b"ab")])) as _, + ])]; + + test( + expected, + "VARBINARY(2)", + "(CAST('ab' AS VARBINARY(2)))", + table_name, + ) +} + +#[test] +fn binary_nullable() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let expected = + vec![Chunk::new(vec![ + Box::new(BinaryArray::::from([Some(b"ab"), None, Some(b"ac")])) as _, + ])]; + + test( + expected, + "VARBINARY(2)", + "(CAST('ab' AS VARBINARY(2))),(NULL),(CAST('ac' AS VARBINARY(2)))", + table_name, + ) +} + +#[test] +fn utf8_nullable() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let expected = + vec![Chunk::new(vec![ + Box::new(Utf8Array::::from([Some("ab"), None, Some("ac")])) as _, + ])]; + + test(expected, "VARCHAR(2)", "('ab'),(NULL),('ac')", table_name) +} + +fn test( + expected: Vec>>, + type_: &str, + insert: &str, + table_name: &str, +) -> Result<()> { + let connection = ENV.connect_with_connection_string(MSSQL).unwrap(); + setup_empty_table(&connection, table_name, &[type_]).unwrap(); + connection + .execute(&format!("INSERT INTO {table_name} (a) VALUES {insert}"), ()) + .unwrap(); + + // When + let query = format!("SELECT a FROM {table_name} ORDER BY id"); + + let chunks = read(&connection, &query)?.1; + + assert_eq!(chunks, expected); + Ok(()) +} + +pub fn read( + connection: &Connection<'_>, + query: &str, +) -> Result<(Vec, Vec>>)> { + let mut a = connection.prepare(query).unwrap(); + let fields = infer_schema(&a)?; + + let max_batch_size = 100; + let buffer = buffer_from_metadata(&a, max_batch_size).unwrap(); + + let cursor = a.execute(()).unwrap().unwrap(); + let mut cursor = cursor.bind_buffer(buffer).unwrap(); + + let mut chunks = vec![]; + while let Some(batch) = cursor.fetch().unwrap() { + let arrays = (0..batch.num_cols()) + .zip(fields.iter()) + .map(|(index, field)| { + let column_view = batch.column(index); + deserialize(column_view, field.data_type.clone()) + }) + .collect::>(); + chunks.push(Chunk::new(arrays)); + } + + Ok((fields, chunks)) +} diff --git a/arrow-odbc-integration-testing/src/write.rs b/arrow-odbc-integration-testing/src/write.rs new file mode 100644 index 00000000000..bcf12761abd --- /dev/null +++ b/arrow-odbc-integration-testing/src/write.rs @@ -0,0 +1,133 @@ +use stdext::function_name; + +use arrow2::array::{Array, BinaryArray, BooleanArray, Int32Array, Utf8Array}; +use arrow2::chunk::Chunk; +use arrow2::datatypes::{DataType, Field}; +use arrow2::error::Result; +use arrow2::io::odbc::write::{buffer_from_description, infer_descriptions, serialize}; + +use super::read::read; +use super::{setup_empty_table, ENV, MSSQL}; + +fn test( + expected: Chunk>, + fields: Vec, + type_: &str, + table_name: &str, +) -> Result<()> { + let connection = ENV.connect_with_connection_string(MSSQL).unwrap(); + setup_empty_table(&connection, table_name, &[type_]).unwrap(); + + let query = &format!("INSERT INTO {table_name} (a) VALUES (?)"); + let mut a = connection.prepare(query).unwrap(); + + let mut buffer = buffer_from_description(infer_descriptions(&fields)?, expected.len()); + + // write + buffer.set_num_rows(expected.len()); + let array = &expected.columns()[0]; + + serialize(array.as_ref(), &mut buffer.column_mut(0))?; + + a.execute(&buffer).unwrap(); + + // read + let query = format!("SELECT a FROM {table_name} ORDER BY id"); + let chunks = read(&connection, &query)?.1; + + assert_eq!(chunks[0], expected); + Ok(()) +} + +#[test] +fn int() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let table_name = format!("write_{}", table_name); + let expected = Chunk::new(vec![Box::new(Int32Array::from_slice([1])) as _]); + + test( + expected, + vec![Field::new("a", DataType::Int32, false)], + "INT", + &table_name, + ) +} + +#[test] +fn int_nullable() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let table_name = format!("write_{}", table_name); + let expected = Chunk::new(vec![Box::new(Int32Array::from([Some(1), None])) as _]); + + test( + expected, + vec![Field::new("a", DataType::Int32, true)], + "INT", + &table_name, + ) +} + +#[test] +fn bool() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let table_name = format!("write_{}", table_name); + let expected = Chunk::new(vec![Box::new(BooleanArray::from_slice([true, false])) as _]); + + test( + expected, + vec![Field::new("a", DataType::Boolean, false)], + "BIT", + &table_name, + ) +} + +#[test] +fn bool_nullable() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let table_name = format!("write_{}", table_name); + let expected = Chunk::new(vec![ + Box::new(BooleanArray::from([Some(true), Some(false), None])) as _, + ]); + + test( + expected, + vec![Field::new("a", DataType::Boolean, true)], + "BIT", + &table_name, + ) +} + +#[test] +fn utf8() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let table_name = format!("write_{}", table_name); + let expected = + Chunk::new(vec![ + Box::new(Utf8Array::::from([Some("aa"), None, Some("aaaa")])) as _, + ]); + + test( + expected, + vec![Field::new("a", DataType::Utf8, true)], + "VARCHAR(4)", + &table_name, + ) +} + +#[test] +fn binary() -> Result<()> { + let table_name = function_name!().rsplit_once(':').unwrap().1; + let table_name = format!("write_{}", table_name); + let expected = Chunk::new(vec![Box::new(BinaryArray::::from([ + Some(&b"aa"[..]), + None, + Some(&b"aaaa"[..]), + ])) as _]); + + test( + expected, + vec![Field::new("a", DataType::Binary, true)], + "VARBINARY(4)", + &table_name, + ) +} diff --git a/examples/io_odbc.rs b/examples/io_odbc.rs new file mode 100644 index 00000000000..9305fab6e24 --- /dev/null +++ b/examples/io_odbc.rs @@ -0,0 +1,83 @@ +//! Demo of how to write to, and read from, an ODBC connector +//! +//! On an Ubuntu, you need to run the following (to install the driver): +//! ```bash +//! sudo apt install libsqliteodbc sqlite3 unixodbc-dev +//! sudo sed --in-place 's/libsqlite3odbc.so/\/usr\/lib\/x86_64-linux-gnu\/odbc\/libsqlite3odbc.so/' /etc/odbcinst.ini +//! ``` +use arrow2::array::{Array, Int32Array, Utf8Array}; +use arrow2::chunk::Chunk; +use arrow2::datatypes::{DataType, Field}; +use arrow2::error::Result; +use arrow2::io::odbc::api; +use arrow2::io::odbc::api::Cursor; +use arrow2::io::odbc::read; +use arrow2::io::odbc::write; + +fn main() -> Result<()> { + let connector = "Driver={SQLite3};Database=sqlite-test.db"; + let env = api::Environment::new()?; + let connection = env.connect_with_connection_string(connector)?; + + // let's create an empty table with a schema + connection.execute("DROP TABLE IF EXISTS example;", ())?; + connection.execute("CREATE TABLE example (c1 INT, c2 TEXT);", ())?; + + // and now let's write some data into it (from arrow arrays!) + // first, we prepare the statement + let query = "INSERT INTO example (c1, c2) VALUES (?, ?)"; + let prepared = connection.prepare(query).unwrap(); + + // secondly, we initialize buffers from odbc-api + let fields = vec![ + // (for now) the types here must match the tables' schema + Field::new("unused", DataType::Int32, true), + Field::new("unused", DataType::LargeUtf8, true), + ]; + + // third, we initialize the writer + let mut writer = write::Writer::try_new(prepared, fields)?; + + // say we have (or receive from a channel) a chunk: + let chunk = Chunk::new(vec![ + Box::new(Int32Array::from_slice([1, 2, 3])) as Box, + Box::new(Utf8Array::::from([Some("Hello"), None, Some("World")])), + ]); + + // we write it like this + writer.write(&chunk)?; + + // and we can later read from it + let chunks = read(&connection, "SELECT c1 FROM example")?; + + // and the result should be the same + assert_eq!(chunks[0].columns()[0], chunk.columns()[0]); + + Ok(()) +} + +/// Reads chunks from a query done against an ODBC connection +pub fn read(connection: &api::Connection<'_>, query: &str) -> Result>>> { + let mut a = connection.prepare(query)?; + let fields = read::infer_schema(&a)?; + + let max_batch_size = 100; + let buffer = read::buffer_from_metadata(&a, max_batch_size)?; + + let cursor = a.execute(())?.unwrap(); + let mut cursor = cursor.bind_buffer(buffer)?; + + let mut chunks = vec![]; + while let Some(batch) = cursor.fetch()? { + let arrays = (0..batch.num_cols()) + .zip(fields.iter()) + .map(|(index, field)| { + let column_view = batch.column(index); + read::deserialize(column_view, field.data_type.clone()) + }) + .collect::>(); + chunks.push(Chunk::new(arrays)); + } + + Ok(chunks) +} diff --git a/guide/src/io/README.md b/guide/src/io/README.md index f8e8bb8ca64..5a8cd477949 100644 --- a/guide/src/io/README.md +++ b/guide/src/io/README.md @@ -7,5 +7,6 @@ This crate offers optional features that enable interoperability with different * Parquet (`io_parquet`) * JSON and NDJSON (`io_json`) * Avro (`io_avro` and `io_avro_async`) +* ODBC-compliant databases (`io_odbc`) In this section you can find a guide and examples for each one of them. diff --git a/guide/src/io/odbc.md b/guide/src/io/odbc.md new file mode 100644 index 00000000000..7e362daf7c6 --- /dev/null +++ b/guide/src/io/odbc.md @@ -0,0 +1,8 @@ +# ODBC + +When compiled with feature `io_odbc`, this crate can be used to read from, and write to +any [ODBC](https://en.wikipedia.org/wiki/Open_Database_Connectivity) interface: + +```rust +{{#include ../../../examples/odbc.rs}} +``` diff --git a/src/io/mod.rs b/src/io/mod.rs index 1b5c39ed7c8..9343d4281ce 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -1,6 +1,10 @@ #![forbid(unsafe_code)] //! Contains modules to interface with other formats such as [`csv`], //! [`parquet`], [`json`], [`ipc`], [`mod@print`] and [`avro`]. + +#[cfg(feature = "io_odbc")] +pub mod odbc; + #[cfg(any( feature = "io_csv_read", feature = "io_csv_read_async", diff --git a/src/io/odbc/mod.rs b/src/io/odbc/mod.rs new file mode 100644 index 00000000000..d681ed3c2fd --- /dev/null +++ b/src/io/odbc/mod.rs @@ -0,0 +1,11 @@ +//! API to serialize and deserialize data from and to ODBC +pub use odbc_api as api; + +pub mod read; +pub mod write; + +impl From for crate::error::ArrowError { + fn from(error: api::Error) -> Self { + crate::error::ArrowError::External("".to_string(), Box::new(error)) + } +} diff --git a/src/io/odbc/read/deserialize.rs b/src/io/odbc/read/deserialize.rs new file mode 100644 index 00000000000..e55fc1a2598 --- /dev/null +++ b/src/io/odbc/read/deserialize.rs @@ -0,0 +1,141 @@ +use odbc_api::buffers::{BinColumnIt, TextColumnIt}; +use odbc_api::Bit; + +use crate::array::{Array, BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::types::NativeType; + +use super::super::api::buffers::AnyColumnView; + +/// Deserializes a [`AnyColumnView`] into an array of [`DataType`]. +/// This is CPU-bounded +pub fn deserialize(column: AnyColumnView, data_type: DataType) -> Box { + match column { + AnyColumnView::Text(iter) => Box::new(utf8(data_type, iter)) as _, + AnyColumnView::WText(_) => todo!(), + AnyColumnView::Binary(iter) => Box::new(binary(data_type, iter)) as _, + AnyColumnView::Date(_) => todo!(), + AnyColumnView::Time(_) => todo!(), + AnyColumnView::Timestamp(_) => todo!(), + AnyColumnView::F64(values) => Box::new(primitive(data_type, values)) as _, + AnyColumnView::F32(values) => Box::new(primitive(data_type, values)) as _, + AnyColumnView::I8(values) => Box::new(primitive(data_type, values)) as _, + AnyColumnView::I16(values) => Box::new(primitive(data_type, values)) as _, + AnyColumnView::I32(values) => Box::new(primitive(data_type, values)) as _, + AnyColumnView::I64(values) => Box::new(primitive(data_type, values)) as _, + AnyColumnView::U8(values) => Box::new(primitive(data_type, values)) as _, + AnyColumnView::Bit(values) => Box::new(bool(data_type, values)) as _, + AnyColumnView::NullableDate(_) => todo!(), + AnyColumnView::NullableTime(_) => todo!(), + AnyColumnView::NullableTimestamp(_) => todo!(), + AnyColumnView::NullableF64(slice) => Box::new(primitive_optional( + data_type, + slice.raw_values().0, + slice.raw_values().1, + )) as _, + AnyColumnView::NullableF32(slice) => Box::new(primitive_optional( + data_type, + slice.raw_values().0, + slice.raw_values().1, + )) as _, + AnyColumnView::NullableI8(slice) => Box::new(primitive_optional( + data_type, + slice.raw_values().0, + slice.raw_values().1, + )) as _, + AnyColumnView::NullableI16(slice) => Box::new(primitive_optional( + data_type, + slice.raw_values().0, + slice.raw_values().1, + )) as _, + AnyColumnView::NullableI32(slice) => Box::new(primitive_optional( + data_type, + slice.raw_values().0, + slice.raw_values().1, + )) as _, + AnyColumnView::NullableI64(slice) => Box::new(primitive_optional( + data_type, + slice.raw_values().0, + slice.raw_values().1, + )) as _, + AnyColumnView::NullableU8(slice) => Box::new(primitive_optional( + data_type, + slice.raw_values().0, + slice.raw_values().1, + )) as _, + AnyColumnView::NullableBit(slice) => Box::new(bool_optional( + data_type, + slice.raw_values().0, + slice.raw_values().1, + )) as _, + } +} + +fn bitmap(values: &[isize]) -> Option { + MutableBitmap::from_trusted_len_iter(values.iter().map(|x| *x != -1)).into() +} + +fn primitive(data_type: DataType, values: &[T]) -> PrimitiveArray { + PrimitiveArray::from_data(data_type, values.to_vec().into(), None) +} + +fn primitive_optional( + data_type: DataType, + values: &[T], + indicators: &[isize], +) -> PrimitiveArray { + let validity = bitmap(indicators); + PrimitiveArray::from_data(data_type, values.to_vec().into(), validity) +} + +fn bool(data_type: DataType, values: &[Bit]) -> BooleanArray { + let values = values.iter().map(|x| x.as_bool()); + let values = Bitmap::from_trusted_len_iter(values); + BooleanArray::from_data(data_type, values, None) +} + +fn bool_optional(data_type: DataType, values: &[Bit], indicators: &[isize]) -> BooleanArray { + let validity = bitmap(indicators); + let values = values.iter().map(|x| x.as_bool()); + let values = Bitmap::from_trusted_len_iter(values); + BooleanArray::from_data(data_type, values, validity) +} + +fn binary_generic<'a>( + iter: impl Iterator>, +) -> (Buffer, Buffer, Option) { + let length = iter.size_hint().0; + let mut validity = MutableBitmap::with_capacity(length); + let mut values = Vec::::with_capacity(0); + + let mut offsets = Vec::with_capacity(length + 1); + offsets.push(0i32); + + for item in iter { + if let Some(item) = item { + values.extend_from_slice(item); + validity.push(true); + } else { + validity.push(false); + } + offsets.push(values.len() as i32) + } + + (offsets.into(), values.into(), validity.into()) +} + +fn binary(data_type: DataType, iter: BinColumnIt) -> BinaryArray { + let (offsets, values, validity) = binary_generic(iter); + + // this O(N) check is not necessary + BinaryArray::from_data(data_type, offsets, values, validity) +} + +fn utf8(data_type: DataType, iter: TextColumnIt) -> Utf8Array { + let (offsets, values, validity) = binary_generic(iter); + + // this O(N) check is necessary for the utf8 validity + Utf8Array::from_data(data_type, offsets, values, validity) +} diff --git a/src/io/odbc/read/mod.rs b/src/io/odbc/read/mod.rs new file mode 100644 index 00000000000..b4077332ff3 --- /dev/null +++ b/src/io/odbc/read/mod.rs @@ -0,0 +1,37 @@ +//! APIs to read from ODBC +mod deserialize; +mod schema; + +pub use deserialize::deserialize; +pub use schema::infer_schema; + +use super::api; + +/// Creates a [`api::buffers::ColumnarBuffer`] from the metadata. +/// # Errors +/// Iff the driver provides an incorrect [`ResultSetMetadata`] +pub fn buffer_from_metadata( + resut_set_metadata: &impl api::ResultSetMetadata, + max_batch_size: usize, +) -> std::result::Result, api::Error> { + let num_cols: u16 = resut_set_metadata.num_result_cols()? as u16; + + let descs = (0..num_cols) + .map(|index| { + let mut column_description = api::ColumnDescription::default(); + + resut_set_metadata.describe_col(index + 1, &mut column_description)?; + + Ok(api::buffers::BufferDescription { + nullable: column_description.could_be_nullable(), + kind: api::buffers::BufferKind::from_data_type(column_description.data_type) + .unwrap(), + }) + }) + .collect::, api::Error>>()?; + + Ok(api::buffers::buffer_from_description( + max_batch_size, + descs.into_iter(), + )) +} diff --git a/src/io/odbc/read/schema.rs b/src/io/odbc/read/schema.rs new file mode 100644 index 00000000000..dba4c233738 --- /dev/null +++ b/src/io/odbc/read/schema.rs @@ -0,0 +1,80 @@ +use crate::datatypes::{DataType, Field, TimeUnit}; +use crate::error::Result; + +use super::super::api; +use super::super::api::ResultSetMetadata; + +/// Infers the Arrow [`Field`]s from a [`ResultSetMetadata`] +pub fn infer_schema(resut_set_metadata: &impl ResultSetMetadata) -> Result> { + let num_cols: u16 = resut_set_metadata.num_result_cols().unwrap() as u16; + + let fields = (0..num_cols) + .map(|index| { + let mut column_description = api::ColumnDescription::default(); + resut_set_metadata + .describe_col(index + 1, &mut column_description) + .unwrap(); + + column_to_field(&column_description) + }) + .collect(); + Ok(fields) +} + +fn column_to_field(column_description: &api::ColumnDescription) -> Field { + Field::new( + &column_description + .name_to_string() + .expect("Column name must be representable in utf8"), + column_to_data_type(&column_description.data_type), + column_description.could_be_nullable(), + ) +} + +fn column_to_data_type(data_type: &api::DataType) -> DataType { + use api::DataType as OdbcDataType; + match data_type { + OdbcDataType::Numeric { + precision: p @ 0..=38, + scale, + } + | OdbcDataType::Decimal { + precision: p @ 0..=38, + scale, + } => DataType::Decimal(*p, (*scale) as usize), + OdbcDataType::Integer => DataType::Int32, + OdbcDataType::SmallInt => DataType::Int16, + OdbcDataType::Real | OdbcDataType::Float { precision: 0..=24 } => DataType::Float32, + OdbcDataType::Float { precision: _ } | OdbcDataType::Double => DataType::Float64, + OdbcDataType::Date => DataType::Date32, + OdbcDataType::Timestamp { precision: 0 } => DataType::Timestamp(TimeUnit::Second, None), + OdbcDataType::Timestamp { precision: 1..=3 } => { + DataType::Timestamp(TimeUnit::Millisecond, None) + } + OdbcDataType::Timestamp { precision: 4..=6 } => { + DataType::Timestamp(TimeUnit::Microsecond, None) + } + OdbcDataType::Timestamp { precision: _ } => DataType::Timestamp(TimeUnit::Nanosecond, None), + OdbcDataType::BigInt => DataType::Int64, + OdbcDataType::TinyInt => DataType::Int8, + OdbcDataType::Bit => DataType::Boolean, + OdbcDataType::Binary { length } => DataType::FixedSizeBinary(*length), + OdbcDataType::LongVarbinary { length: _ } | OdbcDataType::Varbinary { length: _ } => { + DataType::Binary + } + OdbcDataType::Unknown + | OdbcDataType::Time { precision: _ } + | OdbcDataType::Numeric { .. } + | OdbcDataType::Decimal { .. } + | OdbcDataType::Other { + data_type: _, + column_size: _, + decimal_digits: _, + } + | OdbcDataType::WChar { length: _ } + | OdbcDataType::Char { length: _ } + | OdbcDataType::WVarchar { length: _ } + | OdbcDataType::LongVarchar { length: _ } + | OdbcDataType::Varchar { length: _ } => DataType::Utf8, + } +} diff --git a/src/io/odbc/write/mod.rs b/src/io/odbc/write/mod.rs new file mode 100644 index 00000000000..245f2455bb8 --- /dev/null +++ b/src/io/odbc/write/mod.rs @@ -0,0 +1,71 @@ +//! APIs to write to ODBC +mod schema; +mod serialize; + +use crate::{array::Array, chunk::Chunk, datatypes::Field, error::Result}; + +use super::api; +pub use schema::infer_descriptions; +pub use serialize::serialize; + +/// Creates a [`api::buffers::ColumnarBuffer`] from [`api::ColumnDescription`]s. +/// +/// This is useful when separating the serialization (CPU-bounded) to writing to the DB (IO-bounded). +pub fn buffer_from_description( + descriptions: Vec, + capacity: usize, +) -> api::buffers::ColumnarBuffer { + let descs = descriptions + .into_iter() + .map(|description| api::buffers::BufferDescription { + nullable: description.could_be_nullable(), + kind: api::buffers::BufferKind::from_data_type(description.data_type).unwrap(), + }); + + api::buffers::buffer_from_description(capacity, descs) +} + +/// A writer of [`Chunk`]s to an ODBC [`api::Prepared`] statement. +/// # Implementation +/// This struct mixes CPU-bounded and IO-bounded tasks and is not ideal +/// for an `async` context. +pub struct Writer<'a> { + fields: Vec, + buffer: api::buffers::ColumnarBuffer, + prepared: api::Prepared<'a>, +} + +impl<'a> Writer<'a> { + /// Creates a new [`Writer`]. + /// # Errors + /// Errors iff any of the types from [`Field`] is not supported. + pub fn try_new(prepared: api::Prepared<'a>, fields: Vec) -> Result { + let buffer = buffer_from_description(infer_descriptions(&fields)?, 0); + Ok(Self { + fields, + buffer, + prepared, + }) + } + + /// Writes a chunk to the writer. + /// # Errors + /// Errors iff the execution of the statement fails. + pub fn write>(&mut self, chunk: &Chunk) -> Result<()> { + if chunk.len() > self.buffer.num_rows() { + // if the chunk is larger, we re-allocate new buffers to hold it + self.buffer = buffer_from_description(infer_descriptions(&self.fields)?, chunk.len()); + } + + self.buffer.set_num_rows(chunk.len()); + + // serialize (CPU-bounded) + for (i, column) in chunk.arrays().iter().enumerate() { + serialize(column.as_ref(), &mut self.buffer.column_mut(i))?; + } + + // write (IO-bounded) + self.prepared.execute(&self.buffer)?; + Ok(()) + } +} diff --git a/src/io/odbc/write/schema.rs b/src/io/odbc/write/schema.rs new file mode 100644 index 00000000000..9e4b61f704e --- /dev/null +++ b/src/io/odbc/write/schema.rs @@ -0,0 +1,38 @@ +use super::super::api; + +use crate::datatypes::{DataType, Field}; +use crate::error::{ArrowError, Result}; + +/// Infers the [`api::ColumnDescription`] from the fields +pub fn infer_descriptions(fields: &[Field]) -> Result> { + fields + .iter() + .map(|field| { + let nullability = if field.is_nullable { + api::Nullability::Nullable + } else { + api::Nullability::NoNulls + }; + let data_type = data_type_to(field.data_type())?; + Ok(api::ColumnDescription { + name: api::U16String::from_str(&field.name).into_vec(), + nullability, + data_type, + }) + }) + .collect() +} + +fn data_type_to(data_type: &DataType) -> Result { + Ok(match data_type { + DataType::Boolean => api::DataType::Bit, + DataType::Int16 => api::DataType::SmallInt, + DataType::Int32 => api::DataType::Integer, + DataType::Float32 => api::DataType::Float { precision: 24 }, + DataType::Float64 => api::DataType::Float { precision: 53 }, + DataType::FixedSizeBinary(length) => api::DataType::Binary { length: *length }, + DataType::Binary | DataType::LargeBinary => api::DataType::Varbinary { length: 0 }, + DataType::Utf8 | DataType::LargeUtf8 => api::DataType::Varchar { length: 0 }, + other => return Err(ArrowError::nyi(format!("{other:?} to ODBC"))), + }) +} diff --git a/src/io/odbc/write/serialize.rs b/src/io/odbc/write/serialize.rs new file mode 100644 index 00000000000..3128ceb964b --- /dev/null +++ b/src/io/odbc/write/serialize.rs @@ -0,0 +1,179 @@ +use api::buffers::{BinColumnWriter, TextColumnWriter}; + +use crate::array::*; +use crate::bitmap::Bitmap; +use crate::datatypes::DataType; +use crate::error::{ArrowError, Result}; +use crate::types::NativeType; + +use super::super::api; +use super::super::api::buffers::NullableSliceMut; + +/// Serializes an [`Array`] to [`api::buffers::AnyColumnViewMut`] +/// This operation is CPU-bounded +pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) -> Result<()> { + match array.data_type() { + DataType::Boolean => { + if let api::buffers::AnyColumnViewMut::Bit(values) = column { + bool(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else if let api::buffers::AnyColumnViewMut::NullableBit(values) = column { + bool_optional(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else { + Err(ArrowError::nyi("serialize bool to non-bool ODBC")) + } + } + DataType::Int16 => { + if let api::buffers::AnyColumnViewMut::I16(values) = column { + primitive(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else if let api::buffers::AnyColumnViewMut::NullableI16(values) = column { + primitive_optional(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else { + Err(ArrowError::nyi("serialize i16 to non-i16 ODBC")) + } + } + DataType::Int32 => { + if let api::buffers::AnyColumnViewMut::I32(values) = column { + primitive(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else if let api::buffers::AnyColumnViewMut::NullableI32(values) = column { + primitive_optional(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else { + Err(ArrowError::nyi("serialize i32 to non-i32 ODBC")) + } + } + DataType::Float32 => { + if let api::buffers::AnyColumnViewMut::F32(values) = column { + primitive(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else if let api::buffers::AnyColumnViewMut::NullableF32(values) = column { + primitive_optional(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else { + Err(ArrowError::nyi("serialize f32 to non-f32 ODBC")) + } + } + DataType::Float64 => { + if let api::buffers::AnyColumnViewMut::F64(values) = column { + primitive(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else if let api::buffers::AnyColumnViewMut::NullableF64(values) = column { + primitive_optional(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else { + Err(ArrowError::nyi("serialize f64 to non-f64 ODBC")) + } + } + DataType::Utf8 => { + if let api::buffers::AnyColumnViewMut::Text(values) = column { + utf8::(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else { + Err(ArrowError::nyi("serialize utf8 to non-text ODBC")) + } + } + DataType::LargeUtf8 => { + if let api::buffers::AnyColumnViewMut::Text(values) = column { + utf8::(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else { + Err(ArrowError::nyi("serialize utf8 to non-text ODBC")) + } + } + DataType::Binary => { + if let api::buffers::AnyColumnViewMut::Binary(values) = column { + binary::(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else { + Err(ArrowError::nyi("serialize utf8 to non-binary ODBC")) + } + } + DataType::LargeBinary => { + if let api::buffers::AnyColumnViewMut::Binary(values) = column { + binary::(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else { + Err(ArrowError::nyi("serialize utf8 to non-text ODBC")) + } + } + DataType::FixedSizeBinary(_) => { + if let api::buffers::AnyColumnViewMut::Binary(values) = column { + fixed_binary(array.as_any().downcast_ref().unwrap(), values); + Ok(()) + } else { + Err(ArrowError::nyi("serialize fixed to non-binary ODBC")) + } + } + other => Err(ArrowError::nyi(format!("{other:?} to ODBC"))), + } +} + +fn bool(array: &BooleanArray, values: &mut [api::Bit]) { + array + .values() + .iter() + .zip(values.iter_mut()) + .for_each(|(from, to)| *to = api::Bit(from as u8)); +} + +fn bool_optional(array: &BooleanArray, values: &mut NullableSliceMut) { + let (values, indicators) = values.raw_values(); + array + .values() + .iter() + .zip(values.iter_mut()) + .for_each(|(from, to)| *to = api::Bit(from as u8)); + write_validity(array.validity(), indicators); +} + +fn primitive(array: &PrimitiveArray, values: &mut [T]) { + values.copy_from_slice(array.values()) +} + +fn write_validity(validity: Option<&Bitmap>, indicators: &mut [isize]) { + if let Some(validity) = validity { + indicators + .iter_mut() + .zip(validity.iter()) + .for_each(|(indicator, is_valid)| *indicator = if is_valid { 0 } else { -1 }) + } else { + indicators.iter_mut().for_each(|x| *x = 0) + } +} + +fn primitive_optional(array: &PrimitiveArray, values: &mut NullableSliceMut) { + let (values, indicators) = values.raw_values(); + values.copy_from_slice(array.values()); + write_validity(array.validity(), indicators); +} + +fn fixed_binary(array: &FixedSizeBinaryArray, writer: &mut BinColumnWriter) { + writer.set_max_len(array.size()); + writer.write(array.iter()) +} + +fn binary(array: &BinaryArray, writer: &mut BinColumnWriter) { + let max_len = array + .offsets() + .windows(2) + .map(|x| (x[1] - x[0]).to_usize()) + .max() + .unwrap_or(0); + writer.set_max_len(max_len); + writer.write(array.iter()) +} + +fn utf8(array: &Utf8Array, writer: &mut TextColumnWriter) { + let max_len = array + .offsets() + .windows(2) + .map(|x| (x[1] - x[0]).to_usize()) + .max() + .unwrap_or(0); + writer.set_max_len(max_len); + writer.write(array.iter().map(|x| x.map(|x| x.as_bytes()))) +}