Skip to content

Commit

Permalink
Merge pull request #200 from TomAugspurger/categorical-dtype-refactor
Browse files Browse the repository at this point in the history
COMPAT: for new pandas CategoricalDtype
  • Loading branch information
martindurant committed Aug 25, 2017
2 parents 923c202 + bb2d54f commit 99cd107
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 8 deletions.
4 changes: 3 additions & 1 deletion fastparquet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def read_col(column, schema_helper, infile, use_cat=False,
if grab_dict:
return dic
if use_cat:
catdef._categories = pd.Index(dic)
# fastpath skips the check the number of categories hasn't changed.
# In this case, they may change, if the default RangeIndex was used.
catdef._set_categories(pd.Index(dic), fastpath=True)
if np.iinfo(assign.dtype).max < len(dic):
raise RuntimeError('Assigned array dtype (%s) cannot accommodate '
'number of category labels (%i)' %
Expand Down
7 changes: 6 additions & 1 deletion fastparquet/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from pandas.core.frame import DataFrame
from pandas.core.index import RangeIndex, Index
from pandas.core.categorical import Categorical, CategoricalDtype
try:
from pandas.api.types import is_categorical_dtype
except ImportError:
# Pandas <= 0.18.1
from pandas.core.common import is_categorical_dtype
from .util import STR_TYPE


Expand Down Expand Up @@ -110,7 +115,7 @@ def empty(types, size, cats=None, cols=None, index_type=None, index_name=None):
inds = list(range(inds.start, inds.stop, inds.step))
for i, ind in enumerate(inds):
col = df.columns[ind]
if str(dtype) == 'category':
if is_categorical_dtype(dtype):
views[col] = block.values._codes
views[col+'-catdef'] = block.values
else:
Expand Down
17 changes: 15 additions & 2 deletions fastparquet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import sys
import six
from thrift.protocol.TBase import TBase
try:
from pandas.api.types import is_categorical_dtype
except ImportError:
# Pandas <= 0.18.1
from pandas.core.common import is_categorical_dtype


PY2 = six.PY2
PY3 = six.PY3
Expand Down Expand Up @@ -260,7 +266,7 @@ def get_column_metadata(column, name):
inferred_dtype = infer_dtype(column)
dtype = column.dtype

if str(dtype) == 'category':
if is_categorical_dtype(dtype):
extra_metadata = {
'num_categories': len(column.cat.categories),
'ordered': column.cat.ordered,
Expand Down Expand Up @@ -289,6 +295,13 @@ def get_column_metadata(column, name):
'integer': str(dtype),
'floating': str(dtype),
}.get(inferred_dtype, inferred_dtype),
'numpy_type': str(dtype),
'numpy_type': get_numpy_type(dtype),
'metadata': extra_metadata,
}


def get_numpy_type(dtype):
if is_categorical_dtype(dtype):
return 'category'
else:
return str(dtype)
13 changes: 9 additions & 4 deletions fastparquet/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated as TCompactProtocol
from thrift.protocol.TProtocol import TProtocolException
try:
from pandas.api.types import is_categorical_dtype
except ImportError:
# Pandas <= 0.18.1
from pandas.core.common import is_categorical_dtype
from .thrift_structures import parquet_thrift
from .compression import compress_data, decompress_data
from .converted_types import tobson
Expand Down Expand Up @@ -454,14 +459,14 @@ def write_column(f, data, selement, compression=None):
encoding = "PLAIN"

if has_nulls:
if str(data.dtype) == 'category':
if is_categorical_dtype(data.dtype):
num_nulls = (data.cat.codes == -1).sum()
elif data.dtype.kind in ['i', 'b']:
num_nulls = 0
else:
num_nulls = len(data) - data.count()
definition_data, data = make_definitions(data, num_nulls == 0)
if data.dtype.kind == "O" and str(data.dtype) != 'category':
if data.dtype.kind == "O" and not is_categorical_dtype(data.dtype):
if selement.type == parquet_thrift.Type.INT64:
data = data.astype(int)
elif selement.type == parquet_thrift.Type.BOOLEAN:
Expand All @@ -478,7 +483,7 @@ def write_column(f, data, selement, compression=None):
diff = 0
max, min = None, None

if str(data.dtype) == 'category':
if is_categorical_dtype(data.dtype):
dph = parquet_thrift.DictionaryPageHeader(
num_values=len(data.cat.categories),
encoding=parquet_thrift.Encoding.PLAIN)
Expand Down Expand Up @@ -673,7 +678,7 @@ def make_metadata(data, has_nulls=True, ignore_columns=[], fixed_text=None,
oencoding = (object_encoding if isinstance(object_encoding, STR_TYPE)
else object_encoding.get(column, None))
fixed = None if fixed_text is None else fixed_text.get(column, None)
if str(data[column].dtype) == 'category':
if is_categorical_dtype(data[column].dtype):
se, type = find_type(data[column].cat.categories,
fixed_text=fixed, object_encoding=oencoding)
se.name = column
Expand Down

0 comments on commit 99cd107

Please sign in to comment.