Skip to content

Commit

Permalink
Deduplicate serialization for accessors (#334)
Browse files Browse the repository at this point in the history
Previously, we had two copies of serialization that were doing
essentially the same thing.

This deduplicates accessor serialization to allow any
"json-serializable" value through (a superset technically; whatever
ipywidgets default serialization supports) but then otherwise assume the
input is a pyarrow column and serialize that to a parquet file.
  • Loading branch information
kylebarron committed Jan 26, 2024
1 parent de30e82 commit f8e057e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 28 deletions.
26 changes: 7 additions & 19 deletions lonboard/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def serialize_table_to_parquet(table: pa.Table, *, max_chunksize: int) -> List[b
buffers: List[bytes] = []
# NOTE: passing `max_chunksize=0` creates an infinite loop
# https://github.com/apache/arrow/issues/39788
assert max_chunksize > 0

for record_batch in table.to_batches(max_chunksize=max_chunksize):
with BytesIO() as bio:
with pq.ParquetWriter(
Expand Down Expand Up @@ -51,25 +53,13 @@ def serialize_pyarrow_column(data: pa.Array, *, max_chunksize: int) -> List[byte
return serialize_table_to_parquet(pyarrow_table, max_chunksize=max_chunksize)


def serialize_color_accessor(
data: Union[List[int], Tuple[int], NDArray[np.uint8]], obj
):
if data is None:
return None

if isinstance(data, (list, tuple)):
return data

assert isinstance(data, (pa.ChunkedArray, pa.Array))
validate_accessor_length_matches_table(data, obj.table)
return serialize_pyarrow_column(data, max_chunksize=obj._rows_per_chunk)


def serialize_float_accessor(data: Union[int, float, NDArray[np.floating]], obj):
def serialize_accessor(data: Union[List[int], Tuple[int], NDArray[np.uint8]], obj):
if data is None:
return None

if isinstance(data, (str, int, float)):
# We assume data has already been validated to the right type for this accessor
# Allow any json-serializable type through
if isinstance(data, (str, int, float, list, tuple, bytes)):
return data

assert isinstance(data, (pa.ChunkedArray, pa.Array))
Expand Down Expand Up @@ -98,7 +88,5 @@ def validate_accessor_length_matches_table(accessor, table):
raise TraitError("accessor must have same length as table")


COLOR_SERIALIZATION = {"to_json": serialize_color_accessor}
# TODO: rename as it's used for text as well
FLOAT_SERIALIZATION = {"to_json": serialize_float_accessor}
ACCESSOR_SERIALIZATION = {"to_json": serialize_accessor}
TABLE_SERIALIZATION = {"to_json": serialize_table}
6 changes: 2 additions & 4 deletions lonboard/experimental/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import pyarrow as pa
from traitlets.traitlets import TraitType

from lonboard._serialization import (
COLOR_SERIALIZATION,
)
from lonboard._serialization import ACCESSOR_SERIALIZATION
from lonboard.traits import FixedErrorTraitType


Expand Down Expand Up @@ -39,7 +37,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.tag(sync=True, **COLOR_SERIALIZATION)
self.tag(sync=True, **ACCESSOR_SERIALIZATION)

def validate(
self, obj, value
Expand Down
9 changes: 4 additions & 5 deletions lonboard/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from typing_extensions import Self

from lonboard._serialization import (
COLOR_SERIALIZATION,
FLOAT_SERIALIZATION,
ACCESSOR_SERIALIZATION,
TABLE_SERIALIZATION,
)

Expand Down Expand Up @@ -206,7 +205,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.tag(sync=True, **COLOR_SERIALIZATION)
self.tag(sync=True, **ACCESSOR_SERIALIZATION)

def validate(
self, obj, value
Expand Down Expand Up @@ -332,7 +331,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.tag(sync=True, **FLOAT_SERIALIZATION)
self.tag(sync=True, **ACCESSOR_SERIALIZATION)

def validate(self, obj, value) -> Union[float, pa.ChunkedArray, pa.DoubleArray]:
if isinstance(value, (int, float)):
Expand Down Expand Up @@ -402,7 +401,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.tag(sync=True, **FLOAT_SERIALIZATION)
self.tag(sync=True, **ACCESSOR_SERIALIZATION)

def validate(self, obj, value) -> Union[float, pa.ChunkedArray, pa.DoubleArray]:
if isinstance(value, str):
Expand Down

0 comments on commit f8e057e

Please sign in to comment.