diff --git a/src/nested_pandas/series/dtype.py b/src/nested_pandas/series/dtype.py index c0b8c9c1..c5495a40 100644 --- a/src/nested_pandas/series/dtype.py +++ b/src/nested_pandas/series/dtype.py @@ -28,9 +28,36 @@ class NestedDtype(ExtensionDtype): Parameters ---------- - pyarrow_dtype : pyarrow.StructType or pd.ArrowDtype - The pyarrow data type to use for the nested type. It must be a struct - type where all fields are list types. + pyarrow_dtype : pyarrow.StructType, pd.ArrowDtype, or Mapping[str, pa.DataType] + The pyarrow data type to use for the nested type. It may be provided as + a pyarrow.StructType, a pandas.ArrowDtype, or a mapping of column names to + pyarrow data types (such as a dictionary). + + Examples + -------- + >>> import pyarrow as pa + >>> from nested_pandas import NestedDtype + + From pa.StructType: + + >>> dtype = NestedDtype(pa.struct([pa.field("a", pa.list_(pa.int64())), + ... pa.field("b", pa.list_(pa.float64()))])) + >>> dtype + nested + + From pd.ArrowDtype: + + >>> import pandas as pd + >>> dtype = NestedDtype(pd.ArrowDtype(pa.struct([pa.field("a", pa.list_(pa.int64())), + ... pa.field("b", pa.list_(pa.float64()))]))) + >>> dtype + nested + + From mapping of column names to pyarrow data types: + + >>> dtype = NestedDtype({"a": pa.int64(), "b": pa.float64()}) + >>> dtype + nested """ # ExtensionDtype overrides # @@ -160,6 +187,15 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> ExtensionArray: pyarrow_dtype: pa.StructType def __init__(self, pyarrow_dtype: pa.DataType) -> None: + # Allow pd.ArrowDtypes on init + if isinstance(pyarrow_dtype, pd.ArrowDtype): + pyarrow_dtype = pyarrow_dtype.pyarrow_dtype + + # Allow from_columns-style mapping inputs + if isinstance(pyarrow_dtype, Mapping): + pyarrow_dtype = pa.struct({col: pa.list_(pa_type) for col, pa_type in pyarrow_dtype.items()}) + pyarrow_dtype = cast(pa.StructType, pyarrow_dtype) + self.pyarrow_dtype, self.list_struct_pa_dtype = self._validate_dtype(pyarrow_dtype) @property diff --git a/tests/nested_pandas/series/test_dtype.py b/tests/nested_pandas/series/test_dtype.py index fbc19160..df68e1ed 100644 --- a/tests/nested_pandas/series/test_dtype.py +++ b/tests/nested_pandas/series/test_dtype.py @@ -83,6 +83,14 @@ def test_from_pandas_arrow_dtype(): assert dtype_from_list.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))]) +def test_init_from_pandas_arrow_dtype(): + """Test that we can construct NestedDtype from pandas.ArrowDtype in __init__.""" + dtype_from_struct = NestedDtype(pd.ArrowDtype(pa.struct([pa.field("a", pa.list_(pa.int64()))]))) + assert dtype_from_struct.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))]) + dtype_from_list = NestedDtype(pd.ArrowDtype(pa.list_(pa.struct([pa.field("a", pa.int64())])))) + assert dtype_from_list.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))]) + + def test_to_pandas_list_struct_arrow_dtype(): """Test that NestedDtype.to_pandas_arrow_dtype(list_struct=True) returns the correct pyarrow type.""" dtype = NestedDtype.from_columns({"a": pa.list_(pa.int64()), "b": pa.float64()}) @@ -100,6 +108,15 @@ def test_from_columns(): ) +def test_init_from_columns(): + """Test NestedDtype.__init__ with columns dict.""" + columns = {"a": pa.int64(), "b": pa.float64()} + dtype = NestedDtype(columns) + assert dtype.pyarrow_dtype == pa.struct( + [pa.field("a", pa.list_(pa.int64())), pa.field("b", pa.list_(pa.float64()))] + ) + + def test_na_value(): """Test that NestedDtype.na_value is a singleton instance of NAType.""" dtype = NestedDtype(pa.struct([pa.field("a", pa.list_(pa.int64()))]))