Skip to content

Commit

Permalink
fix to pa type bug (#78)
Browse files Browse the repository at this point in the history
* prepare for pandas extension data types

* update

* update to_pa_datatype
  • Loading branch information
goodwanghan committed Dec 1, 2021
1 parent 6451880 commit ce7031b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 43 deletions.
4 changes: 4 additions & 0 deletions README.md
Expand Up @@ -20,6 +20,10 @@ pip install triad

## Release History

### 0.5.7

* Fix pandas extension data types bug

### 0.5.6

* Prepare to support [pandas extension data types](https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes)
Expand Down
36 changes: 13 additions & 23 deletions tests/utils/test_pyarrow.py
Expand Up @@ -104,29 +104,19 @@ def test_to_pa_datatype():
assert pa.binary() == to_pa_datatype("bytes")
assert pa.binary() == to_pa_datatype("binary")

assert pa.int8() == to_pa_datatype(pd.Int8Dtype)
assert pa.int8() == to_pa_datatype(pd.Int8Dtype())
assert pa.int16() == to_pa_datatype(pd.Int16Dtype)
assert pa.int16() == to_pa_datatype(pd.Int16Dtype())
assert pa.int32() == to_pa_datatype(pd.Int32Dtype)
assert pa.int32() == to_pa_datatype(pd.Int32Dtype())
assert pa.int64() == to_pa_datatype(pd.Int64Dtype)
assert pa.int64() == to_pa_datatype(pd.Int64Dtype())

assert pa.uint8() == to_pa_datatype(pd.UInt8Dtype)
assert pa.uint8() == to_pa_datatype(pd.UInt8Dtype())
assert pa.uint16() == to_pa_datatype(pd.UInt16Dtype)
assert pa.uint16() == to_pa_datatype(pd.UInt16Dtype())
assert pa.uint32() == to_pa_datatype(pd.UInt32Dtype)
assert pa.uint32() == to_pa_datatype(pd.UInt32Dtype())
assert pa.uint64() == to_pa_datatype(pd.UInt64Dtype)
assert pa.uint64() == to_pa_datatype(pd.UInt64Dtype())

assert pa.string() == to_pa_datatype(pd.StringDtype)
assert pa.string() == to_pa_datatype(pd.StringDtype())

assert pa.bool_() == to_pa_datatype(pd.BooleanDtype)
assert pa.bool_() == to_pa_datatype(pd.BooleanDtype())
assert pa.int8() == to_pa_datatype(pd.Series([1]).astype("Int8").dtype)
assert pa.int16() == to_pa_datatype(pd.Series([1]).astype("Int16").dtype)
assert pa.int32() == to_pa_datatype(pd.Series([1]).astype("Int32").dtype)
assert pa.int64() == to_pa_datatype(pd.Series([1]).astype("Int64").dtype)

assert pa.uint8() == to_pa_datatype(pd.Series([1]).astype("UInt8").dtype)
assert pa.uint16() == to_pa_datatype(pd.Series([1]).astype("UInt16").dtype)
assert pa.uint32() == to_pa_datatype(pd.Series([1]).astype("UInt32").dtype)
assert pa.uint64() == to_pa_datatype(pd.Series([1]).astype("UInt64").dtype)

assert pa.string() == to_pa_datatype(pd.Series(["x"]).astype("string").dtype)

assert pa.bool_() == to_pa_datatype(pd.Series([True]).astype("boolean").dtype)

raises(TypeError, lambda: to_pa_datatype(123))
raises(TypeError, lambda: to_pa_datatype(None))
Expand Down
40 changes: 21 additions & 19 deletions triad/utils/pyarrow.py
@@ -1,18 +1,19 @@
import json
import pickle
from datetime import date, datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple

import numpy as np
import pandas as pd
import pyarrow as pa
from pandas.core.dtypes.base import ExtensionDtype
from triad.utils.assertion import assert_or_throw
from triad.utils.convert import as_type
from triad.utils.iter import EmptyAwareIterable, Slicer
from triad.utils.json import loads_no_dup
from triad.utils.string import validate_triad_var_name

import pyarrow as pa

TRIAD_DEFAULT_TIMESTAMP = pa.timestamp("us")

_TYPE_EXPRESSION_MAPPING: Dict[str, pa.DataType] = {
Expand Down Expand Up @@ -68,17 +69,17 @@
pa.binary(): "bytes",
}

_PANDAS_EXTENSION_TYPE_TO_PA_MAP: Dict[Type[ExtensionDtype], pa.DataType] = {
pd.Int8Dtype: pa.int8(),
pd.UInt8Dtype: pa.uint8(),
pd.Int16Dtype: pa.int16(),
pd.UInt16Dtype: pa.uint16(),
pd.Int32Dtype: pa.int32(),
pd.UInt32Dtype: pa.uint32(),
pd.Int64Dtype: pa.int64(),
pd.UInt64Dtype: pa.uint64(),
pd.StringDtype: pa.string(),
pd.BooleanDtype: pa.bool_(),
_PANDAS_EXTENSION_TYPE_TO_PA_MAP: Dict[ExtensionDtype, pa.DataType] = {
pd.Int8Dtype(): pa.int8(),
pd.UInt8Dtype(): pa.uint8(),
pd.Int16Dtype(): pa.int16(),
pd.UInt16Dtype(): pa.uint16(),
pd.Int32Dtype(): pa.int32(),
pd.UInt32Dtype(): pa.uint32(),
pd.Int64Dtype(): pa.int64(),
pd.UInt64Dtype(): pa.uint64(),
pd.StringDtype(): pa.string(),
pd.BooleanDtype(): pa.bool_(),
}

_PA_TO_PANDAS_EXTENSION_TYPE_MAP: Dict[pa.DataType, ExtensionDtype] = {
Expand Down Expand Up @@ -156,6 +157,8 @@ def to_pa_datatype(obj: Any) -> pa.DataType: # noqa: C901
:raises TypeError: if unable to convert
:return: an instance of pd.DataType
"""
if obj is None:
raise TypeError("obj can't be None")
if isinstance(obj, pa.DataType):
return obj
if obj is bool:
Expand All @@ -168,13 +171,12 @@ def to_pa_datatype(obj: Any) -> pa.DataType: # noqa: C901
return pa.string()
if isinstance(obj, str):
return _parse_type(obj)
if isinstance(obj, ExtensionDtype) or issubclass(obj, ExtensionDtype):
pt = obj if not isinstance(obj, ExtensionDtype) else type(obj)
if pt in _PANDAS_EXTENSION_TYPE_TO_PA_MAP:
return _PANDAS_EXTENSION_TYPE_TO_PA_MAP[pt]
if issubclass(obj, datetime):
if isinstance(obj, ExtensionDtype):
if obj in _PANDAS_EXTENSION_TYPE_TO_PA_MAP:
return _PANDAS_EXTENSION_TYPE_TO_PA_MAP[obj]
if type(obj) == type and issubclass(obj, datetime):
return TRIAD_DEFAULT_TIMESTAMP
if issubclass(obj, date):
if type(obj) == type and issubclass(obj, date):
return pa.date32()
return pa.from_numpy_dtype(np.dtype(obj))

Expand Down
2 changes: 1 addition & 1 deletion triad_version/__init__.py
@@ -1 +1 @@
__version__ = "0.5.6"
__version__ = "0.5.7"

0 comments on commit ce7031b

Please sign in to comment.