Skip to content

Commit

Permalink
Grab dict (#28)
Browse files Browse the repository at this point in the history
* plain stats encoding

* Short-cut function to grab dictionary values

Needed by dask-dataframe if we are to have any categorical columns

* Remember to convert dict, if necessary

* Add cats test

* Rename parameter
  • Loading branch information
martindurant committed Nov 17, 2016
1 parent 9726080 commit be061d1
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
34 changes: 32 additions & 2 deletions fastparquet/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,38 @@ def read_row_group_file(self, rg, columns, categories):
with self.open(ofname, 'rb') as f:
return self.read_row_group(rg, columns, categories, infile=f)

def grab_cats(self, columns, row_group_index=0):
""" Read dictionaries of first row_group
Used to correctly create metadata for categorical dask dataframes.
Could be used to check that the same dictionary is used throughout
the data.
Parameters
----------
columns: list
Column names to load
row_group_index: int (0)
Row group to load from
Returns
-------
{column: [list, of, values]}
"""
rg = self.row_groups[row_group_index]
ofname = self.sep.join([os.path.dirname(self.fn),
rg.columns[0].file_path])
out = {}

with self.open(ofname, 'rb') as f:
for column in rg.columns:
name = ".".join(column.meta_data.path_in_schema)
if name not in columns:
continue
out[name] = core.read_col(column, self.helper, f,
grab_dict=True)
return out

def read_row_group(self, rg, columns, categories, infile=None):
"""
Access row-group in a file and read some columns into a data-frame.
Expand All @@ -120,8 +152,6 @@ def read_row_group(self, rg, columns, categories, infile=None):

for column in rg.columns:
name = ".".join(column.meta_data.path_in_schema)
se = self.helper.schema_element(name)
use = name in categories if categories is not None else False
if name not in columns:
continue

Expand Down
15 changes: 11 additions & 4 deletions fastparquet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def read_dictionary_page(file_obj, schema_helper, page_header, column_metadata):
return values


def read_col(column, schema_helper, infile, use_cat=False):
def read_col(column, schema_helper, infile, use_cat=False,
grab_dict=False):
"""Using the given metadata, read one column in one row-group.
Parameters
Expand All @@ -150,10 +151,12 @@ def read_col(column, schema_helper, infile, use_cat=False):
use_cat: bool (False)
If this column is encoded throughout with dict encoding, give back
a pandas categorical column; otherwise, decode to values
grab_dict: bool (False)
Short-cut mode to return the dictionary values only - skips the actual
data.
"""
cmd = column.meta_data
name = ".".join(cmd.path_in_schema)
rows = cmd.num_values
se = schema_helper.schema_element(cmd.path_in_schema[-1])
off = min((cmd.dictionary_page_offset or cmd.data_page_offset,
cmd.data_page_offset))

Expand All @@ -164,6 +167,11 @@ def read_col(column, schema_helper, infile, use_cat=False):
if ph.type == parquet_thrift.PageType.DICTIONARY_PAGE:
dic = np.array(read_dictionary_page(infile, schema_helper, ph, cmd))
ph = read_thrift(infile, parquet_thrift.PageHeader)
if grab_dict:
return convert(pd.Series(dic), se)

name = ".".join(cmd.path_in_schema)
rows = cmd.num_values

out = []
num = 0
Expand Down Expand Up @@ -229,7 +237,6 @@ def read_col(column, schema_helper, infile, use_cat=False):
final[start:start+l] = val
start += l

se = schema_helper.schema_element(cmd.path_in_schema[-1])
if all_dict:
if se.converted_type is not None:
dic = convert(pd.Series(dic), se)
Expand Down
11 changes: 11 additions & 0 deletions fastparquet/test/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,14 @@ def test_statistics(tempdir):
assert stat['min']['b'] == [None]
assert stat['max']['c'] == [b'c']
assert stat['min']['c'] == [b'a']


def test_grab_cats(tempdir):
s = pd.Series(['a', 'c', 'b']*20)
df = pd.DataFrame({'a': s, 'b': s.astype('category'),
'c': s.astype('category').cat.as_ordered()})
fastparquet.write(tempdir, df, file_scheme='hive')
pf = fastparquet.ParquetFile(tempdir)
cats = pf.grab_cats(['b', 'c'])
assert (cats['b'] == df.b.cat.categories).all()
assert (cats['c'] == df.c.cat.categories).all()

0 comments on commit be061d1

Please sign in to comment.