Skip to content

Commit

Permalink
Fix dataframe vstacking, use built-in polars parquet writer (#266)
Browse files Browse the repository at this point in the history
* feat(nox-ecs): bump polars for Array support

Pick up pola-rs/polars#14943 so that we can
write Parquet files directly.

Signed-off-by: Akhil Velagapudi <4@4khil.com>

* fix(nox-ecs): actually aggregate the dataframes

Signed-off-by: Akhil Velagapudi <4@4khil.com>

* feat(nox-ecs): use built-in parquet writer

Signed-off-by: Akhil Velagapudi <4@4khil.com>

---------

Signed-off-by: Akhil Velagapudi <4@4khil.com>
  • Loading branch information
akhilles committed Mar 20, 2024
1 parent e66aca9 commit 9459a3f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 168 deletions.
7 changes: 3 additions & 4 deletions libs/nox-ecs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ tracing = "0.1"
tracing-subscriber = "0.3"

# serialize
polars.version = "0.37"
polars.version = "0.38"
polars.features = ["parquet", "dtype-array", "lazy"]
polars-arrow.version = "0.37"
arrow.version = "50.0"
polars-arrow.version = "0.38"
arrow.version = "51.0"
arrow.features = ["ffi"]
parquet = "50.0.0"
serde.version = "1.0"
serde_json = "1.0"
postcard.version = "1.0.8"
Expand Down
2 changes: 1 addition & 1 deletion libs/nox-ecs/src/history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl History {
.zip(final_world.archetypes.values_mut())
{
add_time(tick_df, time)?;
final_df.vstack(tick_df)?;
final_df.vstack_mut(tick_df)?;
}
}
Ok(Some(final_world))
Expand Down
168 changes: 5 additions & 163 deletions libs/nox-ecs/src/polars.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
use arrow::array::{ArrayData, LargeListArray, ListArray, MapArray, StructArray, UnionArray};
use arrow::datatypes::{Field, Schema};
use arrow::array::ArrayData;
use arrow::ffi::FFI_ArrowArray;
use arrow::record_batch::RecordBatch;
use conduit::{ComponentId, ComponentType, EntityId, PrimitiveTy};
use parquet::arrow::ArrowWriter;
use parquet::file::properties::WriterProperties;
use polars::prelude::SerReader;
use polars::prelude::*;
use polars::{frame::DataFrame, series::Series};
use polars_arrow::{
array::{Array, PrimitiveArray},
Expand All @@ -14,8 +10,6 @@ use polars_arrow::{
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use std::{collections::BTreeMap, fs::File, path::Path};

use crate::{
Expand Down Expand Up @@ -61,13 +55,9 @@ impl PolarsWorld {
for (archetype_id, df) in &mut self.archetypes {
let path = path.join(format!("{}.parquet", archetype_id.to_raw()));
let file = std::fs::File::create(&path)?;
let props = WriterProperties::default();
let record_batch = df.to_record_batch()?;
let mut writer =
ArrowWriter::try_new(file, record_batch.record_batch().schema(), Some(props))
.unwrap();
writer.write(record_batch.record_batch()).unwrap();
writer.close().unwrap();
ParquetWriter::new(file)
.with_row_group_size(Some(1000))
.finish(df)?;
}
let path = path.join("assets.bin");
let file = std::fs::File::create(path)?;
Expand Down Expand Up @@ -309,154 +299,6 @@ fn tensor_array(ty: &ComponentType, inner: Box<dyn Array>) -> Box<dyn Array> {
// (data_type, Some(metadata))
}

pub struct RecordBatchRef<'a> {
phantom_data: PhantomData<&'a ()>,
record_batch: arrow::record_batch::RecordBatch,
}

impl<'a> RecordBatchRef<'a> {
fn record_batch<'b>(&'b self) -> &'a arrow::record_batch::RecordBatch
where
'b: 'a,
{
&self.record_batch
}
}

pub trait DataFrameConv {
fn to_record_batch(&self) -> Result<RecordBatchRef<'_>, Error>;
}

impl DataFrameConv for DataFrame {
fn to_record_batch(&self) -> Result<RecordBatchRef<'_>, Error> {
let mut fields = vec![];
let mut columns = vec![];
for series in self.iter() {
let name = series.name();
// safety: `to_array_data` is unsafe because it creates a unlifetimed
// reference to `Series`, using `RecordBatchRef` we ensure
// that Series's lifetime is tied to the RecordBatch lifetime,
// so the `Series` will always be alive while the `RecordBatch` is
let array_data = unsafe { series.to_array_data() };
let array: Arc<dyn arrow::array::Array> = match array_data.data_type() {
arrow::datatypes::DataType::Null => {
Arc::new(arrow::array::NullArray::from(array_data))
}
arrow::datatypes::DataType::Boolean => {
Arc::new(arrow::array::BooleanArray::from(array_data))
}
arrow::datatypes::DataType::Int8 => {
Arc::new(arrow::array::Int8Array::from(array_data))
}
arrow::datatypes::DataType::Int16 => {
Arc::new(arrow::array::Int16Array::from(array_data))
}
arrow::datatypes::DataType::Int32 => {
Arc::new(arrow::array::Int32Array::from(array_data))
}
arrow::datatypes::DataType::Int64 => {
Arc::new(arrow::array::Int64Array::from(array_data))
}
arrow::datatypes::DataType::UInt8 => {
Arc::new(arrow::array::UInt8Array::from(array_data))
}
arrow::datatypes::DataType::UInt16 => {
Arc::new(arrow::array::UInt16Array::from(array_data))
}
arrow::datatypes::DataType::UInt32 => {
Arc::new(arrow::array::UInt32Array::from(array_data))
}
arrow::datatypes::DataType::UInt64 => {
Arc::new(arrow::array::UInt64Array::from(array_data))
}
arrow::datatypes::DataType::Float16 => {
Arc::new(arrow::array::Float16Array::from(array_data))
}
arrow::datatypes::DataType::Float32 => {
Arc::new(arrow::array::Float32Array::from(array_data))
}
arrow::datatypes::DataType::Float64 => {
Arc::new(arrow::array::Float64Array::from(array_data))
}
arrow::datatypes::DataType::Timestamp(_, _) => todo!(),
arrow::datatypes::DataType::Date32 => {
Arc::new(arrow::array::Date32Array::from(array_data))
}
arrow::datatypes::DataType::Date64 => {
Arc::new(arrow::array::Date64Array::from(array_data))
}
arrow::datatypes::DataType::Time32(u) => match u {
arrow::datatypes::TimeUnit::Second => {
Arc::new(arrow::array::Time32SecondArray::from(array_data))
}
arrow::datatypes::TimeUnit::Millisecond => {
Arc::new(arrow::array::Time32MillisecondArray::from(array_data))
}
arrow::datatypes::TimeUnit::Microsecond => {
unimplemented!()
}
arrow::datatypes::TimeUnit::Nanosecond => {
unimplemented!()
}
},
arrow::datatypes::DataType::Time64(u) => match u {
arrow::datatypes::TimeUnit::Second => {
todo!()
}
arrow::datatypes::TimeUnit::Millisecond => {
todo!()
}
arrow::datatypes::TimeUnit::Microsecond => {
Arc::new(arrow::array::Time64MicrosecondArray::from(array_data))
}
arrow::datatypes::TimeUnit::Nanosecond => {
Arc::new(arrow::array::Time64NanosecondArray::from(array_data))
}
},
arrow::datatypes::DataType::Duration(_) => todo!(),
arrow::datatypes::DataType::Interval(_) => todo!(),
arrow::datatypes::DataType::Binary => {
Arc::new(arrow::array::BinaryArray::from(array_data))
}
arrow::datatypes::DataType::FixedSizeBinary(_) => {
Arc::new(arrow::array::FixedSizeBinaryArray::from(array_data))
}
arrow::datatypes::DataType::LargeBinary => {
Arc::new(arrow::array::LargeBinaryArray::from(array_data))
}
arrow::datatypes::DataType::Utf8 => todo!(),
arrow::datatypes::DataType::LargeUtf8 => todo!(),
arrow::datatypes::DataType::List(_) => Arc::new(ListArray::from(array_data)),
arrow::datatypes::DataType::FixedSizeList(_, _) => {
Arc::new(arrow::array::FixedSizeListArray::from(array_data))
}
arrow::datatypes::DataType::LargeList(_) => {
Arc::new(LargeListArray::from(array_data))
}
arrow::datatypes::DataType::Struct(_) => Arc::new(StructArray::from(array_data)),
arrow::datatypes::DataType::Union(_, _) => Arc::new(UnionArray::from(array_data)),
arrow::datatypes::DataType::Dictionary(_, _) => {
todo!()
}
arrow::datatypes::DataType::Decimal128(_, _) => todo!(),
arrow::datatypes::DataType::Decimal256(_, _) => todo!(),
arrow::datatypes::DataType::Map(_, _) => Arc::new(MapArray::from(array_data)),
arrow::datatypes::DataType::RunEndEncoded(_, _) => todo!(),
};

let field = Field::new(name, array.data_type().clone(), false);
fields.push(field);
columns.push(array);
}
let schema = Arc::new(Schema::new(fields));
let batch = RecordBatch::try_new(schema, columns)?;
Ok(RecordBatchRef {
phantom_data: PhantomData,
record_batch: batch,
})
}
}

pub trait SeriesExt {
fn to_bytes(&self) -> Vec<u8>;
unsafe fn to_array_data(&self) -> ArrayData;
Expand Down

0 comments on commit 9459a3f

Please sign in to comment.