Skip to content

Commit

Permalink
Merge pull request #1213: utils.write_json: Serialize Pandas Series
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin committed Jun 14, 2023
2 parents 7139595 + edb1e46 commit 9a0c04c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
5 changes: 5 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## __NEXT__

### Bug fixes

* utils: Serialize pandas Series in `write_json`. [#1213][] (@victorlin)

[#1213]: https://github.com/nextstrain/augur/pull/1213

## 22.0.2 (26 May 2023)

Expand Down
11 changes: 8 additions & 3 deletions augur/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,23 @@ def write_json(data, file_name, indent=(None if os.environ.get("AUGUR_MINIFY_JSO
data["generated_by"] = {"program": "augur", "version": get_augur_version()}
with open(file_name, 'w', encoding='utf-8') as handle:
sort_keys = False if isinstance(data, OrderedDict) else True
json.dump(data, handle, indent=indent, sort_keys=sort_keys, cls=NumpyJSONEncoder)
json.dump(data, handle, indent=indent, sort_keys=sort_keys, cls=AugurJSONEncoder)


class NumpyJSONEncoder(json.JSONEncoder):
"""A custom JSONEncoder subclass to serialize additional numpy data types."""
class AugurJSONEncoder(json.JSONEncoder):
"""
A custom JSONEncoder subclass to serialize data types used for various data
stored in dictionary format.
"""
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, pd.Series):
return obj.tolist()
return super().default(obj)


Expand Down
11 changes: 7 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from pathlib import Path
from unittest.mock import patch
import pandas as pd

import pytest

Expand Down Expand Up @@ -93,18 +94,20 @@ def test_read_strains(self, tmpdir):
assert len(strains) == 3
assert "strain1" in strains

def test_write_json_numpy_types(self, tmpdir):
"""write_json should be able to serialize numpy data types."""
def test_write_json_data_types(self, tmpdir):
"""write_json should be able to serialize various data types."""
data = {
'int': np.int64(1),
'float': np.float64(2.0),
'array': np.array([3,4,5])
'array': np.array([3,4,5]),
'series': pd.Series([6,7,8])
}
file = Path(tmpdir) / Path("data.json")
utils.write_json(data, file, include_version=False)
with open(file) as f:
assert json.load(f) == {
'int': 1,
'float': 2.0,
'array': [3,4,5]
'array': [3,4,5],
'series': [6,7,8]
}

0 comments on commit 9a0c04c

Please sign in to comment.