diff --git a/python/Cargo.toml b/python/Cargo.toml index cc486b972ec..103052520ba 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -10,14 +10,14 @@ name = "lance" crate-type = ["cdylib"] [dependencies] -arrow-array = "33.0" -arrow-data = "33.0" -arrow-schema = "33.0" +arrow-array = "37.0" +arrow-data = "37.0" +arrow-schema = "37.0" chrono = "0.4.23" tokio = { version = "1.23", features = ["rt-multi-thread"] } futures = "0.3" pyo3 = { version = "0.18.1", features = ["extension-module", "abi3-py38"] } -arrow = { version = "33.0.0", features = ["pyarrow"] } +arrow = { version = "37.0.0", features = ["pyarrow"] } lance = { path = "../rust"} uuid = "1.3.0" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 9cee0002b32..8950936e5ae 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -24,14 +24,14 @@ categories = [ [dependencies] bytes = "1.3" -arrow-arith = "33.0" -arrow-array = "33.0" -arrow-buffer = "33.0" -arrow-cast = "33.0.0" -arrow-data = "33.0" -arrow-ord = "33.0" -arrow-schema = "33.0" -arrow-select = "33.0" +arrow-arith = "37.0" +arrow-array = "37.0" +arrow-buffer = "37.0" +arrow-cast = "37.0.0" +arrow-data = "37.0" +arrow-ord = "37.0" +arrow-schema = "37.0" +arrow-select = "37.0" async-recursion = "1.0" async-trait = "0.1.60" byteorder = "1.4.3" @@ -51,11 +51,11 @@ futures = "0.3" uuid = { version = "1.2", features = ["v4"] } path-absolutize = "3.0.14" shellexpand = "3.0.0" -arrow = { version = "33.0.0", features = ["prettyprint"] } +arrow = { version = "37.0.0", features = ["prettyprint"] } num_cpus = "1.0" sqlparser-lance = "0.32.0" # TODO: use datafusion sub-modules to reduce build size? -datafusion = { version = "19.0.0", default-features = false } +datafusion = { version = "23.0.0", default-features = false } faiss = { version = "0.11.0", features = ["gpu"], optional = true } lapack = "0.19.0" cblas = "0.4.0" diff --git a/rust/src/arrow.rs b/rust/src/arrow.rs index 84f69a945b7..111b62aefe4 100644 --- a/rust/src/arrow.rs +++ b/rust/src/arrow.rs @@ -24,7 +24,7 @@ use arrow_array::{ OffsetSizeTrait, PrimitiveArray, RecordBatch, UInt8Array, }; use arrow_data::ArrayDataBuilder; -use arrow_schema::{DataType, Field, Schema}; +use arrow_schema::{DataType, Field, FieldRef, Fields, Schema}; mod kernels; pub mod linalg; @@ -169,13 +169,13 @@ where { fn try_new(values: T, offsets: &PrimitiveArray) -> Result { let data_type = if Offset::Native::IS_LARGE { - DataType::LargeList(Box::new(Field::new( + DataType::LargeList(Arc::new(Field::new( "item", values.data_type().clone(), true, ))) } else { - DataType::List(Box::new(Field::new( + DataType::List(Arc::new(Field::new( "item", values.data_type().clone(), true, @@ -216,12 +216,12 @@ pub trait FixedSizeListArrayExt { impl FixedSizeListArrayExt for FixedSizeListArray { fn try_new(values: T, list_size: i32) -> Result { let list_type = DataType::FixedSizeList( - Box::new(Field::new("item", values.data_type().clone(), true)), + Arc::new(Field::new("item", values.data_type().clone(), true)), list_size, ); let data = ArrayDataBuilder::new(list_type) .len(values.len() / list_size as usize) - .add_child_data(values.data().clone()) + .add_child_data(values.into_data()) .build()?; Ok(Self::from(data)) @@ -261,7 +261,7 @@ impl FixedSizeBinaryArrayExt for FixedSizeBinaryArray { let data_type = DataType::FixedSizeBinary(stride); let data = ArrayDataBuilder::new(data_type) .len(values.len() / stride as usize) - .add_buffer(values.data().buffers()[0].clone()) + .add_buffer(values.into_data().buffers()[0].clone()) .build()?; Ok(Self::from(data)) } @@ -353,10 +353,10 @@ pub trait RecordBatchExt { impl RecordBatchExt for RecordBatch { fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result { - let mut new_fields = self.schema().fields.clone(); - new_fields.push(field); + let mut new_fields: Vec = self.schema().fields.iter().cloned().collect(); + new_fields.push(FieldRef::new(field)); let new_schema = Arc::new(Schema::new_with_metadata( - new_fields, + Fields::from(new_fields.as_slice()), self.schema().metadata.clone(), )); let mut new_columns = self.columns().to_vec(); @@ -373,9 +373,9 @@ impl RecordBatchExt for RecordBatch { ))); } - let mut fields = self.schema().fields.clone(); + let mut fields: Vec = self.schema().fields.iter().cloned().collect(); let mut columns = Vec::from(self.columns()); - for field in other.schema().fields.as_slice() { + for field in other.schema().fields.iter() { if !fields.iter().any(|f| f.name() == field.name()) { fields.push(field.clone()); columns.push( diff --git a/rust/src/arrow/linalg.rs b/rust/src/arrow/linalg.rs index 38c35a52f5c..13ff208aec3 100644 --- a/rust/src/arrow/linalg.rs +++ b/rust/src/arrow/linalg.rs @@ -17,10 +17,7 @@ use std::sync::Arc; -use arrow::{ - array::{as_primitive_array, Float32Builder}, - datatypes::Float32Type, -}; +use arrow::array::{as_primitive_array, Float32Builder}; use arrow_array::{Array, FixedSizeListArray, Float32Array}; use arrow_schema::DataType; use rand::{distributions::Standard, rngs::SmallRng, seq::IteratorRandom, Rng, SeedableRng}; @@ -254,7 +251,7 @@ impl MatrixView { let mut builder = Float32Builder::with_capacity(n * dim); for idx in chosen.iter() { let s = self.data.slice(idx * dim, dim); - builder.append_slice(as_primitive_array::(s.as_ref()).values()); + builder.append_slice(s.values()); } let data = Arc::new(builder.finish()); Self { @@ -447,13 +444,18 @@ mod tests { -0.6525516, 0.10910681, ]; - assert_relative_eq!(u.data().values(), expected_u.as_slice(), epsilon = 0.0001,); - - assert_relative_eq!( - sigma.values(), - vec![27.46873242, 22.64318501, 8.55838823, 5.9857232, 2.01489966].as_slice(), - epsilon = 0.0001, - ); + u.data() + .values() + .iter() + .zip(expected_u.iter()) + .for_each(|(a, b)| { + assert_relative_eq!(a, b, epsilon = 0.0001); + }); + + let expected = vec![27.46873242, 22.64318501, 8.55838823, 5.9857232, 2.01489966]; + sigma.values().iter().zip(expected).for_each(|(&a, b)| { + assert_relative_eq!(a, b, epsilon = 0.0001); + }); // Obtained from `numpy.linagl.svd()`. let expected_vt = vec![ @@ -483,7 +485,13 @@ mod tests { -0.62652825, -0.43955169, ]; - assert_relative_eq!(vt.data().values(), expected_vt.as_slice(), epsilon = 0.0001,) + vt.data() + .values() + .iter() + .zip(expected_vt) + .for_each(|(&a, b)| { + assert_relative_eq!(a, b, epsilon = 0.0001); + }); } #[test] @@ -499,7 +507,10 @@ mod tests { let b = MatrixView::new(b_data, 2); let c = a.dot(&b).unwrap(); - assert_relative_eq!(c.data.values(), vec![44.0, 50.0, 98.0, 113.0].as_slice(),); + let expected = vec![44.0, 50.0, 98.0, 113.0]; + c.data.values().iter().zip(expected).for_each(|(&a, b)| { + assert_relative_eq!(a, b, epsilon = 0.0001); + }); } #[test] @@ -515,7 +526,10 @@ mod tests { let b = MatrixView::new(b_data, 2); let c_t = b.transpose().dot(&a.transpose()).unwrap(); - assert_relative_eq!(c_t.data.values(), vec![44.0, 98.0, 50.0, 113.0].as_slice(),); + let expected = vec![44.0, 98.0, 50.0, 113.0]; + c_t.data.values().iter().zip(expected).for_each(|(&a, b)| { + assert_relative_eq!(a, b, epsilon = 0.0001); + }); } #[test] diff --git a/rust/src/arrow/schema.rs b/rust/src/arrow/schema.rs index 8243ed267df..c53f72810ad 100644 --- a/rust/src/arrow/schema.rs +++ b/rust/src/arrow/schema.rs @@ -14,7 +14,7 @@ //! Extension to arrow schema -use arrow_schema::{ArrowError, Field, Schema}; +use arrow_schema::{ArrowError, Field, FieldRef, Schema}; /// Extends the functionality of [arrow_schema::Schema]. pub trait SchemaExt { @@ -33,8 +33,8 @@ impl SchemaExt for Schema { self ))); }; - let mut fields = self.fields.clone(); - fields.push(field); + let mut fields: Vec = self.fields().iter().cloned().collect(); + fields.push(FieldRef::new(field)); Ok(Schema::new_with_metadata(fields, self.metadata.clone())) } diff --git a/rust/src/datafusion/physical_expr.rs b/rust/src/datafusion/physical_expr.rs index ec79b453e48..3a770b22349 100644 --- a/rust/src/datafusion/physical_expr.rs +++ b/rust/src/datafusion/physical_expr.rs @@ -51,6 +51,15 @@ impl Column { } } +impl PartialEq for Column { + fn eq(&self, other: &dyn Any) -> bool { + other + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } +} + impl PhysicalExpr for Column { fn as_any(&self) -> &dyn Any { self @@ -89,15 +98,6 @@ impl PhysicalExpr for Column { } } -impl PartialEq for Column { - fn eq(&self, other: &dyn Any) -> bool { - other - .downcast_ref::() - .map(|x| self == x) - .unwrap_or(false) - } -} - struct ColumnVisitor { columns: Vec, } @@ -129,7 +129,7 @@ mod tests { use super::*; use arrow_array::{ArrayRef, Float32Array, Int32Array, StringArray, StructArray}; - use arrow_schema::Field; + use arrow_schema::{Field, Fields}; #[test] fn test_simple_column() { @@ -138,10 +138,10 @@ mod tests { Field::new("s", DataType::Utf8, true), Field::new( "st", - DataType::Struct(vec![ + DataType::Struct(Fields::from(vec![ Field::new("x", DataType::Float32, false), Field::new("y", DataType::Float32, false), - ]), + ])), true, ), ])); @@ -169,10 +169,10 @@ mod tests { Field::new("s", DataType::Utf8, true), Field::new( "st", - DataType::Struct(vec![ + DataType::Struct(Fields::from(vec![ Field::new("x", DataType::Float32, false), Field::new("y", DataType::Float32, false), - ]), + ])), true, ), ])); diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index f4976eec2f4..8701b8afc40 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -1017,7 +1017,7 @@ mod tests { let schema = Arc::new(ArrowSchema::new(vec![Field::new( "embeddings", DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Float32, true)), + Arc::new(Field::new("item", DataType::Float32, true)), dimension, ), false, diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index c08b7b1e345..5ea9617269c 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -761,7 +761,7 @@ mod test { ArrowField::new( "vec", DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Float32, true)), + Arc::new(ArrowField::new("item", DataType::Float32, true)), 32, ), true, @@ -843,7 +843,7 @@ mod test { ArrowField::new( "vec", DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Float32, true)), + Arc::new(ArrowField::new("item", DataType::Float32, true)), 32, ), true, @@ -894,7 +894,7 @@ mod test { ArrowField::new( "vec", DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Float32, true)), + Arc::new(ArrowField::new("item", DataType::Float32, true)), 32, ), true, @@ -944,7 +944,7 @@ mod test { ArrowField::new( "vec", DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Float32, true)), + Arc::new(ArrowField::new("item", DataType::Float32, true)), 32, ), true, diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index 5628d6df4bb..97a6a7cf988 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -4,6 +4,7 @@ use std::cmp::max; use std::collections::HashMap; use std::fmt::{self}; use std::fmt::{Debug, Formatter}; +use std::sync::Arc; use arrow_array::cast::{as_large_list_array, as_list_array}; use arrow_array::types::{ @@ -184,7 +185,7 @@ impl TryFrom<&LogicalType> for DataType { .parse::() .map_err(|e: _| Error::Schema(e.to_string()))?; Ok(FixedSizeList( - Box::new(ArrowField::new("item", elem_type, true)), + Arc::new(ArrowField::new("item", elem_type, true)), size, )) } @@ -238,10 +239,10 @@ impl TryFrom<&LogicalType> for DataType { Err(Error::Schema(format!("Unsupported timestamp type: {}", lt))) } else { let timeunit = parse_timeunit(splits[1])?; - let tz = if splits[2] == "-" { + let tz: Option> = if splits[2] == "-" { None } else { - Some(splits[2].to_string()) + Some(splits[2].into()) }; Ok(Timestamp(timeunit, tz)) } @@ -302,9 +303,9 @@ impl Field { /// Returns arrow data type. pub fn data_type(&self) -> DataType { match &self.logical_type { - lt if lt.is_list() => DataType::List(Box::new(ArrowField::from(&self.children[0]))), + lt if lt.is_list() => DataType::List(Arc::new(ArrowField::from(&self.children[0]))), lt if lt.is_large_list() => { - DataType::LargeList(Box::new(ArrowField::from(&self.children[0]))) + DataType::LargeList(Arc::new(ArrowField::from(&self.children[0]))) } lt if lt.is_struct() => { DataType::Struct(self.children.iter().map(ArrowField::from).collect()) @@ -365,8 +366,8 @@ impl Field { panic!("Unsupported dictionary key type: {}", key_type); } }, - DataType::Struct(mut subfields) => { - for (i, f) in subfields.iter_mut().enumerate() { + DataType::Struct(subfields) => { + for (i, f) in subfields.iter().enumerate() { let lance_field = self .children .iter_mut() @@ -577,9 +578,10 @@ impl TryFrom<&ArrowField> for Field { fn try_from(field: &ArrowField) -> Result { let children = match field.data_type() { - DataType::Struct(children) => { - children.iter().map(Self::try_from).collect::>()? - } + DataType::Struct(children) => children + .iter() + .map(|f| Self::try_from(f.as_ref())) + .collect::>()?, DataType::List(item) => vec![Self::try_from(item.as_ref())?], DataType::LargeList(item) => vec![Self::try_from(item.as_ref())?], _ => vec![], @@ -840,7 +842,7 @@ impl TryFrom<&ArrowSchema> for Schema { fields: schema .fields .iter() - .map(Field::try_from) + .map(|f| Field::try_from(f.as_ref())) .collect::>()?, metadata: schema.metadata.clone(), }; @@ -903,7 +905,7 @@ impl From<&Schema> for Vec { mod tests { use super::*; - use arrow_schema::{Field as ArrowField, TimeUnit}; + use arrow_schema::{Field as ArrowField, Fields as ArrowFields, TimeUnit}; #[test] fn arrow_field_to_field() { @@ -937,7 +939,7 @@ mod tests { ), ( "timestamp:s:America/New_York", - DataType::Timestamp(TimeUnit::Second, Some("America/New_York".to_string())), + DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())), ), ("time32:s", DataType::Time32(TimeUnit::Second)), ("time32:ms", DataType::Time32(TimeUnit::Millisecond)), @@ -951,7 +953,7 @@ mod tests { ( "fixed_size_list:int32:10", DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Int32, true)), + Arc::new(ArrowField::new("item", DataType::Int32, true)), 10, ), ), @@ -967,7 +969,7 @@ mod tests { #[test] fn test_nested_types() { assert_eq!( - LogicalType::try_from(&DataType::List(Box::new(ArrowField::new( + LogicalType::try_from(&DataType::List(Arc::new(ArrowField::new( "item", DataType::Binary, false @@ -977,9 +979,9 @@ mod tests { "list" ); assert_eq!( - LogicalType::try_from(&DataType::List(Box::new(ArrowField::new( + LogicalType::try_from(&DataType::List(Arc::new(ArrowField::new( "item", - DataType::Struct(vec![]), + DataType::Struct(ArrowFields::empty()), false )))) .unwrap() @@ -987,11 +989,11 @@ mod tests { "list.struct" ); assert_eq!( - LogicalType::try_from(&DataType::Struct(vec![ArrowField::new( + LogicalType::try_from(&DataType::Struct(ArrowFields::from(vec![ArrowField::new( "item", DataType::Binary, false - )])) + )]))) .unwrap() .0, "struct" @@ -1002,7 +1004,11 @@ mod tests { fn struct_field() { let arrow_field = ArrowField::new( "struct", - DataType::Struct(vec![ArrowField::new("a", DataType::Int32, true)]), + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "a", + DataType::Int32, + true, + )])), false, ); let field = Field::try_from(&arrow_field).unwrap(); @@ -1017,11 +1023,11 @@ mod tests { ArrowField::new("a", DataType::Int32, false), ArrowField::new( "b", - DataType::Struct(vec![ + DataType::Struct(ArrowFields::from(vec![ ArrowField::new("f1", DataType::Utf8, true), ArrowField::new("f2", DataType::Boolean, false), ArrowField::new("f3", DataType::Float32, false), - ]), + ])), true, ), ArrowField::new("c", DataType::Float64, false), @@ -1032,10 +1038,10 @@ mod tests { let expected_arrow_schema = ArrowSchema::new(vec![ ArrowField::new( "b", - DataType::Struct(vec![ + DataType::Struct(ArrowFields::from(vec![ ArrowField::new("f1", DataType::Utf8, true), ArrowField::new("f3", DataType::Float32, false), - ]), + ])), true, ), ArrowField::new("c", DataType::Float64, false), @@ -1049,11 +1055,11 @@ mod tests { ArrowField::new("a", DataType::Int32, false), ArrowField::new( "b", - DataType::Struct(vec![ + DataType::Struct(ArrowFields::from(vec![ ArrowField::new("f1", DataType::Utf8, true), ArrowField::new("f2", DataType::Boolean, false), ArrowField::new("f3", DataType::Float32, false), - ]), + ])), true, ), ArrowField::new("c", DataType::Float64, false), @@ -1064,10 +1070,10 @@ mod tests { let expected_arrow_schema = ArrowSchema::new(vec![ ArrowField::new( "b", - DataType::Struct(vec![ + DataType::Struct(ArrowFields::from(vec![ ArrowField::new("f1", DataType::Utf8, true), ArrowField::new("f3", DataType::Float32, false), - ]), + ])), true, ), ArrowField::new("c", DataType::Float64, false), @@ -1081,11 +1087,11 @@ mod tests { ArrowField::new("a", DataType::Int32, false), ArrowField::new( "b", - DataType::Struct(vec![ + DataType::Struct(ArrowFields::from(vec![ ArrowField::new("f1", DataType::Utf8, true), ArrowField::new("f2", DataType::Boolean, false), ArrowField::new("f3", DataType::Float32, false), - ]), + ])), true, ), ArrowField::new("c", DataType::Float64, false), @@ -1103,11 +1109,11 @@ mod tests { fn test_get_nested_field() { let arrow_schema = ArrowSchema::new(vec![ArrowField::new( "b", - DataType::Struct(vec![ + DataType::Struct(ArrowFields::from(vec![ ArrowField::new("f1", DataType::Utf8, true), ArrowField::new("f2", DataType::Boolean, false), ArrowField::new("f3", DataType::Float32, false), - ]), + ])), true, )]); let schema = Schema::try_from(&arrow_schema).unwrap(); @@ -1122,11 +1128,11 @@ mod tests { ArrowField::new("a", DataType::Int32, false), ArrowField::new( "b", - DataType::Struct(vec![ + DataType::Struct(ArrowFields::from(vec![ ArrowField::new("f1", DataType::Utf8, true), ArrowField::new("f2", DataType::Boolean, false), ArrowField::new("f3", DataType::Float32, false), - ]), + ])), true, ), ArrowField::new("c", DataType::Float64, false), @@ -1139,7 +1145,11 @@ mod tests { let expected_arrow_schema = ArrowSchema::new(vec![ ArrowField::new( "b", - DataType::Struct(vec![ArrowField::new("f1", DataType::Utf8, true)]), + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "f1", + DataType::Utf8, + true, + )])), true, ), ArrowField::new("c", DataType::Float64, false), diff --git a/rust/src/encodings/binary.rs b/rust/src/encodings/binary.rs index c68e949bfc2..fe831e37cdb 100644 --- a/rust/src/encodings/binary.rs +++ b/rust/src/encodings/binary.rs @@ -15,6 +15,7 @@ //! Var-length binary encoding. //! +use std::borrow::Borrow; use std::marker::PhantomData; use std::ops::{Range, RangeFrom, RangeFull, RangeTo}; use std::sync::Arc; @@ -73,7 +74,7 @@ impl<'a> BinaryEncoder<'a> { let end = offsets[offsets.len() - 1].as_usize(); let b = unsafe { std::slice::from_raw_parts( - arr.data().buffers()[1].as_ptr().offset(start as isize), + arr.to_data().buffers()[1].as_ptr().offset(start as isize), end - start, ) }; @@ -91,7 +92,7 @@ impl<'a> BinaryEncoder<'a> { let positions_offset = self.writer.tell(); let pos_array = pos_builder.finish(); self.writer - .write_all(pos_array.data().buffers()[0].as_slice()) + .write_all(pos_array.to_data().buffers()[0].as_slice()) .await?; Ok(positions_offset) } @@ -193,7 +194,7 @@ impl<'a, T: ByteArrayType> BinaryDecoder<'a, T> { let end = positions.value(range.end); let slice = positions.slice(range.start, range.len() + 1); - let position_slice: &Int64Array = as_primitive_array(slice.as_ref()); + let position_slice: &Int64Array = as_primitive_array(slice.borrow()); let offset_data = if T::Offset::IS_LARGE { subtract_scalar(position_slice, start)?.into_data() } else { @@ -391,7 +392,6 @@ mod tests { use super::*; use arrow_select::concat::concat; - use arrow_array::cast::as_string_array; use arrow_array::{ new_empty_array, types::GenericStringType, GenericStringArray, LargeStringArray, OffsetSizeTrait, StringArray, @@ -468,9 +468,8 @@ mod tests { #[tokio::test] async fn test_write_binary_data_with_offset() { - let slice = StringArray::from(vec![Some("d"), Some("e")]).slice(1, 1); - let array = as_string_array(slice.as_ref()); - test_round_trips(&[array]).await; + let array: StringArray = StringArray::from(vec![Some("d"), Some("e")]).slice(1, 1); + test_round_trips(&[&array]).await; } #[tokio::test] @@ -582,10 +581,7 @@ mod tests { let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap(); let mut encoder = BinaryEncoder::new(&mut object_writer); for i in 0..10 { - let pos = encoder - .encode(&[data.slice(i * 10, 10).as_ref()]) - .await - .unwrap(); + let pos = encoder.encode(&[&data.slice(i * 10, 10)]).await.unwrap(); assert_eq!(pos, (i * (8 * 11) /* offset array */ + (i + 1) * (10 * 10))); } } diff --git a/rust/src/encodings/plain.rs b/rust/src/encodings/plain.rs index 81e01077971..45ab86e0da8 100644 --- a/rust/src/encodings/plain.rs +++ b/rust/src/encodings/plain.rs @@ -91,7 +91,7 @@ impl<'a> PlainEncoder<'a> { let boolean_array = builder.finish(); self.writer - .write_all(boolean_array.data().buffers()[0].as_slice()) + .write_all(boolean_array.into_data().buffers()[0].as_slice()) .await?; Ok(()) } @@ -111,7 +111,7 @@ impl<'a> PlainEncoder<'a> { } else { let byte_width = data_type.byte_width(); for a in arrays.iter() { - let data = a.data(); + let data = a.to_data(); let slice = unsafe { from_raw_parts( data.buffers()[0].as_ptr().add(a.offset() * byte_width), @@ -552,7 +552,7 @@ mod tests { for t in int_types { let buffer = Buffer::from_slice_ref(input.as_slice()); let list_type = - DataType::FixedSizeList(Box::new(Field::new("item", t.clone(), true)), 3); + DataType::FixedSizeList(Arc::new(Field::new("item", t.clone(), true)), 3); let mut arrs: Vec = Vec::new(); for _ in 0..10 { @@ -580,8 +580,8 @@ mod tests { #[tokio::test] async fn test_encode_decode_nested_fixed_size_list() { // FixedSizeList of FixedSizeList - let inner = DataType::FixedSizeList(Box::new(Field::new("item", DataType::Int64, true)), 2); - let t = DataType::FixedSizeList(Box::new(Field::new("item", inner, true)), 2); + let inner = DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 2); + let t = DataType::FixedSizeList(Arc::new(Field::new("item", inner, true)), 2); let mut arrs: Vec = Vec::new(); for _ in 0..10 { @@ -595,7 +595,7 @@ mod tests { // FixedSizeList of FixedSizeBinary let inner = DataType::FixedSizeBinary(2); - let t = DataType::FixedSizeList(Box::new(Field::new("item", inner, true)), 2); + let t = DataType::FixedSizeList(Arc::new(Field::new("item", inner, true)), 2); let mut arrs: Vec = Vec::new(); for _ in 0..10 { @@ -722,7 +722,9 @@ mod tests { for i in (0..1000).step_by(4) { let data = array.slice(i, 4); file_writer - .write(&[&RecordBatch::try_new(arrow_schema.clone(), vec![data]).unwrap()]) + .write(&[ + &RecordBatch::try_new(arrow_schema.clone(), vec![Arc::new(data)]).unwrap(), + ]) .await .unwrap(); } @@ -776,14 +778,11 @@ mod tests { let mut file_writer = FileWriter::try_new(&store, &path, &schema).await.unwrap(); for i in (0..100).step_by(4) { - let data = fixed_size_list.slice(i, 4); - let slice: &FixedSizeListArray = as_fixed_size_list_array(data.as_ref()); + let slice: FixedSizeListArray = fixed_size_list.slice(i, 4); file_writer - .write(&[&RecordBatch::try_new( - arrow_schema.clone(), - vec![Arc::new(slice.clone())], - ) - .unwrap()]) + .write(&[ + &RecordBatch::try_new(arrow_schema.clone(), vec![Arc::new(slice)]).unwrap(), + ]) .await .unwrap(); } diff --git a/rust/src/format/fragment.rs b/rust/src/format/fragment.rs index 10d2c4dfb53..2e796a6b340 100644 --- a/rust/src/format/fragment.rs +++ b/rust/src/format/fragment.rs @@ -113,7 +113,9 @@ impl From<&Fragment> for pb::DataFragment { #[cfg(test)] mod tests { use super::*; - use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; + use arrow_schema::{ + DataType, Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema, + }; #[test] fn test_new_fragment() { @@ -122,10 +124,10 @@ mod tests { let arrow_schema = ArrowSchema::new(vec![ ArrowField::new( "s", - DataType::Struct(vec![ + DataType::Struct(ArrowFields::from(vec![ ArrowField::new("si", DataType::Int32, false), ArrowField::new("sb", DataType::Binary, true), - ]), + ])), true, ), ArrowField::new("bool", DataType::Boolean, true), diff --git a/rust/src/index/vector/diskann.rs b/rust/src/index/vector/diskann.rs index ceea6364581..6d57bb64fa9 100644 --- a/rust/src/index/vector/diskann.rs +++ b/rust/src/index/vector/diskann.rs @@ -132,7 +132,7 @@ mod tests { let schema = Arc::new(ArrowSchema::new(vec![Field::new( "embeddings", DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Float32, true)), + Arc::new(Field::new("item", DataType::Float32, true)), dimension, ), false, diff --git a/rust/src/index/vector/diskann/builder.rs b/rust/src/index/vector/diskann/builder.rs index a8fe147032f..bd64fa77f81 100644 --- a/rust/src/index/vector/diskann/builder.rs +++ b/rust/src/index/vector/diskann/builder.rs @@ -388,7 +388,7 @@ mod tests { let schema = Arc::new(ArrowSchema::new(vec![Field::new( "vector", DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Float32, true)), + Arc::new(Field::new("item", DataType::Float32, true)), dim as i32, ), true, diff --git a/rust/src/index/vector/graph/persisted.rs b/rust/src/index/vector/graph/persisted.rs index 1f4f6c54a06..adbd8c42bd9 100644 --- a/rust/src/index/vector/graph/persisted.rs +++ b/rust/src/index/vector/graph/persisted.rs @@ -213,7 +213,7 @@ pub(crate) async fn write_graph( ), Field::new( NEIGHBORS_COL, - DataType::List(Box::new(Field::new("item", DataType::UInt32, true))), + DataType::List(Arc::new(Field::new("item", DataType::UInt32, true))), false, ), ])); diff --git a/rust/src/index/vector/ivf.rs b/rust/src/index/vector/ivf.rs index 962b0b216ba..77ec1c600ad 100644 --- a/rust/src/index/vector/ivf.rs +++ b/rust/src/index/vector/ivf.rs @@ -347,7 +347,7 @@ impl Ivf { ArrowField::new( RESIDUAL_COLUMN, DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Float32, true)), + Arc::new(ArrowField::new("item", DataType::Float32, true)), dim as i32, ), false, @@ -589,7 +589,7 @@ pub async fn build_ivf_pq_index( ArrowField::new( PQ_CODE_COLUMN, DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::UInt8, true)), + Arc::new(ArrowField::new("item", DataType::UInt8, true)), pq_params.num_sub_vectors as i32, ), false, diff --git a/rust/src/index/vector/kmeans.rs b/rust/src/index/vector/kmeans.rs index 271140c438b..27b947246fc 100644 --- a/rust/src/index/vector/kmeans.rs +++ b/rust/src/index/vector/kmeans.rs @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::borrow::Borrow; use std::sync::Arc; use arrow_array::{ - builder::Float32Builder, cast::as_primitive_array, types::Float32Type, Array, Float32Array, + builder::Float32Builder, cast::as_primitive_array, types::Float32Type, Float32Array, }; use rand::{seq::IteratorRandom, Rng}; @@ -56,7 +57,7 @@ pub async fn train_kmeans( let mut builder = Float32Builder::with_capacity(sample_size * dimension); for idx in chosen.iter() { let s = array.slice(idx * dimension, dimension); - builder.append_slice(as_primitive_array::(s.as_ref()).values()); + builder.append_slice(as_primitive_array::(s.borrow()).values()); } builder.finish() } else { diff --git a/rust/src/index/vector/opq.rs b/rust/src/index/vector/opq.rs index 9b547f9b136..6719db0fbfc 100644 --- a/rust/src/index/vector/opq.rs +++ b/rust/src/index/vector/opq.rs @@ -351,7 +351,7 @@ mod tests { let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( "vector", DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Float32, true)), + Arc::new(ArrowField::new("item", DataType::Float32, true)), 64, ), true, @@ -434,10 +434,10 @@ mod tests { // R^T * R = I let i = r.transpose().dot(&r).unwrap(); - assert_relative_eq!( - i.data().values(), - MatrixView::identity(dim).data().values(), - epsilon = 0.001 - ); + let expected = i.data().values().to_vec(); + let result = MatrixView::identity(dim).data().values().to_vec(); + expected.iter().zip(result).for_each(|(&e, r)| { + assert_relative_eq!(e, r, epsilon = 0.001); + }); } } diff --git a/rust/src/index/vector/pq.rs b/rust/src/index/vector/pq.rs index e1e699e1835..ae3e7fe654c 100644 --- a/rust/src/index/vector/pq.rs +++ b/rust/src/index/vector/pq.rs @@ -137,8 +137,7 @@ impl PQIndex { let sub_vector_length = self.dimension / self.num_sub_vectors; for i in 0..self.num_sub_vectors { - let slice = key.slice(i * sub_vector_length, sub_vector_length); - let key_sub_vector: &Float32Array = as_primitive_array(slice.as_ref()); + let key_sub_vector: Float32Array = key.slice(i * sub_vector_length, sub_vector_length); let sub_vector_centroids = self.pq.centroids(i).ok_or_else(|| { Error::Index("PQIndex::cosine_scores: PQ is not initialized".to_string()) })?; @@ -501,12 +500,11 @@ impl ProductQuantizer { let code: &UInt8Array = as_primitive_array(code_arr.as_ref()); for sub_vec_id in 0..code.len() { let centroid = code.value(sub_vec_id) as usize; - let sub_vector = data.data().slice( + let sub_vector: Float32Array = data.data().slice( i * self.dimension + sub_vec_id * sub_vector_dim, sub_vector_dim, ); counts[sub_vec_id * num_centroids + centroid] += 1; - let sub_vector: &Float32Array = as_primitive_array(sub_vector.as_ref()); for k in 0..sub_vector.len() { sum[sub_vec_id * sum_stride + centroid * sub_vector_dim + k] += sub_vector.value(k); @@ -637,7 +635,7 @@ pub(crate) async fn train_pq( mod tests { use super::*; - use approx::assert_relative_eq; + use approx::relative_eq; use arrow_array::types::Float32Type; #[test] @@ -696,10 +694,14 @@ mod tests { actual_pq.train(&mat, MetricType::L2, 1).await.unwrap(); } - assert_relative_eq!( - pq.codebook.unwrap().values(), - actual_pq.codebook.unwrap().values(), - epsilon = 0.01 - ); + let result = pq.codebook.unwrap(); + let expected = actual_pq.codebook.unwrap(); + result + .values() + .iter() + .zip(expected.values()) + .for_each(|(&r, &e)| { + assert!(relative_eq!(r, e)); + }); } } diff --git a/rust/src/io/exec/knn.rs b/rust/src/io/exec/knn.rs index 22fa19244d7..a6d3d7fd33f 100644 --- a/rust/src/io/exec/knn.rs +++ b/rust/src/io/exec/knn.rs @@ -159,7 +159,7 @@ impl ExecutionPlan for KNNFlatExec { let input_schema = self.input.schema(); let mut fields = input_schema.fields().to_vec(); if !input_schema.field_with_name(SCORE_COL).is_ok() { - fields.push(Field::new(SCORE_COL, DataType::Float32, false)); + fields.push(Arc::new(Field::new(SCORE_COL, DataType::Float32, false))); } Arc::new(Schema::new_with_metadata( @@ -392,7 +392,7 @@ mod tests { ArrowField::new( "vector", DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Float32, true)), + Arc::new(ArrowField::new("item", DataType::Float32, true)), 128, ), true, @@ -476,7 +476,7 @@ mod tests { ArrowField::new( "vector", DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Float32, true)), + Arc::new(ArrowField::new("item", DataType::Float32, true)), dim as i32, ), true, @@ -504,7 +504,7 @@ mod tests { ArrowField::new( "vector", DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Float32, true)), + Arc::new(ArrowField::new("item", DataType::Float32, true)), dim as i32, ), true, diff --git a/rust/src/io/exec/planner.rs b/rust/src/io/exec/planner.rs index 7a84d4ac8a3..31167116057 100644 --- a/rust/src/io/exec/planner.rs +++ b/rust/src/io/exec/planner.rs @@ -254,7 +254,7 @@ mod tests { use arrow_array::{ ArrayRef, BooleanArray, Float32Array, Int32Array, RecordBatch, StringArray, StructArray, }; - use arrow_schema::{DataType, Field, Schema}; + use arrow_schema::{DataType, Field, Fields, Schema}; use datafusion::logical_expr::{col, lit}; #[test] @@ -264,10 +264,10 @@ mod tests { Field::new("s", DataType::Utf8, true), Field::new( "st", - DataType::Struct(vec![ + DataType::Struct(Fields::from(vec![ Field::new("x", DataType::Float32, false), Field::new("y", DataType::Float32, false), - ]), + ])), true, ), ])); diff --git a/rust/src/io/exec/scan.rs b/rust/src/io/exec/scan.rs index c56751aa8fe..49d8cdfee97 100644 --- a/rust/src/io/exec/scan.rs +++ b/rust/src/io/exec/scan.rs @@ -203,8 +203,8 @@ impl ExecutionPlan for LanceScanExec { fn schema(&self) -> SchemaRef { let schema: ArrowSchema = self.projection.as_ref().into(); if self.with_row_id { - let mut fields = schema.fields; - fields.push(Field::new(ROW_ID, DataType::UInt64, false)); + let mut fields: Vec> = schema.fields.to_vec(); + fields.push(Arc::new(Field::new(ROW_ID, DataType::UInt64, false))); Arc::new(ArrowSchema::new(fields)) } else { Arc::new(schema) diff --git a/rust/src/io/reader.rs b/rust/src/io/reader.rs index 3af643d930d..86423cba8e8 100644 --- a/rust/src/io/reader.rs +++ b/rust/src/io/reader.rs @@ -574,7 +574,7 @@ mod tests { Array, DictionaryArray, Float32Array, Int64Array, LargeListArray, ListArray, NullArray, RecordBatchReader, StringArray, StructArray, UInt32Array, UInt8Array, }; - use arrow_schema::{Field as ArrowField, Schema as ArrowSchema}; + use arrow_schema::{Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema}; use tempfile::tempdir; use crate::io::FileWriter; @@ -702,7 +702,11 @@ mod tests { async fn test_write_null_string_in_struct(field_nullable: bool) { let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( "parent", - DataType::Struct(vec![ArrowField::new("str", DataType::Utf8, field_nullable)]), + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "str", + DataType::Utf8, + field_nullable, + )])), true, )])); @@ -795,14 +799,14 @@ mod tests { DataType::Struct(subfields) => subfields .iter() .zip(expected_columns) - .map(|(f, d)| (f.clone(), d)) + .map(|(f, d)| (f.as_ref().clone(), d)) .collect::>(), _ => panic!("unexpected field"), }; let expected_struct_array = StructArray::from(expected_struct); let expected_batch = RecordBatch::from(&StructArray::from(vec![( - arrow_schema.fields[0].clone(), + arrow_schema.fields[0].as_ref().clone(), Arc::new(expected_struct_array) as ArrayRef, )])); @@ -815,23 +819,23 @@ mod tests { fn make_schema_of_list_array() -> Arc { Arc::new(ArrowSchema::new(vec![ArrowField::new( "s", - DataType::Struct(vec![ + DataType::Struct(ArrowFields::from(vec![ ArrowField::new( "li", - DataType::List(Box::new(ArrowField::new("item", DataType::Int32, true))), + DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))), true, ), ArrowField::new( "ls", - DataType::List(Box::new(ArrowField::new("item", DataType::Utf8, true))), + DataType::List(Arc::new(ArrowField::new("item", DataType::Utf8, true))), true, ), ArrowField::new( "ll", - DataType::LargeList(Box::new(ArrowField::new("item", DataType::Int32, true))), + DataType::LargeList(Arc::new(ArrowField::new("item", DataType::Int32, true))), false, ), - ]), + ])), true, )])) } @@ -857,7 +861,7 @@ mod tests { ( ArrowField::new( "li", - DataType::List(Box::new(ArrowField::new("item", DataType::Int32, true))), + DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))), true, ), Arc::new(li_builder.finish()) as ArrayRef, @@ -865,7 +869,7 @@ mod tests { ( ArrowField::new( "ls", - DataType::List(Box::new(ArrowField::new("item", DataType::Utf8, true))), + DataType::List(Arc::new(ArrowField::new("item", DataType::Utf8, true))), true, ), Arc::new(ls_builder.finish()) as ArrayRef, @@ -873,7 +877,7 @@ mod tests { ( ArrowField::new( "ll", - DataType::LargeList(Box::new(ArrowField::new("item", DataType::Int32, true))), + DataType::LargeList(Arc::new(ArrowField::new("item", DataType::Int32, true))), false, ), Arc::new(large_list_builder.finish()) as ArrayRef, @@ -887,11 +891,11 @@ mod tests { let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( "s", - DataType::Struct(vec![ArrowField::new( + DataType::Struct(ArrowFields::from(vec![ArrowField::new( "d", DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), true, - )]), + )])), true, )])); @@ -1022,12 +1026,12 @@ mod tests { let arrow_schema = ArrowSchema::new(vec![ ArrowField::new( "l", - DataType::List(Box::new(ArrowField::new("item", DataType::Int32, true))), + DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))), false, ), ArrowField::new( "ll", - DataType::LargeList(Box::new(ArrowField::new("item", DataType::Int32, true))), + DataType::LargeList(Arc::new(ArrowField::new("item", DataType::Int32, true))), false, ), ]); @@ -1092,12 +1096,12 @@ mod tests { let arrow_schema = ArrowSchema::new(vec![ ArrowField::new( "l", - DataType::List(Box::new(ArrowField::new("item", DataType::Int32, true))), + DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))), false, ), ArrowField::new( "ll", - DataType::LargeList(Box::new(ArrowField::new("item", DataType::Int32, true))), + DataType::LargeList(Arc::new(ArrowField::new("item", DataType::Int32, true))), false, ), ]); diff --git a/rust/src/io/writer.rs b/rust/src/io/writer.rs index d1c26cacbcb..004b23ff358 100644 --- a/rust/src/io/writer.rs +++ b/rust/src/io/writer.rs @@ -372,7 +372,9 @@ mod tests { TimestampSecondArray, UInt8Array, }; use arrow_buffer::i256; - use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, TimeUnit}; + use arrow_schema::{ + DataType, Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema, TimeUnit, + }; use object_store::path::Path; use crate::io::{FileReader, ObjectStore}; @@ -411,7 +413,7 @@ mod tests { ArrowField::new( "fixed_size_list", DataType::FixedSizeList( - Box::new(ArrowField::new("item", DataType::Float32, true)), + Arc::new(ArrowField::new("item", DataType::Float32, true)), 16, ), true, @@ -419,17 +421,17 @@ mod tests { ArrowField::new("fixed_size_binary", DataType::FixedSizeBinary(8), true), ArrowField::new( "l", - DataType::List(Box::new(ArrowField::new("item", DataType::Utf8, true))), + DataType::List(Arc::new(ArrowField::new("item", DataType::Utf8, true))), true, ), ArrowField::new( "large_l", - DataType::LargeList(Box::new(ArrowField::new("item", DataType::Utf8, true))), + DataType::LargeList(Arc::new(ArrowField::new("item", DataType::Utf8, true))), true, ), ArrowField::new( "l_dict", - DataType::List(Box::new(ArrowField::new( + DataType::List(Arc::new(ArrowField::new( "item", DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), true, @@ -438,7 +440,7 @@ mod tests { ), ArrowField::new( "large_l_dict", - DataType::LargeList(Box::new(ArrowField::new( + DataType::LargeList(Arc::new(ArrowField::new( "item", DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), true, @@ -447,10 +449,10 @@ mod tests { ), ArrowField::new( "s", - DataType::Struct(vec![ + DataType::Struct(ArrowFields::from(vec![ ArrowField::new("si", DataType::Int64, true), ArrowField::new("sb", DataType::Utf8, true), - ]), + ])), true, ), ]); @@ -609,10 +611,7 @@ mod tests { ), ArrowField::new( "ts_tz", - DataType::Timestamp( - TimeUnit::Microsecond, - Some("America/Los_Angeles".to_string()), - ), + DataType::Timestamp(TimeUnit::Microsecond, Some("America/Los_Angeles".into())), false, ), ])); diff --git a/rust/src/utils/kmeans.rs b/rust/src/utils/kmeans.rs index 231308f858b..1a0957be408 100644 --- a/rust/src/utils/kmeans.rs +++ b/rust/src/utils/kmeans.rs @@ -102,8 +102,8 @@ async fn kmean_plusplus( assert!(data.len() > k * dimension); let mut kmeans = KMeans::empty(k, dimension, metric_type); let first_idx = rng.gen_range(0..data.len() / dimension); - let first_vector = data.slice(first_idx * dimension, dimension); - kmeans.centroids = Arc::new(as_primitive_array(first_vector.as_ref()).clone()); + let first_vector: Float32Array = data.slice(first_idx * dimension, dimension); + kmeans.centroids = Arc::new(first_vector); let mut seen = HashSet::new(); seen.insert(first_idx); @@ -120,8 +120,7 @@ async fn kmean_plusplus( } } - let slice = data.slice(chosen * dimension, dimension); - let new_vector: &Float32Array = as_primitive_array(slice.as_ref()); + let new_vector: Float32Array = data.slice(chosen * dimension, dimension); let new_centroid_values = Float32Array::from_iter_values( kmeans @@ -208,7 +207,7 @@ impl KMeanMembership { for i in 0..cluster_ids.len() { if cluster_ids[i] as usize == cluster { sum = - add(&sum, as_primitive_array(data.slice(i * dimension, dimension).as_ref())).unwrap(); + add(&sum, &data.slice(i * dimension, dimension)).unwrap(); total += 1.0; }; } @@ -216,8 +215,7 @@ impl KMeanMembership { divide_scalar(&sum, total).unwrap() } else { eprintln!("Warning: KMean: cluster {cluster} has no value, does not change centroids."); - let prev_centroids = prev_centroids.slice(cluster * dimension, dimension); - as_primitive_array(prev_centroids.as_ref()).clone() + prev_centroids.slice(cluster * dimension, dimension) } }) .await