Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 100 additions & 9 deletions crates/duckdb/src/core/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use crate::ffi::*;
/// <https://duckdb.org/docs/api/c/types>
#[repr(u32)]
#[derive(Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum LogicalTypeId {
/// Invalid
Invalid = DUCKDB_TYPE_DUCKDB_TYPE_INVALID,
/// Boolean
Boolean = DUCKDB_TYPE_DUCKDB_TYPE_BOOLEAN,
/// Tinyint
Expand Down Expand Up @@ -66,14 +69,37 @@ pub enum LogicalTypeId {
Uuid = DUCKDB_TYPE_DUCKDB_TYPE_UUID,
/// Union
Union = DUCKDB_TYPE_DUCKDB_TYPE_UNION,
/// Bit
Bit = DUCKDB_TYPE_DUCKDB_TYPE_BIT,
/// Time TZ
TimeTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIME_TZ,
/// Timestamp TZ
TimestampTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ,
/// Unsigned Hugeint
UHugeint = DUCKDB_TYPE_DUCKDB_TYPE_UHUGEINT,
/// Array
Array = DUCKDB_TYPE_DUCKDB_TYPE_ARRAY,
/// Any
Any = DUCKDB_TYPE_DUCKDB_TYPE_ANY,
/// Bignum
Bignum = DUCKDB_TYPE_DUCKDB_TYPE_BIGNUM,
/// SqlNull
SqlNull = DUCKDB_TYPE_DUCKDB_TYPE_SQLNULL,
/// String Literal
StringLiteral = DUCKDB_TYPE_DUCKDB_TYPE_STRING_LITERAL,
/// Integer Literal
IntegerLiteral = DUCKDB_TYPE_DUCKDB_TYPE_INTEGER_LITERAL,
/// Time NS
TimeNs = DUCKDB_TYPE_DUCKDB_TYPE_TIME_NS,
/// DuckDB returned a type that this wrapper does not yet recognize
Unsupported = u32::MAX,
}

