Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Simplified Primitive and Boolean scalar. (#648)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Dec 1, 2021
1 parent 23c7d8a commit 8300684
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 61 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/security.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ jobs:
toolchain: nightly-2021-10-24
override: true
- uses: Swatinem/rust-cache@v1
with:
key: key1
- name: Install Miri
run: |
rustup component add miri
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ jobs:
toolchain: nightly-2021-10-24
override: true
- uses: Swatinem/rust-cache@v1
with:
key: key1
- name: Install Miri
run: |
rustup component add miri
Expand All @@ -96,6 +98,8 @@ jobs:
toolchain: nightly-2021-10-24
override: true
- uses: Swatinem/rust-cache@v1
with:
key: key1
- name: Install Miri
run: |
rustup component add miri
Expand Down
25 changes: 13 additions & 12 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,58 +240,59 @@ macro_rules! compare_scalar {
Boolean => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<BooleanScalar>().unwrap();
boolean::$op(lhs, rhs.value())
// validity checked above
boolean::$op(lhs, rhs.value().unwrap())
}
Int8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i8>>().unwrap();
primitive::$op::<i8>(lhs, rhs.value())
primitive::$op::<i8>(lhs, rhs.value().unwrap())
}
Int16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i16>>().unwrap();
primitive::$op::<i16>(lhs, rhs.value())
primitive::$op::<i16>(lhs, rhs.value().unwrap())
}
Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i32>>().unwrap();
primitive::$op::<i32>(lhs, rhs.value())
primitive::$op::<i32>(lhs, rhs.value().unwrap())
}
Int64 | Timestamp(_, _) | Date64 | Time64(_) | Duration(_) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
primitive::$op::<i64>(lhs, rhs.value())
primitive::$op::<i64>(lhs, rhs.value().unwrap())
}
UInt8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u8>>().unwrap();
primitive::$op::<u8>(lhs, rhs.value())
primitive::$op::<u8>(lhs, rhs.value().unwrap())
}
UInt16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u16>>().unwrap();
primitive::$op::<u16>(lhs, rhs.value())
primitive::$op::<u16>(lhs, rhs.value().unwrap())
}
UInt32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u32>>().unwrap();
primitive::$op::<u32>(lhs, rhs.value())
primitive::$op::<u32>(lhs, rhs.value().unwrap())
}
UInt64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<u64>>().unwrap();
primitive::$op::<u64>(lhs, rhs.value())
primitive::$op::<u64>(lhs, rhs.value().unwrap())
}
Float16 => unreachable!(),
Float32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<f32>>().unwrap();
primitive::$op::<f32>(lhs, rhs.value())
primitive::$op::<f32>(lhs, rhs.value().unwrap())
}
Float64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<f64>>().unwrap();
primitive::$op::<f64>(lhs, rhs.value())
primitive::$op::<f64>(lhs, rhs.value().unwrap())
}
Utf8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
Expand All @@ -309,7 +310,7 @@ macro_rules! compare_scalar {
.as_any()
.downcast_ref::<PrimitiveScalar<i128>>()
.unwrap();
primitive::$op::<i128>(lhs, rhs.value())
primitive::$op::<i128>(lhs, rhs.value().unwrap())
}
Binary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
Expand Down
25 changes: 7 additions & 18 deletions src/scalar/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,21 @@ use crate::datatypes::DataType;
use super::Scalar;

/// The [`Scalar`] implementation of a boolean.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct BooleanScalar {
value: bool,
is_valid: bool,
}

impl PartialEq for BooleanScalar {
fn eq(&self, other: &Self) -> bool {
self.is_valid == other.is_valid && ((!self.is_valid) | (self.value == other.value))
}
value: Option<bool>,
}

impl BooleanScalar {
/// Returns a new [`BooleanScalar`]
#[inline]
pub fn new(v: Option<bool>) -> Self {
let is_valid = v.is_some();
Self {
value: v.unwrap_or_default(),
is_valid,
}
pub fn new(value: Option<bool>) -> Self {
Self { value }
}

/// The value irrespectively of the validity
/// The value
#[inline]
pub fn value(&self) -> bool {
pub fn value(&self) -> Option<bool> {
self.value
}
}
Expand All @@ -41,7 +30,7 @@ impl Scalar for BooleanScalar {

#[inline]
fn is_valid(&self) -> bool {
self.is_valid
self.value.is_some()
}

#[inline]
Expand Down
36 changes: 8 additions & 28 deletions src/scalar/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,16 @@ use super::Scalar;

/// The implementation of [`Scalar`] for primitive, semantically equivalent to [`Option<T>`]
/// with [`DataType`].
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct PrimitiveScalar<T: NativeType> {
// Not Option<T> because this offers a stabler pointer offset on the struct
value: T,
is_valid: bool,
value: Option<T>,
data_type: DataType,
}

impl<T: NativeType> PartialEq for PrimitiveScalar<T> {
fn eq(&self, other: &Self) -> bool {
self.data_type == other.data_type
&& self.is_valid == other.is_valid
&& ((!self.is_valid) | (self.value == other.value))
}
}

impl<T: NativeType> PrimitiveScalar<T> {
/// Returns a new [`PrimitiveScalar`].
#[inline]
pub fn new(data_type: DataType, v: Option<T>) -> Self {
pub fn new(data_type: DataType, value: Option<T>) -> Self {
if !T::is_valid(&data_type) {
Err(ArrowError::InvalidArgumentError(format!(
"Type {} does not support logical type {}",
Expand All @@ -36,30 +26,20 @@ impl<T: NativeType> PrimitiveScalar<T> {
)))
.unwrap()
}
let is_valid = v.is_some();
Self {
value: v.unwrap_or_default(),
is_valid,
data_type,
}
Self { value, data_type }
}

/// Returns the value irrespectively of its validity.
/// Returns the optional value.
#[inline]
pub fn value(&self) -> T {
pub fn value(&self) -> Option<T> {
self.value
}

/// Returns a new `PrimitiveScalar` with the same value but different [`DataType`]
/// # Panic
/// This function panics if the `data_type` is not valid for self's physical type `T`.
pub fn to(self, data_type: DataType) -> Self {
let v = if self.is_valid {
Some(self.value)
} else {
None
};
Self::new(data_type, v)
Self::new(data_type, self.value)
}
}

Expand All @@ -78,7 +58,7 @@ impl<T: NativeType> Scalar for PrimitiveScalar<T> {

#[inline]
fn is_valid(&self) -> bool {
self.is_valid
self.value.is_some()
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion tests/it/scalar/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn equal() {
fn basics() {
let a = BooleanScalar::new(Some(true));

assert!(a.value());
assert_eq!(a.value(), Some(true));
assert_eq!(a.data_type(), &DataType::Boolean);
assert!(a.is_valid());

Expand Down
3 changes: 1 addition & 2 deletions tests/it/scalar/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ fn equal() {
fn basics() {
let a = PrimitiveScalar::from(Some(2i32));

assert_eq!(a.value(), 2i32);
assert_eq!(a.value(), Some(2i32));
assert_eq!(a.data_type(), &DataType::Int32);
assert!(a.is_valid());

let a = a.to(DataType::Date32);
assert_eq!(a.data_type(), &DataType::Date32);
Expand Down

0 comments on commit 8300684

Please sign in to comment.