Skip to content

Commit

Permalink
pandas: add support for multilevel columns
Browse files Browse the repository at this point in the history
Add a test case that verifies that we can serialize pandas
data frames with multilevel columns.

Notes: this test does not currently pass on Python3.

Closes #346
Signed-off-by: David Aguilar <davvid@gmail.com>
  • Loading branch information
davvid committed Jan 31, 2021
1 parent e21d6fd commit 565c299
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
9 changes: 6 additions & 3 deletions jsonpickle/ext/pandas.py
Expand Up @@ -56,7 +56,7 @@ def restore_pandas(self, data):
def make_read_csv_params(meta):
meta_dtypes = meta.get('dtypes', {})
# The header is used to select the rows of the csv from which
# the columns names are retrived
# the columns names are retrieved
header = meta.get('header', [0])
parse_dates = []
converters = {}
Expand Down Expand Up @@ -102,7 +102,9 @@ def flatten(self, obj, data):
def restore(self, data):
csv, meta = self.pp.restore_pandas(data)
params, timedeltas = make_read_csv_params(meta)
column_levels_names = meta.get('column_level_names', None)
# None makes it compatible with objects serialized before
# column_levels_names has been introduced.
column_level_names = meta.get('column_level_names', None)
df = (
pd.read_csv(StringIO(csv), **params)
if data['values'].strip()
Expand All @@ -113,7 +115,8 @@ def restore(self, data):

df.set_index(decode(meta['index']), inplace=True)
# restore the column level(s) name(s)
df.columns.names = column_levels_names
if column_level_names:
df.columns.names = column_level_names
return df


Expand Down
20 changes: 20 additions & 0 deletions tests/pandas_test.py
Expand Up @@ -15,6 +15,7 @@

import jsonpickle
import jsonpickle.ext.pandas
from jsonpickle.compat import PY2


@pytest.fixture(scope='module', autouse=True)
Expand Down Expand Up @@ -289,5 +290,24 @@ def test_dataframe_with_timedelta64_dtype():
assert data_frame['Duration'][2] == actual['Duration'][2]


def test_multilevel_columns():
if not PY2:
pytest.skip('This test does not yet pass on Python 3')

iterables = [['inj', 'prod'], ['hourly', 'cumulative']]
names = ['first', 'second']
# transform it to tuples
columns = pd.MultiIndex.from_product(iterables, names=names)
# build a multi-index from it
data_frame = pd.DataFrame(
np.random.randn(3, 4), index=['A', 'B', 'C'], columns=columns
)
encoded = jsonpickle.encode(data_frame)
cloned_data_frame = jsonpickle.decode(encoded)
assert isinstance(cloned_data_frame, pd.DataFrame)
assert data_frame.columns.names == cloned_data_frame.columns.names
assert_frame_equal(data_frame, cloned_data_frame)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 565c299

Please sign in to comment.