Skip to content

Commit

Permalink
Preserve column order when writing dataframes.
Browse files Browse the repository at this point in the history
  • Loading branch information
mtth committed Feb 4, 2017
1 parent 314d9c2 commit 99e1a85
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 6 deletions.
2 changes: 1 addition & 1 deletion hdfs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging as lg


__version__ = '2.0.15'
__version__ = '2.0.16'
__license__ = 'MIT'


Expand Down
14 changes: 11 additions & 3 deletions hdfs/ext/avro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class AvroReader(object):

def __init__(self, client, hdfs_path, parts=None):
self.content = client.content(hdfs_path) #: Content summary of Avro file.
self.metadata = None #: Avro header metadata.
self._schema = None
if self.content['directoryCount']:
# This is a folder.
Expand Down Expand Up @@ -196,12 +197,12 @@ def _reader():
if not self._schema:
schema = reader.schema
_logger.debug('Read schema from %r.', path)
yield schema
yield (schema, reader.metadata)
for record in reader:
yield record

self._records = _reader()
self._schema = next(self._records) # Prime generator to get schema.
self._schema, self.metadata = next(self._records)
return self

def __exit__(self, exc_type, exc_value, traceback):
Expand Down Expand Up @@ -239,6 +240,8 @@ class AvroWriter(object):
:param sync_interval: Number of bytes after which a block will be written.
:param sync_marker: 16 byte tag used for synchronization. If not specified,
one will be generated at random.
:param metadata: Additional metadata to include in the container file's
header. Keys starting with `'avro.'` are reserved.
:param \*\*kwargs: Keyword arguments forwarded to
:meth:`hdfs.client.Client.write`.
Expand All @@ -253,13 +256,14 @@ class AvroWriter(object):
"""

def __init__(self, client, hdfs_path, schema=None, codec=None,
sync_interval=None, sync_marker=None, **kwargs):
sync_interval=None, sync_marker=None, metadata=None, **kwargs):
self._hdfs_path = hdfs_path
self._fo = client.write(hdfs_path, **kwargs)
self._schema = schema
self._codec = codec or 'null'
self._sync_interval = sync_interval or 1000 * fastavro._writer.SYNC_SIZE
self._sync_marker = sync_marker or os.urandom(fastavro._writer.SYNC_SIZE)
self._metadata = metadata
self._writer = None
_logger.info('Instantiated %r.', self)

Expand Down Expand Up @@ -323,6 +327,10 @@ def dump_header():
'avro.codec': self._codec,
'avro.schema': dumps(self._schema),
}
if self._metadata:
for key, value in self._metadata.items():
# Don't overwrite the codec or schema.
metadata.setdefault(key, value)
fastavro._writer.write_header(fo, metadata, self._sync_marker)
_logger.debug('Wrote header. Sync marker: %r', self._sync_marker)
fastavro._writer.acquaint_schema(self._schema)
Expand Down
10 changes: 8 additions & 2 deletions hdfs/ext/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""

from .avro import AvroReader, AvroWriter
import json
import pandas as pd


Expand All @@ -25,7 +26,11 @@ def read_dataframe(client, hdfs_path):
"""
with AvroReader(client, hdfs_path) as reader:
# Hack-ish, but loading all elements in memory first to get length.
return pd.DataFrame.from_records(list(reader))
if 'pandas.columns' in reader.metadata:
columns = json.loads(reader.metadata['pandas.columns'])
else:
columns = None
return pd.DataFrame.from_records(list(reader), columns=columns)


def write_dataframe(client, hdfs_path, df, **kwargs):
Expand All @@ -38,6 +43,7 @@ def write_dataframe(client, hdfs_path, df, **kwargs):
:class:`hdfs.ext.avro.AvroWriter`.
"""
with AvroWriter(client, hdfs_path, **kwargs) as writer:
metadata = {'pandas.columns': json.dumps(df.columns.tolist())}
with AvroWriter(client, hdfs_path, metadata=metadata, **kwargs) as writer:
for _, row in df.iterrows():
writer.write(row.to_dict())
12 changes: 12 additions & 0 deletions test/test_ext_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,15 @@ def test_write(self):
write_dataframe(self.client, 'weather.avro', self.df)
with AvroReader(self.client, 'weather.avro') as reader:
eq_(list(reader), self.records)


class TestReadWriteDataFrame(_DataFrameIntegrationTest):

def test_column_order(self):
# Column order should be preserved, not just alphabetical.
df = self.df[['temp', 'station', 'time']]
write_dataframe(self.client, 'weather-ordered.avro', df)
assert_frame_equal(
read_dataframe(self.client, 'weather-ordered.avro'),
df
)

0 comments on commit 99e1a85

Please sign in to comment.