diff --git a/crates/duckdb/src/core/logical_type.rs b/crates/duckdb/src/core/logical_type.rs index 8e6b5ce1..34d74acd 100644 --- a/crates/duckdb/src/core/logical_type.rs +++ b/crates/duckdb/src/core/logical_type.rs @@ -9,7 +9,10 @@ use crate::ffi::*; /// #[repr(u32)] #[derive(Debug, PartialEq, Eq)] +#[non_exhaustive] pub enum LogicalTypeId { + /// Invalid + Invalid = DUCKDB_TYPE_DUCKDB_TYPE_INVALID, /// Boolean Boolean = DUCKDB_TYPE_DUCKDB_TYPE_BOOLEAN, /// Tinyint @@ -66,14 +69,37 @@ pub enum LogicalTypeId { Uuid = DUCKDB_TYPE_DUCKDB_TYPE_UUID, /// Union Union = DUCKDB_TYPE_DUCKDB_TYPE_UNION, + /// Bit + Bit = DUCKDB_TYPE_DUCKDB_TYPE_BIT, + /// Time TZ + TimeTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIME_TZ, /// Timestamp TZ TimestampTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ, + /// Unsigned Hugeint + UHugeint = DUCKDB_TYPE_DUCKDB_TYPE_UHUGEINT, + /// Array + Array = DUCKDB_TYPE_DUCKDB_TYPE_ARRAY, + /// Any + Any = DUCKDB_TYPE_DUCKDB_TYPE_ANY, + /// Bignum + Bignum = DUCKDB_TYPE_DUCKDB_TYPE_BIGNUM, + /// SqlNull + SqlNull = DUCKDB_TYPE_DUCKDB_TYPE_SQLNULL, + /// String Literal + StringLiteral = DUCKDB_TYPE_DUCKDB_TYPE_STRING_LITERAL, + /// Integer Literal + IntegerLiteral = DUCKDB_TYPE_DUCKDB_TYPE_INTEGER_LITERAL, + /// Time NS + TimeNs = DUCKDB_TYPE_DUCKDB_TYPE_TIME_NS, + /// DuckDB returned a type that this wrapper does not yet recognize + Unsupported = u32::MAX, } impl From for LogicalTypeId { /// Convert from u32 to LogicalTypeId fn from(value: u32) -> Self { match value { + DUCKDB_TYPE_DUCKDB_TYPE_INVALID => Self::Invalid, DUCKDB_TYPE_DUCKDB_TYPE_BOOLEAN => Self::Boolean, DUCKDB_TYPE_DUCKDB_TYPE_TINYINT => Self::Tinyint, DUCKDB_TYPE_DUCKDB_TYPE_SMALLINT => Self::Smallint, @@ -102,8 +128,19 @@ impl From for LogicalTypeId { DUCKDB_TYPE_DUCKDB_TYPE_MAP => Self::Map, DUCKDB_TYPE_DUCKDB_TYPE_UUID => Self::Uuid, DUCKDB_TYPE_DUCKDB_TYPE_UNION => Self::Union, + DUCKDB_TYPE_DUCKDB_TYPE_BIT => Self::Bit, + DUCKDB_TYPE_DUCKDB_TYPE_TIME_TZ => Self::TimeTZ, DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ => Self::TimestampTZ, - _ => panic!(), + DUCKDB_TYPE_DUCKDB_TYPE_UHUGEINT => Self::UHugeint, + DUCKDB_TYPE_DUCKDB_TYPE_ARRAY => Self::Array, + DUCKDB_TYPE_DUCKDB_TYPE_ANY => Self::Any, + DUCKDB_TYPE_DUCKDB_TYPE_BIGNUM => Self::Bignum, + DUCKDB_TYPE_DUCKDB_TYPE_SQLNULL => Self::SqlNull, + DUCKDB_TYPE_DUCKDB_TYPE_STRING_LITERAL => Self::StringLiteral, + DUCKDB_TYPE_DUCKDB_TYPE_INTEGER_LITERAL => Self::IntegerLiteral, + DUCKDB_TYPE_DUCKDB_TYPE_TIME_NS => Self::TimeNs, + // Unknown / forward compatible types + _ => Self::Unsupported, } } } @@ -119,6 +156,8 @@ impl Debug for LogicalTypeHandle { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { let id = self.id(); match id { + LogicalTypeId::Invalid => write!(f, "Invalid"), + LogicalTypeId::Unsupported => write!(f, "Unsupported({})", self.raw_id()), LogicalTypeId::Struct => { write!(f, "struct<")?; for i in 0..self.num_children() { @@ -129,7 +168,7 @@ impl Debug for LogicalTypeHandle { } write!(f, ">") } - _ => write!(f, "{:?}", self.id()), + _ => write!(f, "{:?}", id), } } } @@ -248,8 +287,25 @@ impl LogicalTypeHandle { /// Logical type ID pub fn id(&self) -> LogicalTypeId { - let duckdb_type_id = unsafe { duckdb_get_type_id(self.ptr) }; - duckdb_type_id.into() + self.raw_id().into() + } + + /// Logical type ID, with forward-compatibility awareness. + /// + /// Returns `Ok(LogicalTypeId)` for all known ids (including `Invalid`), and + /// `Err(raw_id)` when DuckDB returns an id this wrapper does not yet + /// recognize. + pub fn try_id(&self) -> Result { + let raw = self.raw_id(); + match LogicalTypeId::from(raw) { + LogicalTypeId::Unsupported => Err(raw), + id => Ok(id), + } + } + + /// Raw logical type id returned by DuckDB C API + pub fn raw_id(&self) -> u32 { + unsafe { duckdb_get_type_id(self.ptr) } } /// Logical type children num @@ -258,6 +314,7 @@ impl LogicalTypeHandle { LogicalTypeId::Struct => unsafe { duckdb_struct_type_child_count(self.ptr) as usize }, LogicalTypeId::Union => unsafe { duckdb_union_type_member_count(self.ptr) as usize }, LogicalTypeId::List => 1, + LogicalTypeId::Array => 1, _ => 0, } } @@ -270,6 +327,7 @@ impl LogicalTypeHandle { let child_name_ptr = match self.id() { LogicalTypeId::Struct => duckdb_struct_type_child_name(self.ptr, idx as u64), LogicalTypeId::Union => duckdb_union_type_member_name(self.ptr, idx as u64), + LogicalTypeId::Unsupported => panic!("unsupported logical type {}", self.raw_id()), _ => panic!("not a struct or union"), }; let c_str = CString::from_raw(child_name_ptr); @@ -284,7 +342,9 @@ impl LogicalTypeHandle { match self.id() { LogicalTypeId::Struct => duckdb_struct_type_child_type(self.ptr, idx as u64), LogicalTypeId::Union => duckdb_union_type_member_type(self.ptr, idx as u64), - _ => panic!("not a struct or union"), + LogicalTypeId::Array => duckdb_array_type_child_type(self.ptr), + LogicalTypeId::Unsupported => panic!("unsupported logical type {}", self.raw_id()), + _ => panic!("not a struct, union, or array"), } }; unsafe { Self::new(c_logical_type) } @@ -319,26 +379,36 @@ mod test { #[test] fn test_struct() { - let fields = &[("hello", LogicalTypeHandle::from(crate::core::LogicalTypeId::Boolean))]; + let fields = &[("hello", LogicalTypeHandle::from(LogicalTypeId::Boolean))]; let typ = LogicalTypeHandle::struct_type(fields); assert_eq!(typ.num_children(), 1); assert_eq!(typ.child_name(0), "hello"); - assert_eq!(typ.child(0).id(), crate::core::LogicalTypeId::Boolean); + assert_eq!(typ.child(0).id(), LogicalTypeId::Boolean); + } + + #[test] + fn test_array() { + let child = LogicalTypeHandle::from(LogicalTypeId::Integer); + let array = LogicalTypeHandle::array(&child, 4); + + assert_eq!(array.id(), LogicalTypeId::Array); + assert_eq!(array.num_children(), 1); + assert_eq!(array.child(0).id(), LogicalTypeId::Integer); } #[test] fn test_decimal() { let typ = LogicalTypeHandle::decimal(10, 2); - assert_eq!(typ.id(), crate::core::LogicalTypeId::Decimal); + assert_eq!(typ.id(), LogicalTypeId::Decimal); assert_eq!(typ.decimal_width(), 10); assert_eq!(typ.decimal_scale(), 2); } #[test] fn test_decimal_methods() { - let typ = LogicalTypeHandle::from(crate::core::LogicalTypeId::Varchar); + let typ = LogicalTypeHandle::from(LogicalTypeId::Varchar); assert_eq!(typ.decimal_width(), 0); assert_eq!(typ.decimal_scale(), 0); @@ -360,4 +430,25 @@ mod test { assert_eq!(typ.child_name(1), "world"); assert_eq!(typ.child(1).id(), LogicalTypeId::Integer); } + + #[test] + fn test_invalid_type() { + use crate::ffi::{duckdb_create_logical_type, DUCKDB_TYPE_DUCKDB_TYPE_INVALID}; + + // Create an invalid logical type (what DuckDB returns in certain error cases) + let invalid_type = + unsafe { LogicalTypeHandle::new(duckdb_create_logical_type(DUCKDB_TYPE_DUCKDB_TYPE_INVALID)) }; + + assert_eq!(invalid_type.id(), LogicalTypeId::Invalid); + assert_eq!(invalid_type.try_id().unwrap(), LogicalTypeId::Invalid); + assert_eq!(invalid_type.raw_id(), DUCKDB_TYPE_DUCKDB_TYPE_INVALID); + + let debug_str = format!("{invalid_type:?}"); + assert_eq!(debug_str, "Invalid"); + } + + #[test] + fn test_unknown_type() { + assert_eq!(LogicalTypeId::from(999_999), LogicalTypeId::Unsupported); + } } diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index dfc14f95..10df8a85 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -25,9 +25,7 @@ use arrow::{ record_batch::RecordBatch, }; -use libduckdb_sys::{ - duckdb_date, duckdb_hugeint, duckdb_interval, duckdb_string_t, duckdb_time, duckdb_timestamp, duckdb_vector, -}; +use libduckdb_sys::{duckdb_date, duckdb_string_t, duckdb_time, duckdb_timestamp, duckdb_vector}; use num::{cast::AsPrimitive, ToPrimitive}; /// A pointer to the Arrow record batch for the table function. @@ -261,8 +259,11 @@ pub fn flat_vector_to_arrow_array( vector: &mut FlatVector, len: usize, ) -> Result, Box> { - let type_id = vector.logical_type().id(); + let raw_type_id = vector.logical_type().raw_id(); + let type_id = LogicalTypeId::from(raw_type_id); match type_id { + LogicalTypeId::Invalid => Err("Cannot convert invalid logical type returned by DuckDB".into()), + LogicalTypeId::Unsupported => Err(format!("Unsupported DuckDB logical type ID {raw_type_id}").into()), LogicalTypeId::Integer => { let data = vector.as_slice_with_len::(len); @@ -461,35 +462,25 @@ pub fn flat_vector_to_arrow_array( Ok(Arc::new(structs)) } - LogicalTypeId::Struct => { - todo!() - } - LogicalTypeId::Decimal => { - todo!() - } - LogicalTypeId::Map => { - todo!() - } - LogicalTypeId::List => { - todo!() - } - LogicalTypeId::Union => { - todo!() - } - LogicalTypeId::Interval => { - let _data = vector.as_slice_with_len::(len); - todo!() - } - LogicalTypeId::Hugeint => { - let _data = vector.as_slice_with_len::(len); - todo!() - } - LogicalTypeId::Enum => { - todo!() - } - LogicalTypeId::Uuid => { - todo!() - } + LogicalTypeId::Interval => todo!(), + LogicalTypeId::Hugeint => todo!(), + LogicalTypeId::Decimal => todo!(), + LogicalTypeId::Enum => todo!(), + LogicalTypeId::List => todo!(), + LogicalTypeId::Struct => todo!(), + LogicalTypeId::Map => todo!(), + LogicalTypeId::Array => todo!(), + LogicalTypeId::Uuid => todo!(), + LogicalTypeId::Union => todo!(), + LogicalTypeId::Bit => todo!(), + LogicalTypeId::TimeTZ => todo!(), + LogicalTypeId::UHugeint => todo!(), + LogicalTypeId::Any => todo!(), + LogicalTypeId::Bignum => todo!(), + LogicalTypeId::SqlNull => todo!(), + LogicalTypeId::StringLiteral => todo!(), + LogicalTypeId::IntegerLiteral => todo!(), + LogicalTypeId::TimeNs => todo!(), } }