Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
No unsafe in IO (#749)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Jan 10, 2022
1 parent 3f6d522 commit 2493f7d
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 79 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Expand Up @@ -16,6 +16,7 @@ bench = false

[dependencies]
num-traits = "0.2"
bytemuck = { version = "1", features = ["derive"] }
chrono = { version = "0.4", default_features = false, features = ["std"] }
chrono-tz = { version = "0.6", optional = true }
# To efficiently cast numbers to strings
Expand Down
1 change: 0 additions & 1 deletion src/io/avro/mod.rs
@@ -1,5 +1,4 @@
#![deny(missing_docs)]
#![forbid(unsafe_code)]
//! Read and write from and to Apache Avro

pub mod read;
Expand Down
1 change: 0 additions & 1 deletion src/io/csv/mod.rs
@@ -1,5 +1,4 @@
#![deny(missing_docs)]
#![forbid(unsafe_code)]
//! Convert data between the Arrow and CSV (comma-separated values).

use crate::error::ArrowError;
Expand Down
19 changes: 3 additions & 16 deletions src/io/ipc/read/read_basic.rs
Expand Up @@ -72,14 +72,8 @@ fn read_uncompressed_buffer<T: NativeType, R: Read + Seek>(

if is_native_little_endian() == is_little_endian {
// fast case where we can just copy the contents as is
unsafe {
// transmute T to bytes.
let slice = std::slice::from_raw_parts_mut(
buffer.as_mut_ptr() as *mut u8,
length * std::mem::size_of::<T>(),
);
reader.read_exact(slice)?;
}
let slice = bytemuck::cast_slice_mut(&mut buffer);
reader.read_exact(slice)?;
} else {
read_swapped(reader, length, &mut buffer, is_little_endian)?;
}
Expand Down Expand Up @@ -108,14 +102,7 @@ fn read_compressed_buffer<T: NativeType, R: Read + Seek>(
let mut slice = vec![0u8; buffer_length];
reader.read_exact(&mut slice)?;

// Safety:
// This is safe because T is NativeType, which by definition can be transmuted to u8
let out_slice = unsafe {
std::slice::from_raw_parts_mut(
buffer.as_mut_ptr() as *mut u8,
length * std::mem::size_of::<T>(),
)
};
let out_slice = bytemuck::cast_slice_mut(&mut buffer);

match compression.codec() {
CompressionType::LZ4_FRAME => {
Expand Down
14 changes: 2 additions & 12 deletions src/io/ipc/write/serialize.rs
Expand Up @@ -765,12 +765,7 @@ fn _write_compressed_buffer_from_iter<T: NativeType, I: TrustedLen<Item = T>>(
fn _write_buffer<T: NativeType>(buffer: &[T], arrow_data: &mut Vec<u8>, is_little_endian: bool) {
if is_little_endian == is_native_little_endian() {
// in native endianess we can use the bytes directly.
let buffer = unsafe {
std::slice::from_raw_parts(
buffer.as_ptr() as *const u8,
buffer.len() * std::mem::size_of::<T>(),
)
};
let buffer = bytemuck::cast_slice(buffer);
arrow_data.extend_from_slice(buffer);
} else {
_write_buffer_from_iter(buffer.iter().copied(), arrow_data, is_little_endian)
Expand All @@ -784,12 +779,7 @@ fn _write_compressed_buffer<T: NativeType>(
compression: Compression,
) {
if is_little_endian == is_native_little_endian() {
let bytes = unsafe {
std::slice::from_raw_parts(
buffer.as_ptr() as *const u8,
buffer.len() * std::mem::size_of::<T>(),
)
};
let bytes = bytemuck::cast_slice(buffer);
arrow_data.extend_from_slice(&(bytes.len() as i64).to_le_bytes());
match compression {
Compression::LZ4 => {
Expand Down
1 change: 0 additions & 1 deletion src/io/json/mod.rs
@@ -1,5 +1,4 @@
#![deny(missing_docs)]
#![forbid(unsafe_code)]
//! Convert data between the Arrow memory format and JSON line-delimited records.

pub mod read;
Expand Down
1 change: 1 addition & 0 deletions src/io/mod.rs
@@ -1,3 +1,4 @@
#![forbid(unsafe_code)]
//! Contains modules to interface with other formats such as [`csv`],
//! [`parquet`], [`json`], [`ipc`], [`mod@print`] and [`avro`].
#[cfg(any(
Expand Down
6 changes: 3 additions & 3 deletions src/io/parquet/read/primitive/basic.rs
Expand Up @@ -5,7 +5,7 @@ use parquet2::{
};

use super::super::utils as other_utils;
use super::utils::ExactChunksIter;
use super::utils::chunks;
use super::ColumnDescriptor;
use crate::{
bitmap::{utils::BitmapIter, MutableBitmap},
Expand Down Expand Up @@ -110,7 +110,7 @@ fn read_nullable<T, A, F>(
F: Fn(T) -> A,
{
let length = additional + values.len();
let mut chunks = ExactChunksIter::<T>::new(values_buffer);
let mut chunks = chunks(values_buffer);

let validity_iterator = hybrid_rle::Decoder::new(validity_buffer, 1);

Expand Down Expand Up @@ -153,7 +153,7 @@ where
F: Fn(T) -> A,
{
assert_eq!(values_buffer.len(), additional * std::mem::size_of::<T>());
let iterator = ExactChunksIter::<T>::new(values_buffer);
let iterator = chunks(values_buffer);

let iterator = iterator.map(op);

Expand Down
4 changes: 2 additions & 2 deletions src/io/parquet/read/primitive/nested.rs
Expand Up @@ -7,7 +7,7 @@ use parquet2::{

use super::super::nested_utils::extend_offsets;
use super::ColumnDescriptor;
use super::{super::utils, utils::ExactChunksIter, Nested};
use super::{super::utils, utils::chunks, Nested};
use crate::{
bitmap::MutableBitmap, error::Result, trusted_len::TrustedLen,
types::NativeType as ArrowNativeType,
Expand Down Expand Up @@ -66,7 +66,7 @@ fn read<T, A, F>(
A: ArrowNativeType,
F: Fn(T) -> A,
{
let new_values = ExactChunksIter::<T>::new(values_buffer);
let new_values = chunks(values_buffer);

let max_rep_level = rep_level_encoding.1 as u32;
let max_def_level = def_level_encoding.1 as u32;
Expand Down
51 changes: 12 additions & 39 deletions src/io/parquet/read/primitive/utils.rs
@@ -1,44 +1,17 @@
use crate::trusted_len::TrustedLen;

use std::{convert::TryInto, hint::unreachable_unchecked};
use std::convert::TryInto;

use parquet2::types::NativeType;

pub struct ExactChunksIter<'a, T: NativeType> {
chunks: std::slice::ChunksExact<'a, u8>,
phantom: std::marker::PhantomData<T>,
}

impl<'a, T: NativeType> ExactChunksIter<'a, T> {
#[inline]
pub fn new(slice: &'a [u8]) -> Self {
assert_eq!(slice.len() % std::mem::size_of::<T>(), 0);
let chunks = slice.chunks_exact(std::mem::size_of::<T>());
Self {
chunks,
phantom: std::marker::PhantomData,
}
}
}

impl<'a, T: NativeType> Iterator for ExactChunksIter<'a, T> {
type Item = T;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.chunks.next().map(|chunk| {
let chunk: <T as NativeType>::Bytes = match chunk.try_into() {
Ok(v) => v,
Err(_) => unsafe { unreachable_unchecked() },
};
T::from_le_bytes(chunk)
})
}
use crate::trusted_len::TrustedLen;

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.chunks.size_hint()
}
pub fn chunks<T: NativeType>(bytes: &[u8]) -> impl TrustedLen<Item = T> + '_ {
assert_eq!(bytes.len() % std::mem::size_of::<T>(), 0);
let chunks = bytes.chunks_exact(std::mem::size_of::<T>());
chunks.map(|chunk| {
let chunk: <T as NativeType>::Bytes = match chunk.try_into() {
Ok(v) => v,
Err(_) => unreachable!(),
};
T::from_le_bytes(chunk)
})
}

unsafe impl<'a, T: NativeType> TrustedLen for ExactChunksIter<'a, T> {}
10 changes: 6 additions & 4 deletions src/types/native.rs
@@ -1,22 +1,23 @@
use std::convert::TryFrom;
use std::ops::Neg;

use bytemuck::{Pod, Zeroable};

use super::PrimitiveType;

/// Sealed trait implemented by all physical types that can be allocated,
/// serialized and deserialized by this crate.
/// All O(N) allocations in this crate are done for this trait alone.
pub trait NativeType:
super::private::Sealed
+ Pod
+ Send
+ Sync
+ Sized
+ Copy
+ std::fmt::Debug
+ std::fmt::Display
+ PartialEq
+ Default
+ 'static
{
/// The corresponding variant of [`PrimitiveType`].
const PRIMITIVE: PrimitiveType;
Expand Down Expand Up @@ -84,8 +85,9 @@ native_type!(f64, PrimitiveType::Float64);
native_type!(i128, PrimitiveType::Int128);

/// The in-memory representation of the DayMillisecond variant of arrow's "Interval" logical type.
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, Hash)]
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, Hash, Zeroable, Pod)]
#[allow(non_camel_case_types)]
#[repr(C)]
pub struct days_ms([i32; 2]);

impl days_ms {
Expand Down Expand Up @@ -176,7 +178,7 @@ impl NativeType for days_ms {
}

/// The in-memory representation of the MonthDayNano variant of the "Interval" logical type.
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, Hash)]
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, Hash, Zeroable, Pod)]
#[allow(non_camel_case_types)]
#[repr(C)]
pub struct months_days_ns(i32, i32, i64);
Expand Down

0 comments on commit 2493f7d

Please sign in to comment.