From 76d2a39051a797892371eccd8723406efa7e6399 Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Mon, 13 Jun 2022 07:28:39 +0200 Subject: [PATCH] Added support to read and write f16 (#1051) --- README.md | 1 - src/array/mod.rs | 3 +- src/array/primitive/mod.rs | 6 +- src/compute/arithmetics/mod.rs | 2 +- src/compute/cast/mod.rs | 7 ++ src/compute/cast/primitive_to.rs | 7 +- src/compute/comparison/mod.rs | 7 +- src/compute/comparison/simd/native.rs | 4 +- src/compute/comparison/simd/packed.rs | 4 +- src/datatypes/mod.rs | 3 +- src/io/json_integration/read/array.rs | 1 + src/types/mod.rs | 3 + src/types/native.rs | 110 ++++++++++++++++++++++++++ src/types/simd/mod.rs | 4 +- src/types/simd/native.rs | 1 + tests/it/compute/cast.rs | 2 + tests/it/compute/comparison.rs | 1 + 17 files changed, 155 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 35b42244944..07ebc960de5 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,6 @@ documentation of each of its APIs. ## Features * Most feature-complete implementation of Apache Arrow after the reference implementation (C++) - * Float 16 unsupported (not a Rust native type) * Decimal 256 unsupported (not a Rust native type) * C data interface supported for all Arrow types (read and write) * C stream interface supported for all Arrow types (read and write) diff --git a/src/array/mod.rs b/src/array/mod.rs index e004bb0e778..427dae7f651 100644 --- a/src/array/mod.rs +++ b/src/array/mod.rs @@ -203,7 +203,7 @@ macro_rules! with_match_primitive_type {( ) => ({ macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} use crate::datatypes::PrimitiveType::*; - use crate::types::{days_ms, months_days_ns}; + use crate::types::{days_ms, months_days_ns, f16}; match $key_type { Int8 => __with_ty__! { i8 }, Int16 => __with_ty__! { i16 }, @@ -216,6 +216,7 @@ macro_rules! with_match_primitive_type {( UInt16 => __with_ty__! { u16 }, UInt32 => __with_ty__! { u32 }, UInt64 => __with_ty__! { u64 }, + Float16 => __with_ty__! { f16 }, Float32 => __with_ty__! { f32 }, Float64 => __with_ty__! { f64 }, } diff --git a/src/array/primitive/mod.rs b/src/array/primitive/mod.rs index b5eee572511..de55b6bb284 100644 --- a/src/array/primitive/mod.rs +++ b/src/array/primitive/mod.rs @@ -7,7 +7,7 @@ use crate::{ datatypes::*, error::Error, trusted_len::TrustedLen, - types::{days_ms, months_days_ns, NativeType}, + types::{days_ms, f16, months_days_ns, NativeType}, }; use super::Array; @@ -468,6 +468,8 @@ pub type Int128Array = PrimitiveArray; pub type DaysMsArray = PrimitiveArray; /// A type definition [`PrimitiveArray`] for [`months_days_ns`] pub type MonthsDaysNsArray = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `f16` +pub type Float16Array = PrimitiveArray; /// A type definition [`PrimitiveArray`] for `f32` pub type Float32Array = PrimitiveArray; /// A type definition [`PrimitiveArray`] for `f64` @@ -495,6 +497,8 @@ pub type Int128Vec = MutablePrimitiveArray; pub type DaysMsVec = MutablePrimitiveArray; /// A type definition [`MutablePrimitiveArray`] for [`months_days_ns`] pub type MonthsDaysNsVec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `f16` +pub type Float16Vec = MutablePrimitiveArray; /// A type definition [`MutablePrimitiveArray`] for `f32` pub type Float32Vec = MutablePrimitiveArray; /// A type definition [`MutablePrimitiveArray`] for `f64` diff --git a/src/compute/arithmetics/mod.rs b/src/compute/arithmetics/mod.rs index 5766385d8a4..2b6bb7e6c0c 100644 --- a/src/compute/arithmetics/mod.rs +++ b/src/compute/arithmetics/mod.rs @@ -414,7 +414,7 @@ macro_rules! with_match_negatable {( Int128 => __with_ty__! { i128 }, DaysMs => __with_ty__! { days_ms }, MonthDayNano => __with_ty__! { months_days_ns }, - UInt8 | UInt16 | UInt32 | UInt64=> todo!(), + UInt8 | UInt16 | UInt32 | UInt64 | Float16 => todo!(), Float32 => __with_ty__! { f32 }, Float64 => __with_ty__! { f64 }, } diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index de3fcff43d8..d3c37cd73fb 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -215,6 +215,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Int64, Float64) => true, (Int64, Decimal(_, _)) => true, + (Float16, Float32) => true, + (Float32, UInt8) => true, (Float32, UInt16) => true, (Float32, UInt32) => true, @@ -736,6 +738,11 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + (Float16, Float32) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(f16_to_f32(from).boxed()) + } + (Float32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Float32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), (Float32, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), diff --git a/src/compute/cast/primitive_to.rs b/src/compute/cast/primitive_to.rs index d064ab6d9b9..9fdbe2f4ebe 100644 --- a/src/compute/cast/primitive_to.rs +++ b/src/compute/cast/primitive_to.rs @@ -4,7 +4,7 @@ use num_traits::{AsPrimitive, Float, ToPrimitive}; use crate::datatypes::IntervalUnit; use crate::error::Result; -use crate::types::{days_ms, months_days_ns}; +use crate::types::{days_ms, f16, months_days_ns}; use crate::{ array::*, bitmap::Bitmap, @@ -581,3 +581,8 @@ pub fn months_to_months_days_ns(from: &PrimitiveArray) -> PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x.as_f32(), DataType::Float32) +} diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index 9ed47c16a79..8f67856b90e 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -81,6 +81,7 @@ macro_rules! match_eq_ord {( UInt16 => __with_ty__! { u16 }, UInt32 => __with_ty__! { u32 }, UInt64 => __with_ty__! { u64 }, + Float16 => todo!(), Float32 => __with_ty__! { f32 }, Float64 => __with_ty__! { f64 }, } @@ -91,7 +92,7 @@ macro_rules! match_eq {( ) => ({ macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} use crate::datatypes::PrimitiveType::*; - use crate::types::{days_ms, months_days_ns}; + use crate::types::{days_ms, months_days_ns, f16}; match $key_type { Int8 => __with_ty__! { i8 }, Int16 => __with_ty__! { i16 }, @@ -104,6 +105,7 @@ macro_rules! match_eq {( UInt16 => __with_ty__! { u16 }, UInt32 => __with_ty__! { u32 }, UInt64 => __with_ty__! { u64 }, + Float16 => __with_ty__! { f16 }, Float32 => __with_ty__! { f32 }, Float64 => __with_ty__! { f64 }, } @@ -487,7 +489,8 @@ fn can_partial_eq(data_type: &DataType) -> bool { can_partial_eq_and_ord(data_type) || matches!( data_type.to_logical_type(), - DataType::Interval(IntervalUnit::DayTime) + DataType::Float16 + | DataType::Interval(IntervalUnit::DayTime) | DataType::Interval(IntervalUnit::MonthDayNano) ) } diff --git a/src/compute/comparison/simd/native.rs b/src/compute/comparison/simd/native.rs index a4a760bf1e9..b760620a602 100644 --- a/src/compute/comparison/simd/native.rs +++ b/src/compute/comparison/simd/native.rs @@ -1,7 +1,7 @@ use std::convert::TryInto; use super::{set, Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; -use crate::types::{days_ms, months_days_ns}; +use crate::types::{days_ms, f16, months_days_ns}; simd8_native_all!(u8); simd8_native_all!(u16); @@ -12,6 +12,8 @@ simd8_native_all!(i16); simd8_native_all!(i32); simd8_native_all!(i128); simd8_native_all!(i64); +simd8_native!(f16); +simd8_native_partial_eq!(f16); simd8_native_all!(f32); simd8_native_all!(f64); simd8_native!(days_ms); diff --git a/src/compute/comparison/simd/packed.rs b/src/compute/comparison/simd/packed.rs index 9348cbed484..ea4684d1dff 100644 --- a/src/compute/comparison/simd/packed.rs +++ b/src/compute/comparison/simd/packed.rs @@ -2,7 +2,7 @@ use std::convert::TryInto; use std::simd::ToBitMask; use crate::types::simd::*; -use crate::types::{days_ms, months_days_ns}; +use crate::types::{days_ms, f16, months_days_ns}; use super::*; @@ -71,6 +71,8 @@ simd8!(i16, i16x8); simd8!(i32, i32x8); simd8!(i64, i64x8); simd8_native_all!(i128); +simd8_native!(f16); +simd8_native_partial_eq!(f16); simd8!(f32, f32x8); simd8!(f64, f64x8); simd8_native!(days_ms); diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index e7e2c9b98e8..4be369da0a2 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -237,7 +237,7 @@ impl DataType { UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16), UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32), UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64), - Float16 => unreachable!(), + Float16 => PhysicalType::Primitive(PrimitiveType::Float16), Float32 => PhysicalType::Primitive(PrimitiveType::Float32), Float64 => PhysicalType::Primitive(PrimitiveType::Float64), Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs), @@ -299,6 +299,7 @@ impl From for DataType { PrimitiveType::UInt32 => DataType::UInt32, PrimitiveType::UInt64 => DataType::UInt64, PrimitiveType::Int128 => DataType::Decimal(32, 32), + PrimitiveType::Float16 => DataType::Float16, PrimitiveType::Float32 => DataType::Float32, PrimitiveType::Float64 => DataType::Float64, PrimitiveType::DaysMs => DataType::Interval(IntervalUnit::DayTime), diff --git a/src/io/json_integration/read/array.rs b/src/io/json_integration/read/array.rs index c0cc688b828..e434cbe6918 100644 --- a/src/io/json_integration/read/array.rs +++ b/src/io/json_integration/read/array.rs @@ -289,6 +289,7 @@ pub fn to_array( Primitive(PrimitiveType::UInt16) => Ok(Box::new(to_primitive::(json_col, data_type))), Primitive(PrimitiveType::UInt32) => Ok(Box::new(to_primitive::(json_col, data_type))), Primitive(PrimitiveType::UInt64) => Ok(Box::new(to_primitive::(json_col, data_type))), + Primitive(PrimitiveType::Float16) => todo!(), Primitive(PrimitiveType::Float32) => Ok(Box::new(to_primitive::(json_col, data_type))), Primitive(PrimitiveType::Float64) => Ok(Box::new(to_primitive::(json_col, data_type))), Binary => Ok(to_binary::(json_col, data_type)), diff --git a/src/types/mod.rs b/src/types/mod.rs index 93e5b0667ce..3ba86ad59ec 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -55,6 +55,8 @@ pub enum PrimitiveType { UInt32, /// An unsigned 64-bit integer. UInt64, + /// A 16-bit floating point number. + Float16, /// A 32-bit floating point number. Float32, /// A 64-bit floating point number. @@ -77,6 +79,7 @@ mod private { impl Sealed for i32 {} impl Sealed for i64 {} impl Sealed for i128 {} + impl Sealed for super::f16 {} impl Sealed for f32 {} impl Sealed for f64 {} impl Sealed for super::days_ms {} diff --git a/src/types/native.rs b/src/types/native.rs index 5a3884275ed..fbc8ac408b3 100644 --- a/src/types/native.rs +++ b/src/types/native.rs @@ -325,3 +325,113 @@ impl Neg for months_days_ns { Self::new(-self.months(), -self.days(), -self.ns()) } } + +/// Type representation of the Float16 physical type +#[derive(Copy, Clone, Default, Zeroable, Pod)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct f16(pub u16); + +impl PartialEq for f16 { + #[inline] + fn eq(&self, other: &f16) -> bool { + if self.is_nan() || other.is_nan() { + false + } else { + (self.0 == other.0) || ((self.0 | other.0) & 0x7FFFu16 == 0) + } + } +} + +// see https://github.com/starkat99/half-rs/blob/main/src/binary16.rs +impl f16 { + #[inline] + #[must_use] + pub(crate) const fn is_nan(self) -> bool { + self.0 & 0x7FFFu16 > 0x7C00u16 + } + + /// Casts this `f16` to `f32` + #[inline] + pub fn as_f32(self) -> f32 { + let i = self.0; + // Check for signed zero + if i & 0x7FFFu16 == 0 { + return f32::from_bits((i as u32) << 16); + } + + let half_sign = (i & 0x8000u16) as u32; + let half_exp = (i & 0x7C00u16) as u32; + let half_man = (i & 0x03FFu16) as u32; + + // Check for an infinity or NaN when all exponent bits set + if half_exp == 0x7C00u32 { + // Check for signed infinity if mantissa is zero + if half_man == 0 { + let number = (half_sign << 16) | 0x7F80_0000u32; + return f32::from_bits(number); + } else { + // NaN, keep current mantissa but also set most significiant mantissa bit + let number = (half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13); + return f32::from_bits(number); + } + } + + // Calculate single-precision components with adjusted exponent + let sign = half_sign << 16; + // Unbias exponent + let unbiased_exp = ((half_exp as i32) >> 10) - 15; + + // Check for subnormals, which will be normalized by adjusting exponent + if half_exp == 0 { + // Calculate how much to adjust the exponent by + let e = (half_man as u16).leading_zeros() - 6; + + // Rebias and adjust exponent + let exp = (127 - 15 - e) << 23; + let man = (half_man << (14 + e)) & 0x7F_FF_FFu32; + return f32::from_bits(sign | exp | man); + } + + // Rebias exponent for a normalized normal + let exp = ((unbiased_exp + 127) as u32) << 23; + let man = (half_man & 0x03FFu32) << 13; + f32::from_bits(sign | exp | man) + } +} + +impl std::fmt::Debug for f16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.as_f32()) + } +} + +impl std::fmt::Display for f16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_f32()) + } +} + +impl NativeType for f16 { + const PRIMITIVE: PrimitiveType = PrimitiveType::Float16; + type Bytes = [u8; 2]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + self.0.to_le_bytes() + } + + #[inline] + fn to_ne_bytes(&self) -> Self::Bytes { + self.0.to_ne_bytes() + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + self.0.to_be_bytes() + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + Self(u16::from_be_bytes(bytes)) + } +} diff --git a/src/types/simd/mod.rs b/src/types/simd/mod.rs index d12e61bcd50..3e97902dcd6 100644 --- a/src/types/simd/mod.rs +++ b/src/types/simd/mod.rs @@ -1,7 +1,7 @@ //! Contains traits and implementations of multi-data used in SIMD. //! The actual representation is driven by the feature flag `"simd"`, which, if set, //! uses [`std::simd`]. -use super::{days_ms, months_days_ns}; +use super::{days_ms, f16, months_days_ns}; use super::{BitChunk, BitChunkIter, NativeType}; /// Describes the ability to convert itself from a [`BitChunk`]. @@ -129,6 +129,7 @@ pub(super) use native_simd; // Types do not have specific intrinsics and thus SIMD can't be specialized. // Therefore, we can declare their MD representation as `[$t; 8]` irrespectively // of how they are represented in the different channels. +native_simd!(f16x32, f16, 32, u32); native_simd!(days_msx8, days_ms, 8, u8); native_simd!(months_days_nsx8, months_days_ns, 8, u8); native_simd!(i128x8, i128, 8, u8); @@ -157,6 +158,7 @@ native!(i8, i8x64); native!(i16, i16x32); native!(i32, i32x16); native!(i64, i64x8); +native!(f16, f16x32); native!(f32, f32x16); native!(f64, f64x8); native!(i128, i128x8); diff --git a/src/types/simd/native.rs b/src/types/simd/native.rs index e06d4d7548a..e393c846d8f 100644 --- a/src/types/simd/native.rs +++ b/src/types/simd/native.rs @@ -11,5 +11,6 @@ native_simd!(i8x64, i8, 64, u64); native_simd!(i16x32, i16, 32, u32); native_simd!(i32x16, i32, 16, u16); native_simd!(i64x8, i64, 8, u8); +native_simd!(f16x32, f16, 32, u32); native_simd!(f32x16, f32, 16, u16); native_simd!(f64x8, f64, 8, u8); diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index fedc7a47dc6..79ce42c9904 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -446,6 +446,7 @@ fn consistency() { Int16, Int32, Int64, + Float16, Float32, Float64, Timestamp(TimeUnit::Second, None), @@ -778,6 +779,7 @@ fn null_array_from_and_to_others() { typed_test!(UInt32Array, UInt32); typed_test!(UInt64Array, UInt64); + typed_test!(Float16Array, Float16); typed_test!(Float32Array, Float32); typed_test!(Float64Array, Float64); } diff --git a/tests/it/compute/comparison.rs b/tests/it/compute/comparison.rs index 1e887d78062..9934781dd8f 100644 --- a/tests/it/compute/comparison.rs +++ b/tests/it/compute/comparison.rs @@ -18,6 +18,7 @@ fn consistency() { Int16, Int32, Int64, + Float16, Float32, Float64, Interval(IntervalUnit::YearMonth),