Skip to content

Commit

Permalink
prepare for pandas extension data types (#77)
Browse files Browse the repository at this point in the history
* prepare for pandas extension data types

* update
  • Loading branch information
goodwanghan committed Nov 29, 2021
1 parent ece1754 commit 6451880
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.6, 3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v2
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ pip install triad

## Release History

### 0.5.6

* Prepare to support [pandas extension data types](https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes)
* Support Python 3.9

### 0.5.5

* Change pandas_list enforce_type df construction
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3 :: Only",
],
python_requires=">=3.6",
Expand Down
8 changes: 5 additions & 3 deletions tests/collections/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
from collections import OrderedDict
from datetime import datetime, date
from datetime import date, datetime

import numpy as np
import pandas as pd
import pyarrow as pa
import numpy as np
from pandas.core.dtypes.common import is_integer_dtype
from pytest import raises
from triad.collections.schema import Schema, SchemaError
from triad.exceptions import InvalidOperationError, NoneArgumentError
Expand Down Expand Up @@ -50,7 +51,8 @@ def test_schema_properties():
== s.pyarrow_schema
)
assert s.pyarrow_schema == s.pyarrow_schema
assert dict(a=np.int32, b=np.dtype(str)) == s.pd_dtype
assert pd.api.types.is_integer_dtype(s.pd_dtype["a"])
assert pd.api.types.is_string_dtype(s.pd_dtype["b"])
assert s.pandas_dtype == s.pd_dtype


Expand Down
61 changes: 60 additions & 1 deletion tests/utils/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pyarrow as pa
from pytest import raises
from triad.utils.pyarrow import (
TRIAD_DEFAULT_TIMESTAMP,
SchemaedDataPartitioner,
_parse_type,
_type_to_expression,
Expand All @@ -14,8 +15,9 @@
schema_to_expression,
schemas_equal,
to_pa_datatype,
to_pandas_dtype,
to_single_pandas_dtype,
validate_column_name,
TRIAD_DEFAULT_TIMESTAMP
)


Expand Down Expand Up @@ -101,10 +103,67 @@ def test_to_pa_datatype():
assert pa.date32() == to_pa_datatype("date")
assert pa.binary() == to_pa_datatype("bytes")
assert pa.binary() == to_pa_datatype("binary")

assert pa.int8() == to_pa_datatype(pd.Int8Dtype)
assert pa.int8() == to_pa_datatype(pd.Int8Dtype())
assert pa.int16() == to_pa_datatype(pd.Int16Dtype)
assert pa.int16() == to_pa_datatype(pd.Int16Dtype())
assert pa.int32() == to_pa_datatype(pd.Int32Dtype)
assert pa.int32() == to_pa_datatype(pd.Int32Dtype())
assert pa.int64() == to_pa_datatype(pd.Int64Dtype)
assert pa.int64() == to_pa_datatype(pd.Int64Dtype())

assert pa.uint8() == to_pa_datatype(pd.UInt8Dtype)
assert pa.uint8() == to_pa_datatype(pd.UInt8Dtype())
assert pa.uint16() == to_pa_datatype(pd.UInt16Dtype)
assert pa.uint16() == to_pa_datatype(pd.UInt16Dtype())
assert pa.uint32() == to_pa_datatype(pd.UInt32Dtype)
assert pa.uint32() == to_pa_datatype(pd.UInt32Dtype())
assert pa.uint64() == to_pa_datatype(pd.UInt64Dtype)
assert pa.uint64() == to_pa_datatype(pd.UInt64Dtype())

assert pa.string() == to_pa_datatype(pd.StringDtype)
assert pa.string() == to_pa_datatype(pd.StringDtype())

assert pa.bool_() == to_pa_datatype(pd.BooleanDtype)
assert pa.bool_() == to_pa_datatype(pd.BooleanDtype())

raises(TypeError, lambda: to_pa_datatype(123))
raises(TypeError, lambda: to_pa_datatype(None))


