Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added support to read and write f16 (#1051)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Jun 13, 2022
1 parent 698c7b8 commit 76d2a39
Show file tree
Hide file tree
Showing 17 changed files with 155 additions and 11 deletions.
1 change: 0 additions & 1 deletion README.md
Expand Up @@ -18,7 +18,6 @@ documentation of each of its APIs.
## Features

* Most feature-complete implementation of Apache Arrow after the reference implementation (C++)
* Float 16 unsupported (not a Rust native type)
* Decimal 256 unsupported (not a Rust native type)
* C data interface supported for all Arrow types (read and write)
* C stream interface supported for all Arrow types (read and write)
Expand Down
3 changes: 2 additions & 1 deletion src/array/mod.rs
Expand Up @@ -203,7 +203,7 @@ macro_rules! with_match_primitive_type {(
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
use crate::types::{days_ms, months_days_ns};
use crate::types::{days_ms, months_days_ns, f16};
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Expand All @@ -216,6 +216,7 @@ macro_rules! with_match_primitive_type {(
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float16 => __with_ty__! { f16 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
Expand Down
6 changes: 5 additions & 1 deletion src/array/primitive/mod.rs
Expand Up @@ -7,7 +7,7 @@ use crate::{
datatypes::*,
error::Error,
trusted_len::TrustedLen,
types::{days_ms, months_days_ns, NativeType},
types::{days_ms, f16, months_days_ns, NativeType},
};

use super::Array;
Expand Down Expand Up @@ -468,6 +468,8 @@ pub type Int128Array = PrimitiveArray<i128>;
pub type DaysMsArray = PrimitiveArray<days_ms>;
/// A type definition [`PrimitiveArray`] for [`months_days_ns`]
pub type MonthsDaysNsArray = PrimitiveArray<months_days_ns>;
/// A type definition [`PrimitiveArray`] for `f16`
pub type Float16Array = PrimitiveArray<f16>;
/// A type definition [`PrimitiveArray`] for `f32`
pub type Float32Array = PrimitiveArray<f32>;
/// A type definition [`PrimitiveArray`] for `f64`
Expand Down Expand Up @@ -495,6 +497,8 @@ pub type Int128Vec = MutablePrimitiveArray<i128>;
pub type DaysMsVec = MutablePrimitiveArray<days_ms>;
/// A type definition [`MutablePrimitiveArray`] for [`months_days_ns`]
pub type MonthsDaysNsVec = MutablePrimitiveArray<months_days_ns>;
/// A type definition [`MutablePrimitiveArray`] for `f16`
pub type Float16Vec = MutablePrimitiveArray<f16>;
/// A type definition [`MutablePrimitiveArray`] for `f32`
pub type Float32Vec = MutablePrimitiveArray<f32>;
/// A type definition [`MutablePrimitiveArray`] for `f64`
Expand Down
2 changes: 1 addition & 1 deletion src/compute/arithmetics/mod.rs
Expand Up @@ -414,7 +414,7 @@ macro_rules! with_match_negatable {(
Int128 => __with_ty__! { i128 },
DaysMs => __with_ty__! { days_ms },
MonthDayNano => __with_ty__! { months_days_ns },
UInt8 | UInt16 | UInt32 | UInt64=> todo!(),
UInt8 | UInt16 | UInt32 | UInt64 | Float16 => todo!(),
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
Expand Down
7 changes: 7 additions & 0 deletions src/compute/cast/mod.rs
Expand Up @@ -215,6 +215,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Int64, Float64) => true,
(Int64, Decimal(_, _)) => true,

(Float16, Float32) => true,

(Float32, UInt8) => true,
(Float32, UInt16) => true,
(Float32, UInt32) => true,
Expand Down Expand Up @@ -736,6 +738,11 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Int64, Float64) => primitive_to_primitive_dyn::<i64, f64>(array, to_type, as_options),
(Int64, Decimal(p, s)) => integer_to_decimal_dyn::<i64>(array, *p, *s),

(Float16, Float32) => {
let from = array.as_any().downcast_ref().unwrap();
Ok(f16_to_f32(from).boxed())
}

(Float32, UInt8) => primitive_to_primitive_dyn::<f32, u8>(array, to_type, options),
(Float32, UInt16) => primitive_to_primitive_dyn::<f32, u16>(array, to_type, options),
(Float32, UInt32) => primitive_to_primitive_dyn::<f32, u32>(array, to_type, options),
Expand Down
7 changes: 6 additions & 1 deletion src/compute/cast/primitive_to.rs
Expand Up @@ -4,7 +4,7 @@ use num_traits::{AsPrimitive, Float, ToPrimitive};

use crate::datatypes::IntervalUnit;
use crate::error::Result;
use crate::types::{days_ms, months_days_ns};
use crate::types::{days_ms, f16, months_days_ns};
use crate::{
array::*,
bitmap::Bitmap,
Expand Down Expand Up @@ -581,3 +581,8 @@ pub fn months_to_months_days_ns(from: &PrimitiveArray<i32>) -> PrimitiveArray<mo
DataType::Interval(IntervalUnit::MonthDayNano),
)
}

/// Casts f16 into f32
pub fn f16_to_f32(from: &PrimitiveArray<f16>) -> PrimitiveArray<f32> {
unary(from, |x| x.as_f32(), DataType::Float32)
}
7 changes: 5 additions & 2 deletions src/compute/comparison/mod.rs
Expand Up @@ -81,6 +81,7 @@ macro_rules! match_eq_ord {(
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float16 => todo!(),
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
Expand All @@ -91,7 +92,7 @@ macro_rules! match_eq {(
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
use crate::types::{days_ms, months_days_ns};
use crate::types::{days_ms, months_days_ns, f16};
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Expand All @@ -104,6 +105,7 @@ macro_rules! match_eq {(
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float16 => __with_ty__! { f16 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
Expand Down Expand Up @@ -487,7 +489,8 @@ fn can_partial_eq(data_type: &DataType) -> bool {
can_partial_eq_and_ord(data_type)
|| matches!(
data_type.to_logical_type(),
DataType::Interval(IntervalUnit::DayTime)
DataType::Float16
| DataType::Interval(IntervalUnit::DayTime)
| DataType::Interval(IntervalUnit::MonthDayNano)
)
}
Expand Down
4 changes: 3 additions & 1 deletion src/compute/comparison/simd/native.rs
@@ -1,7 +1,7 @@
use std::convert::TryInto;

use super::{set, Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd};
use crate::types::{days_ms, months_days_ns};
use crate::types::{days_ms, f16, months_days_ns};

simd8_native_all!(u8);
simd8_native_all!(u16);
Expand All @@ -12,6 +12,8 @@ simd8_native_all!(i16);
simd8_native_all!(i32);
simd8_native_all!(i128);
simd8_native_all!(i64);
simd8_native!(f16);
simd8_native_partial_eq!(f16);
simd8_native_all!(f32);
simd8_native_all!(f64);
simd8_native!(days_ms);
Expand Down
4 changes: 3 additions & 1 deletion src/compute/comparison/simd/packed.rs
Expand Up @@ -2,7 +2,7 @@ use std::convert::TryInto;
use std::simd::ToBitMask;

use crate::types::simd::*;
use crate::types::{days_ms, months_days_ns};
use crate::types::{days_ms, f16, months_days_ns};

use super::*;

Expand Down Expand Up @@ -71,6 +71,8 @@ simd8!(i16, i16x8);
simd8!(i32, i32x8);
simd8!(i64, i64x8);
simd8_native_all!(i128);
simd8_native!(f16);
simd8_native_partial_eq!(f16);
simd8!(f32, f32x8);
simd8!(f64, f64x8);
simd8_native!(days_ms);
Expand Down
3 changes: 2 additions & 1 deletion src/datatypes/mod.rs
Expand Up @@ -237,7 +237,7 @@ impl DataType {
UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16),
UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32),
UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64),
Float16 => unreachable!(),
Float16 => PhysicalType::Primitive(PrimitiveType::Float16),
Float32 => PhysicalType::Primitive(PrimitiveType::Float32),
Float64 => PhysicalType::Primitive(PrimitiveType::Float64),
Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs),
Expand Down Expand Up @@ -299,6 +299,7 @@ impl From<PrimitiveType> for DataType {
PrimitiveType::UInt32 => DataType::UInt32,
PrimitiveType::UInt64 => DataType::UInt64,
PrimitiveType::Int128 => DataType::Decimal(32, 32),
PrimitiveType::Float16 => DataType::Float16,
PrimitiveType::Float32 => DataType::Float32,
PrimitiveType::Float64 => DataType::Float64,
PrimitiveType::DaysMs => DataType::Interval(IntervalUnit::DayTime),
Expand Down
1 change: 1 addition & 0 deletions src/io/json_integration/read/array.rs
Expand Up @@ -289,6 +289,7 @@ pub fn to_array(
Primitive(PrimitiveType::UInt16) => Ok(Box::new(to_primitive::<u16>(json_col, data_type))),
Primitive(PrimitiveType::UInt32) => Ok(Box::new(to_primitive::<u32>(json_col, data_type))),
Primitive(PrimitiveType::UInt64) => Ok(Box::new(to_primitive::<u64>(json_col, data_type))),
Primitive(PrimitiveType::Float16) => todo!(),
Primitive(PrimitiveType::Float32) => Ok(Box::new(to_primitive::<f32>(json_col, data_type))),
Primitive(PrimitiveType::Float64) => Ok(Box::new(to_primitive::<f64>(json_col, data_type))),
Binary => Ok(to_binary::<i32>(json_col, data_type)),
Expand Down
3 changes: 3 additions & 0 deletions src/types/mod.rs
Expand Up @@ -55,6 +55,8 @@ pub enum PrimitiveType {
UInt32,
/// An unsigned 64-bit integer.
UInt64,
/// A 16-bit floating point number.
Float16,
/// A 32-bit floating point number.
Float32,
/// A 64-bit floating point number.
Expand All @@ -77,6 +79,7 @@ mod private {
impl Sealed for i32 {}
impl Sealed for i64 {}
impl Sealed for i128 {}
impl Sealed for super::f16 {}
impl Sealed for f32 {}
impl Sealed for f64 {}
impl Sealed for super::days_ms {}
Expand Down
110 changes: 110 additions & 0 deletions src/types/native.rs
Expand Up @@ -325,3 +325,113 @@ impl Neg for months_days_ns {
Self::new(-self.months(), -self.days(), -self.ns())
}
}

/// Type representation of the Float16 physical type
#[derive(Copy, Clone, Default, Zeroable, Pod)]
#[allow(non_camel_case_types)]
#[repr(C)]
pub struct f16(pub u16);

impl PartialEq for f16 {
#[inline]
fn eq(&self, other: &f16) -> bool {
if self.is_nan() || other.is_nan() {
false
} else {
(self.0 == other.0) || ((self.0 | other.0) & 0x7FFFu16 == 0)
}
}
}

// see https://github.com/starkat99/half-rs/blob/main/src/binary16.rs
impl f16 {
#[inline]
#[must_use]
pub(crate) const fn is_nan(self) -> bool {
self.0 & 0x7FFFu16 > 0x7C00u16
}

/// Casts this `f16` to `f32`
#[inline]
pub fn as_f32(self) -> f32 {
let i = self.0;
// Check for signed zero
if i & 0x7FFFu16 == 0 {
return f32::from_bits((i as u32) << 16);
}

let half_sign = (i & 0x8000u16) as u32;
let half_exp = (i & 0x7C00u16) as u32;
let half_man = (i & 0x03FFu16) as u32;

// Check for an infinity or NaN when all exponent bits set
if half_exp == 0x7C00u32 {
// Check for signed infinity if mantissa is zero
if half_man == 0 {
let number = (half_sign << 16) | 0x7F80_0000u32;
return f32::from_bits(number);
} else {
// NaN, keep current mantissa but also set most significiant mantissa bit
let number = (half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13);
return f32::from_bits(number);
}
}

// Calculate single-precision components with adjusted exponent
let sign = half_sign << 16;
// Unbias exponent
let unbiased_exp = ((half_exp as i32) >> 10) - 15;

// Check for subnormals, which will be normalized by adjusting exponent
if half_exp == 0 {
// Calculate how much to adjust the exponent by
let e = (half_man as u16).leading_zeros() - 6;

// Rebias and adjust exponent
let exp = (127 - 15 - e) << 23;
let man = (half_man << (14 + e)) & 0x7F_FF_FFu32;
return f32::from_bits(sign | exp | man);
}

// Rebias exponent for a normalized normal
let exp = ((unbiased_exp + 127) as u32) << 23;
let man = (half_man & 0x03FFu32) << 13;
f32::from_bits(sign | exp | man)
}
}

impl std::fmt::Debug for f16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.as_f32())
}
}

impl std::fmt::Display for f16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_f32())
}
}

impl NativeType for f16 {
const PRIMITIVE: PrimitiveType = PrimitiveType::Float16;
type Bytes = [u8; 2];
#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
self.0.to_le_bytes()
}

#[inline]
fn to_ne_bytes(&self) -> Self::Bytes {
self.0.to_ne_bytes()
}

#[inline]
fn to_be_bytes(&self) -> Self::Bytes {
self.0.to_be_bytes()
}

#[inline]
fn from_be_bytes(bytes: Self::Bytes) -> Self {
Self(u16::from_be_bytes(bytes))
}
}
4 changes: 3 additions & 1 deletion src/types/simd/mod.rs
@@ -1,7 +1,7 @@
//! Contains traits and implementations of multi-data used in SIMD.
//! The actual representation is driven by the feature flag `"simd"`, which, if set,
//! uses [`std::simd`].
use super::{days_ms, months_days_ns};
use super::{days_ms, f16, months_days_ns};
use super::{BitChunk, BitChunkIter, NativeType};

/// Describes the ability to convert itself from a [`BitChunk`].
Expand Down Expand Up @@ -129,6 +129,7 @@ pub(super) use native_simd;
// Types do not have specific intrinsics and thus SIMD can't be specialized.
// Therefore, we can declare their MD representation as `[$t; 8]` irrespectively
// of how they are represented in the different channels.
native_simd!(f16x32, f16, 32, u32);
native_simd!(days_msx8, days_ms, 8, u8);
native_simd!(months_days_nsx8, months_days_ns, 8, u8);
native_simd!(i128x8, i128, 8, u8);
Expand Down Expand Up @@ -157,6 +158,7 @@ native!(i8, i8x64);
native!(i16, i16x32);
native!(i32, i32x16);
native!(i64, i64x8);
native!(f16, f16x32);
native!(f32, f32x16);
native!(f64, f64x8);
native!(i128, i128x8);
Expand Down
1 change: 1 addition & 0 deletions src/types/simd/native.rs
Expand Up @@ -11,5 +11,6 @@ native_simd!(i8x64, i8, 64, u64);
native_simd!(i16x32, i16, 32, u32);
native_simd!(i32x16, i32, 16, u16);
native_simd!(i64x8, i64, 8, u8);
native_simd!(f16x32, f16, 32, u32);
native_simd!(f32x16, f32, 16, u16);
native_simd!(f64x8, f64, 8, u8);
2 changes: 2 additions & 0 deletions tests/it/compute/cast.rs
Expand Up @@ -446,6 +446,7 @@ fn consistency() {
Int16,
Int32,
Int64,
Float16,
Float32,
Float64,
Timestamp(TimeUnit::Second, None),
Expand Down Expand Up @@ -778,6 +779,7 @@ fn null_array_from_and_to_others() {
typed_test!(UInt32Array, UInt32);
typed_test!(UInt64Array, UInt64);

typed_test!(Float16Array, Float16);
typed_test!(Float32Array, Float32);
typed_test!(Float64Array, Float64);
}
Expand Down
1 change: 1 addition & 0 deletions tests/it/compute/comparison.rs
Expand Up @@ -18,6 +18,7 @@ fn consistency() {
Int16,
Int32,
Int64,
Float16,
Float32,
Float64,
Interval(IntervalUnit::YearMonth),
Expand Down

0 comments on commit 76d2a39

Please sign in to comment.