Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions src/nested_pandas/series/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,36 @@ class NestedDtype(ExtensionDtype):

Parameters
----------
pyarrow_dtype : pyarrow.StructType or pd.ArrowDtype
The pyarrow data type to use for the nested type. It must be a struct
type where all fields are list types.
pyarrow_dtype : pyarrow.StructType, pd.ArrowDtype, or Mapping[str, pa.DataType]
The pyarrow data type to use for the nested type. It may be provided as
a pyarrow.StructType, a pandas.ArrowDtype, or a mapping of column names to
pyarrow data types (such as a dictionary).

Examples
--------
>>> import pyarrow as pa
>>> from nested_pandas import NestedDtype

From pa.StructType:

>>> dtype = NestedDtype(pa.struct([pa.field("a", pa.list_(pa.int64())),
... pa.field("b", pa.list_(pa.float64()))]))
>>> dtype
nested<a: [int64], b: [double]>

From pd.ArrowDtype:

>>> import pandas as pd
>>> dtype = NestedDtype(pd.ArrowDtype(pa.struct([pa.field("a", pa.list_(pa.int64())),
... pa.field("b", pa.list_(pa.float64()))])))
>>> dtype
nested<a: [int64], b: [double]>

From mapping of column names to pyarrow data types:

>>> dtype = NestedDtype({"a": pa.int64(), "b": pa.float64()})
>>> dtype
nested<a: [int64], b: [double]>
"""

# ExtensionDtype overrides #
Expand Down Expand Up @@ -160,6 +187,15 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> ExtensionArray:
pyarrow_dtype: pa.StructType

def __init__(self, pyarrow_dtype: pa.DataType) -> None:
# Allow pd.ArrowDtypes on init
if isinstance(pyarrow_dtype, pd.ArrowDtype):
pyarrow_dtype = pyarrow_dtype.pyarrow_dtype

# Allow from_columns-style mapping inputs
if isinstance(pyarrow_dtype, Mapping):
pyarrow_dtype = pa.struct({col: pa.list_(pa_type) for col, pa_type in pyarrow_dtype.items()})
pyarrow_dtype = cast(pa.StructType, pyarrow_dtype)

self.pyarrow_dtype, self.list_struct_pa_dtype = self._validate_dtype(pyarrow_dtype)

@property
Expand Down
17 changes: 17 additions & 0 deletions tests/nested_pandas/series/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def test_from_pandas_arrow_dtype():
assert dtype_from_list.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))])


def test_init_from_pandas_arrow_dtype():
"""Test that we can construct NestedDtype from pandas.ArrowDtype in __init__."""
dtype_from_struct = NestedDtype(pd.ArrowDtype(pa.struct([pa.field("a", pa.list_(pa.int64()))])))
assert dtype_from_struct.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))])
dtype_from_list = NestedDtype(pd.ArrowDtype(pa.list_(pa.struct([pa.field("a", pa.int64())]))))
assert dtype_from_list.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))])


def test_to_pandas_list_struct_arrow_dtype():
"""Test that NestedDtype.to_pandas_arrow_dtype(list_struct=True) returns the correct pyarrow type."""
dtype = NestedDtype.from_columns({"a": pa.list_(pa.int64()), "b": pa.float64()})
Expand All @@ -100,6 +108,15 @@ def test_from_columns():
)


def test_init_from_columns():
"""Test NestedDtype.__init__ with columns dict."""
columns = {"a": pa.int64(), "b": pa.float64()}
dtype = NestedDtype(columns)
assert dtype.pyarrow_dtype == pa.struct(
[pa.field("a", pa.list_(pa.int64())), pa.field("b", pa.list_(pa.float64()))]
)


def test_na_value():
"""Test that NestedDtype.na_value is a singleton instance of NAType."""
dtype = NestedDtype(pa.struct([pa.field("a", pa.list_(pa.int64()))]))
Expand Down