def test_to_single_pandas_dtype():
assert np.bool_ == to_single_pandas_dtype(pa.bool_(), False)
assert np.int16 == to_single_pandas_dtype(pa.int16(), False)
assert np.uint32 == to_single_pandas_dtype(pa.uint32(), False)
assert np.float32 == to_single_pandas_dtype(pa.float32(), False)
assert np.dtype(str) == to_single_pandas_dtype(pa.string(), False)
assert np.dtype("<M8[ns]") == to_single_pandas_dtype(pa.timestamp("ns"), False)

assert pd.BooleanDtype() == to_single_pandas_dtype(pa.bool_(), True)
assert pd.Int16Dtype() == to_single_pandas_dtype(pa.int16(), True)
assert pd.UInt32Dtype() == to_single_pandas_dtype(pa.uint32(), True)
assert np.float32 == to_single_pandas_dtype(pa.float32(), True)
assert pd.StringDtype() == to_single_pandas_dtype(pa.string(), True)
assert np.dtype("<M8[ns]") == to_single_pandas_dtype(pa.timestamp("ns"), True)


def test_to_pandas_dtype():
schema = expression_to_schema("a:bool,b:int,c:double,d:string,e:datetime")
res = to_pandas_dtype(schema, False)
assert np.bool_ == res["a"]
assert np.int32 == res["b"]
assert np.float64 == res["c"]
assert np.dtype("<U") == res["d"]
assert np.dtype("<M8[ns]") == res["e"]
res = to_pandas_dtype(schema, True)
assert pd.BooleanDtype() == res["a"]
assert pd.Int32Dtype() == res["b"]
assert np.float64 == res["c"]
assert pd.StringDtype() == res["d"]
assert np.dtype("<M8[ns]") == res["e"]


def test_is_supported():
assert is_supported(pa.int32())
assert is_supported(pa.decimal128(5, 2))
Expand Down
21 changes: 16 additions & 5 deletions triad/utils/pandas_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

import numpy as np
import pandas as pd
import pyarrow as pa
from triad.utils.assertion import assert_or_throw
from triad.utils.pyarrow import TRIAD_DEFAULT_TIMESTAMP, apply_schema, to_pandas_dtype
from triad.utils.pyarrow import (
TRIAD_DEFAULT_TIMESTAMP,
apply_schema,
to_pandas_dtype,
to_single_pandas_dtype,
)

import pyarrow as pa