impl From<u32> for LogicalTypeId {
/// Convert from u32 to LogicalTypeId
fn from(value: u32) -> Self {
match value {
DUCKDB_TYPE_DUCKDB_TYPE_INVALID => Self::Invalid,
DUCKDB_TYPE_DUCKDB_TYPE_BOOLEAN => Self::Boolean,
DUCKDB_TYPE_DUCKDB_TYPE_TINYINT => Self::Tinyint,
DUCKDB_TYPE_DUCKDB_TYPE_SMALLINT => Self::Smallint,
Expand Down Expand Up @@ -102,8 +128,19 @@ impl From<u32> for LogicalTypeId {
DUCKDB_TYPE_DUCKDB_TYPE_MAP => Self::Map,
DUCKDB_TYPE_DUCKDB_TYPE_UUID => Self::Uuid,
DUCKDB_TYPE_DUCKDB_TYPE_UNION => Self::Union,
DUCKDB_TYPE_DUCKDB_TYPE_BIT => Self::Bit,
DUCKDB_TYPE_DUCKDB_TYPE_TIME_TZ => Self::TimeTZ,
DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ => Self::TimestampTZ,
_ => panic!(),
DUCKDB_TYPE_DUCKDB_TYPE_UHUGEINT => Self::UHugeint,
DUCKDB_TYPE_DUCKDB_TYPE_ARRAY => Self::Array,
DUCKDB_TYPE_DUCKDB_TYPE_ANY => Self::Any,
DUCKDB_TYPE_DUCKDB_TYPE_BIGNUM => Self::Bignum,
DUCKDB_TYPE_DUCKDB_TYPE_SQLNULL => Self::SqlNull,
DUCKDB_TYPE_DUCKDB_TYPE_STRING_LITERAL => Self::StringLiteral,
DUCKDB_TYPE_DUCKDB_TYPE_INTEGER_LITERAL => Self::IntegerLiteral,
DUCKDB_TYPE_DUCKDB_TYPE_TIME_NS => Self::TimeNs,
// Unknown / forward compatible types
_ => Self::Unsupported,
}
}
}
Expand All @@ -119,6 +156,8 @@ impl Debug for LogicalTypeHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
let id = self.id();
match id {
LogicalTypeId::Invalid => write!(f, "Invalid"),
LogicalTypeId::Unsupported => write!(f, "Unsupported({})", self.raw_id()),
LogicalTypeId::Struct => {
write!(f, "struct<")?;
for i in 0..self.num_children() {
Expand All @@ -129,7 +168,7 @@ impl Debug for LogicalTypeHandle {
}
write!(f, ">")
}
_ => write!(f, "{:?}", self.id()),
_ => write!(f, "{:?}", id),
}
}
}
Expand Down Expand Up @@ -248,8 +287,25 @@ impl LogicalTypeHandle {

/// Logical type ID
pub fn id(&self) -> LogicalTypeId {
let duckdb_type_id = unsafe { duckdb_get_type_id(self.ptr) };
duckdb_type_id.into()
self.raw_id().into()
}

/// Logical type ID, with forward-compatibility awareness.
///
/// Returns `Ok(LogicalTypeId)` for all known ids (including `Invalid`), and
/// `Err(raw_id)` when DuckDB returns an id this wrapper does not yet
/// recognize.
pub fn try_id(&self) -> Result<LogicalTypeId, u32> {
let raw = self.raw_id();
match LogicalTypeId::from(raw) {
LogicalTypeId::Unsupported => Err(raw),
id => Ok(id),
}
}

/// Raw logical type id returned by DuckDB C API
pub fn raw_id(&self) -> u32 {
unsafe { duckdb_get_type_id(self.ptr) }
}

/// Logical type children num
Expand All @@ -258,6 +314,7 @@ impl LogicalTypeHandle {
LogicalTypeId::Struct => unsafe { duckdb_struct_type_child_count(self.ptr) as usize },
LogicalTypeId::Union => unsafe { duckdb_union_type_member_count(self.ptr) as usize },
LogicalTypeId::List => 1,
LogicalTypeId::Array => 1,
_ => 0,
}
}
Expand All @@ -270,6 +327,7 @@ impl LogicalTypeHandle {
let child_name_ptr = match self.id() {
LogicalTypeId::Struct => duckdb_struct_type_child_name(self.ptr, idx as u64),
LogicalTypeId::Union => duckdb_union_type_member_name(self.ptr, idx as u64),
LogicalTypeId::Unsupported => panic!("unsupported logical type {}", self.raw_id()),
_ => panic!("not a struct or union"),
};
let c_str = CString::from_raw(child_name_ptr);
Expand All @@ -284,7 +342,9 @@ impl LogicalTypeHandle {
match self.id() {
LogicalTypeId::Struct => duckdb_struct_type_child_type(self.ptr, idx as u64),
LogicalTypeId::Union => duckdb_union_type_member_type(self.ptr, idx as u64),
_ => panic!("not a struct or union"),
LogicalTypeId::Array => duckdb_array_type_child_type(self.ptr),
LogicalTypeId::Unsupported => panic!("unsupported logical type {}", self.raw_id()),
_ => panic!("not a struct, union, or array"),
}
};
unsafe { Self::new(c_logical_type) }
Expand Down Expand Up @@ -319,26 +379,36 @@ mod test {

#[test]
fn test_struct() {
let fields = &[("hello", LogicalTypeHandle::from(crate::core::LogicalTypeId::Boolean))];
let fields = &[("hello", LogicalTypeHandle::from(LogicalTypeId::Boolean))];
let typ = LogicalTypeHandle::struct_type(fields);

assert_eq!(typ.num_children(), 1);
assert_eq!(typ.child_name(0), "hello");
assert_eq!(typ.child(0).id(), crate::core::LogicalTypeId::Boolean);
assert_eq!(typ.child(0).id(), LogicalTypeId::Boolean);
}

#[test]
fn test_array() {
let child = LogicalTypeHandle::from(LogicalTypeId::Integer);
let array = LogicalTypeHandle::array(&child, 4);

assert_eq!(array.id(), LogicalTypeId::Array);
assert_eq!(array.num_children(), 1);
assert_eq!(array.child(0).id(), LogicalTypeId::Integer);
}

#[test]
fn test_decimal() {
let typ = LogicalTypeHandle::decimal(10, 2);

assert_eq!(typ.id(), crate::core::LogicalTypeId::Decimal);
assert_eq!(typ.id(), LogicalTypeId::Decimal);
assert_eq!(typ.decimal_width(), 10);
assert_eq!(typ.decimal_scale(), 2);
}

#[test]
fn test_decimal_methods() {
let typ = LogicalTypeHandle::from(crate::core::LogicalTypeId::Varchar);
let typ = LogicalTypeHandle::from(LogicalTypeId::Varchar);

assert_eq!(typ.decimal_width(), 0);
assert_eq!(typ.decimal_scale(), 0);
Expand All @@ -360,4 +430,25 @@ mod test {
assert_eq!(typ.child_name(1), "world");
assert_eq!(typ.child(1).id(), LogicalTypeId::Integer);
}

#[test]
fn test_invalid_type() {
use crate::ffi::{duckdb_create_logical_type, DUCKDB_TYPE_DUCKDB_TYPE_INVALID};

// Create an invalid logical type (what DuckDB returns in certain error cases)
let invalid_type =
unsafe { LogicalTypeHandle::new(duckdb_create_logical_type(DUCKDB_TYPE_DUCKDB_TYPE_INVALID)) };

assert_eq!(invalid_type.id(), LogicalTypeId::Invalid);
assert_eq!(invalid_type.try_id().unwrap(), LogicalTypeId::Invalid);
assert_eq!(invalid_type.raw_id(), DUCKDB_TYPE_DUCKDB_TYPE_INVALID);

let debug_str = format!("{invalid_type:?}");
assert_eq!(debug_str, "Invalid");
}

#[test]
fn test_unknown_type() {
assert_eq!(LogicalTypeId::from(999_999), LogicalTypeId::Unsupported);
}
}
57 changes: 24 additions & 33 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ use arrow::{
record_batch::RecordBatch,
};

use libduckdb_sys::{
duckdb_date, duckdb_hugeint, duckdb_interval, duckdb_string_t, duckdb_time, duckdb_timestamp, duckdb_vector,
};
use libduckdb_sys::{duckdb_date, duckdb_string_t, duckdb_time, duckdb_timestamp, duckdb_vector};
use num::{cast::AsPrimitive, ToPrimitive};

/// A pointer to the Arrow record batch for the table function.
Expand Down Expand Up @@ -261,8 +259,11 @@ pub fn flat_vector_to_arrow_array(
vector: &mut FlatVector,
len: usize,
) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
let type_id = vector.logical_type().id();
let raw_type_id = vector.logical_type().raw_id();
let type_id = LogicalTypeId::from(raw_type_id);
match type_id {
LogicalTypeId::Invalid => Err("Cannot convert invalid logical type returned by DuckDB".into()),
LogicalTypeId::Unsupported => Err(format!("Unsupported DuckDB logical type ID {raw_type_id}").into()),
LogicalTypeId::Integer => {
let data = vector.as_slice_with_len::<i32>(len);

Expand Down Expand Up @@ -461,35 +462,25 @@ pub fn flat_vector_to_arrow_array(

Ok(Arc::new(structs))
}
LogicalTypeId::Struct => {
todo!()
}
LogicalTypeId::Decimal => {
todo!()
}
LogicalTypeId::Map => {
todo!()
}
LogicalTypeId::List => {
todo!()
}
LogicalTypeId::Union => {
todo!()
}
LogicalTypeId::Interval => {
let _data = vector.as_slice_with_len::<duckdb_interval>(len);
todo!()
}
LogicalTypeId::Hugeint => {
let _data = vector.as_slice_with_len::<duckdb_hugeint>(len);
todo!()
}
LogicalTypeId::Enum => {
todo!()
}
LogicalTypeId::Uuid => {
todo!()
}
LogicalTypeId::Interval => todo!(),
LogicalTypeId::Hugeint => todo!(),
LogicalTypeId::Decimal => todo!(),
LogicalTypeId::Enum => todo!(),
LogicalTypeId::List => todo!(),
LogicalTypeId::Struct => todo!(),
LogicalTypeId::Map => todo!(),
LogicalTypeId::Array => todo!(),
LogicalTypeId::Uuid => todo!(),
LogicalTypeId::Union => todo!(),
LogicalTypeId::Bit => todo!(),
LogicalTypeId::TimeTZ => todo!(),
LogicalTypeId::UHugeint => todo!(),
LogicalTypeId::Any => todo!(),
LogicalTypeId::Bignum => todo!(),
LogicalTypeId::SqlNull => todo!(),
LogicalTypeId::StringLiteral => todo!(),
LogicalTypeId::IntegerLiteral => todo!(),
LogicalTypeId::TimeNs => todo!(),
}
}

Expand Down