T = TypeVar("T", bound=Any)
_DEFAULT_JOIN_KEYS: List[str] = []
Expand Down Expand Up @@ -142,7 +148,7 @@ def enforce_type( # noqa: C901
if self.empty(df):
return df
if not null_safe:
return df.astype(dtype=to_pandas_dtype(schema))
return df.astype(dtype=to_pandas_dtype(schema, use_extension_types=False))
data: Dict[str, Any] = {}
for v in schema:
s = df[v.name]
Expand All @@ -161,9 +167,14 @@ def enforce_type( # noqa: C901
s = s.mask(ns, None)
elif pa.types.is_integer(v.type):
ns = s.isnull()
s = s.fillna(0).astype(v.type.to_pandas_dtype()).mask(ns, None)
s = (
s.fillna(0)
.astype(int)
.astype(to_single_pandas_dtype(v.type))
.mask(ns, None)
)
elif not pa.types.is_struct(v.type) and not pa.types.is_list(v.type):
s = s.astype(v.type.to_pandas_dtype())
s = s.astype(to_single_pandas_dtype(v.type))
data[v.name] = s
return pd.DataFrame(data)

Expand Down
70 changes: 67 additions & 3 deletions triad/utils/pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import pickle
from datetime import date, datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type

import numpy as np
import pandas as pd
import pyarrow as pa
from pandas.core.dtypes.base import ExtensionDtype
from triad.utils.assertion import assert_or_throw
from triad.utils.convert import as_type
from triad.utils.iter import EmptyAwareIterable, Slicer
Expand Down Expand Up @@ -67,6 +68,31 @@
pa.binary(): "bytes",
}

_PANDAS_EXTENSION_TYPE_TO_PA_MAP: Dict[Type[ExtensionDtype], pa.DataType] = {
pd.Int8Dtype: pa.int8(),
pd.UInt8Dtype: pa.uint8(),
pd.Int16Dtype: pa.int16(),
pd.UInt16Dtype: pa.uint16(),
pd.Int32Dtype: pa.int32(),
pd.UInt32Dtype: pa.uint32(),
pd.Int64Dtype: pa.int64(),
pd.UInt64Dtype: pa.uint64(),
pd.StringDtype: pa.string(),
pd.BooleanDtype: pa.bool_(),
}

_PA_TO_PANDAS_EXTENSION_TYPE_MAP: Dict[pa.DataType, ExtensionDtype] = {
pa.int8(): pd.Int8Dtype(),
pa.uint8(): pd.UInt8Dtype(),
pa.int16(): pd.Int16Dtype(),
pa.uint16(): pd.UInt16Dtype(),
pa.int32(): pd.Int32Dtype(),
pa.uint32(): pd.UInt32Dtype(),
pa.int64(): pd.Int64Dtype(),
pa.uint64(): pd.UInt64Dtype(),
pa.string(): pd.StringDtype(),
pa.bool_(): pd.BooleanDtype(),
}

_SPECIAL_TOKENS: Set[str] = {",", "{", "}", "[", "]", ":"}

Expand Down Expand Up @@ -123,7 +149,7 @@ def schema_to_expression(schema: pa.Schema) -> pa.Schema:
return ",".join(_field_to_expression(x) for x in list(schema))


def to_pa_datatype(obj: Any) -> pa.DataType:
def to_pa_datatype(obj: Any) -> pa.DataType: # noqa: C901
"""Convert an object to pyarrow DataType
:param obj: any object
Expand All @@ -142,17 +168,55 @@ def to_pa_datatype(obj: Any) -> pa.DataType:
return pa.string()
if isinstance(obj, str):
return _parse_type(obj)
if isinstance(obj, ExtensionDtype) or issubclass(obj, ExtensionDtype):
pt = obj if not isinstance(obj, ExtensionDtype) else type(obj)
if pt in _PANDAS_EXTENSION_TYPE_TO_PA_MAP:
return _PANDAS_EXTENSION_TYPE_TO_PA_MAP[pt]
if issubclass(obj, datetime):
return TRIAD_DEFAULT_TIMESTAMP
if issubclass(obj, date):
return pa.date32()
return pa.from_numpy_dtype(np.dtype(obj))


def to_pandas_dtype(schema: pa.Schema) -> Dict[str, np.dtype]:
def to_single_pandas_dtype(
pa_type: pa.DataType, use_extension_types: bool = False
) -> Dict[str, np.dtype]:
"""convert a pyarrow data type to a pandas datatype.
Currently, struct type is not supported
:param schema: the pyarrow schema
:param use_extension_types: whether to use pandas extension
data types, defaults to False
:return: the pandas data type
"""
if use_extension_types:
return (
_PA_TO_PANDAS_EXTENSION_TYPE_MAP[pa_type]
if pa_type in _PA_TO_PANDAS_EXTENSION_TYPE_MAP
else pa_type.to_pandas_dtype()
)
return np.dtype(str) if pa.types.is_string(pa_type) else pa_type.to_pandas_dtype()


def to_pandas_dtype(
schema: pa.Schema, use_extension_types: bool = False
) -> Dict[str, np.dtype]:
"""convert as `dtype` dict for pandas dataframes.
Currently, struct type is not supported
:param schema: the pyarrow schema
:param use_extension_types: whether to use pandas extension
data types, defaults to False
:return: the pandas data type dictionary
"""
if use_extension_types:
return {
f.name: _PA_TO_PANDAS_EXTENSION_TYPE_MAP[f.type]
if f.type in _PA_TO_PANDAS_EXTENSION_TYPE_MAP
else f.type.to_pandas_dtype()
for f in schema
}
return {
f.name: np.dtype(str)
if pa.types.is_string(f.type)
Expand Down
2 changes: 1 addition & 1 deletion triad_version/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.5.5"
__version__ = "0.5.6"

0 comments on commit 6451880

Please sign in to comment.