Skip to content

Commit

Permalink
Fix downloading compressed Figshare data (#21)
Browse files Browse the repository at this point in the history
* fix load_train_test() for compressed figshare data (closes #20)

* load_train_test() only accept answer 'y' or 'n' (as orig intended) (close #17)

* add test covering load_train_test() with compressed JSON file from URL

* mv run-scripts.yml test-scripts.yml

* add slow-tests.yml for running slow tests only on PR merges (to save CI budget)
  • Loading branch information
janosh committed Apr 30, 2023
1 parent 0da443f commit 5d7c620
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 74 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/slow-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Run slow tests after PR merge

on:
pull_request:
types: [closed]
branches: [main]

jobs:
slow-tests:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9

- name: Install dependencies
run: pip install -e .[test]

- name: Run slow tests
run: pytest -m slow
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ jobs:
- scripts/analyze_element_errors.py
steps:
- name: Check out repository
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.9

Expand Down
51 changes: 29 additions & 22 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@
)


def as_dict_handler(obj: Any) -> dict[str, Any] | None:
"""Pass this to json.dump(default=) or as pandas.to_json(default_handler=) to
serialize Python classes with as_dict(). Warning: Objects without a as_dict() method
are replaced with None in the serialized data.
"""
try:
return obj.as_dict() # all MSONable objects implement as_dict()
except AttributeError:
return None # replace unhandled objects with None in serialized data
# removes e.g. non-serializable AseAtoms from M3GNet relaxation trajectories


def load_train_test(
data_names: str | Sequence[str],
version: str = figshare_versions[-1],
Expand Down Expand Up @@ -62,8 +74,7 @@ def load_train_test(
ValueError: On bad version number or bad data names.
Returns:
pd.DataFrame: Single dataframe or dictionary of dfs if
multiple data were requested.
pd.DataFrame: Single dataframe or dictionary of dfs if multiple data requested.
"""
if version not in figshare_versions:
raise ValueError(f"Unexpected {version=}. Must be one of {figshare_versions}.")
Expand All @@ -85,16 +96,24 @@ def load_train_test(
reader = pd.read_csv if file.endswith(csv_ext) else pd.read_json

cache_path = f"{cache_dir}/{file}"
if os.path.isfile(cache_path):
if os.path.isfile(cache_path): # load from disk cache
print(f"Loading {key!r} from cached file at {cache_path!r}")
df = reader(cache_path, **kwargs)
else:
else: # download from Figshare URL
# manually set compression since pandas can't infer from URL
if file.endswith(".gz"):
kwargs.setdefault("compression", "gzip")
elif file.endswith(".bz2"):
kwargs.setdefault("compression", "bz2")
url = file_urls[key]
print(f"Downloading {key!r} from {url}")
try:
df = reader(url)
df = reader(url, **kwargs)
except urllib.error.HTTPError as exc:
raise ValueError(f"Bad {url=}") from exc
except Exception:
print(f"\n\nvariable dump:\n{file=},\n{url=},\n{reader=},\n{kwargs=}")
raise
if cache_dir and not os.path.isfile(cache_path):
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
if ".csv" in file:
Expand Down Expand Up @@ -143,7 +162,7 @@ def glob_to_df(
Returns:
pd.DataFrame: Combined dataframe.
"""
reader = reader or pd.read_csv if ".csv" in pattern else pd.read_json
reader = reader or pd.read_csv if ".csv" in pattern.lower() else pd.read_json

# prefix pattern with ROOT if not absolute path
files = glob(pattern)
Expand Down Expand Up @@ -194,9 +213,9 @@ def __getattribute__(self, key: str) -> str:
val = super().__getattribute__(key)
if key in self and not os.path.isfile(val):
msg = f"Warning: {val!r} associated with {key=} does not exist."
try:
self._on_not_found(key, msg) # type: ignore[misc]
except TypeError:
if self._on_not_found:
self._on_not_found(key, msg)
else:
print(msg, file=sys.stderr)
return val

Expand All @@ -214,7 +233,7 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override]

# default to 'y' if not in interactive session, and user can't answer
answer = "" if sys.stdin.isatty() else "y"
while answer not in ("y", "n", "\x1b", ""):
while answer not in ("y", "n"):
answer = input(f"{msg} [y/n] ").lower().strip()
if answer == "y":
load_train_test(key) # download and cache data file
Expand All @@ -239,17 +258,5 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override]
DATA_FILES = DataFiles()


def as_dict_handler(obj: Any) -> dict[str, Any] | None:
"""Pass this to json.dump(default=) or as pandas.to_json(default_handler=) to
serialize Python classes with as_dict(). Warning: Objects without a as_dict() method
are replaced with None in the serialized data.
"""
try:
return obj.as_dict() # all MSONable objects implement as_dict()
except AttributeError:
return None # replace unhandled objects with None in serialized data
# removes e.g. non-serializable AseAtoms from M3GNet relaxation trajectories


df_wbm = load_train_test("wbm_summary")
df_wbm["material_id"] = df_wbm.index
99 changes: 49 additions & 50 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ def test_load_train_test_raises(tmp_path: Path) -> None:

assert (
str(exc_info.value)
== f"Unexpected version='invalid-version'. Must be one of {figshare_versions}."
== f"Unexpected {version=}. Must be one of {figshare_versions}."
)
assert os.listdir(tmp_path) == [], "cache_dir should be empty"


def test_load_train_test_doc_str() -> None:
Expand All @@ -123,68 +124,66 @@ def test_load_train_test_doc_str() -> None:
assert os.path.isdir(f"{ROOT}/site/src/routes/{route}")


# TODO skip this test if offline
# @pytest.mark.skipif(online, reason="requires internet connection")
@pytest.mark.parametrize("version", [figshare_versions[-1]])
def test_load_train_test_no_mock_mp_refs(
version: str, capsys: CaptureFixture[str], tmp_path: Path
) -> None:
# this function runs the download from Figshare for real hence takes some time and
# requires being online
file_key = "mp_elemental_ref_entries"
df = load_train_test(file_key, version=version, cache_dir=tmp_path)
assert df.shape == (5, 89)

stdout, stderr = capsys.readouterr()
assert stderr == ""
rel_path = getattr(type(DATA_FILES), file_key)
cache_path = f"{tmp_path}/{rel_path}"
assert (
f"Downloading {file_key!r} from {figshare_urls[file_key]}\nCached "
f"{file_key!r} to {cache_path!r}" in stdout
)

# test that df loaded from cache is the same as initial df
pd.testing.assert_frame_equal(
df, load_train_test(file_key, version=version, cache_dir=tmp_path)
)
stdout, stderr = capsys.readouterr()
assert stderr == ""
assert stdout == f"Loading {file_key!r} from cached file at {cache_path!r}\n"
wbm_summary_expected_cols = {
"bandgap_pbe",
"e_form_per_atom_mp2020_corrected",
"e_form_per_atom_uncorrected",
"e_form_per_atom_wbm",
"e_above_hull_wbm",
"formula",
"n_sites",
"uncorrected_energy",
"uncorrected_energy_from_cse",
"volume",
"wyckoff_spglib",
}


# TODO skip this test if offline
@pytest.mark.slow
@pytest.mark.parametrize("version", [figshare_versions[-1]])
def test_load_train_test_no_mock_wbm_summary(
version: str, capsys: CaptureFixture[str], tmp_path: Path
# @pytest.mark.skipif(online, reason="requires internet connection")
@pytest.mark.parametrize(
"file_key, version, expected_shape, expected_cols",
[
("mp_elemental_ref_entries", figshare_versions[-1], (5, 89), set()),
pytest.param(
"wbm_summary",
figshare_versions[-1],
(256963, 15),
wbm_summary_expected_cols,
marks=pytest.mark.slow, # run pytest -m 'slow' to select this marker
),
pytest.param(
# large file but needed to test loading compressed JSON from URL
"mp_computed_structure_entries",
figshare_versions[-1],
(154718, 1),
{"entry"},
marks=pytest.mark.slow,
),
],
)
def test_load_train_test_no_mock(
file_key: str,
version: str,
expected_shape: tuple[int, int],
expected_cols: set[str],
capsys: CaptureFixture[str],
tmp_path: Path,
) -> None:
# this function runs the download from Figshare for real hence takes some time and
assert os.listdir(tmp_path) == [], "cache_dir should be empty"
# This function runs the download from Figshare for real hence takes some time and
# requires being online
file_key = "wbm_summary"
df = load_train_test(file_key, version=version, cache_dir=tmp_path)
assert df.shape == (256963, 15)
expected_cols = {
"bandgap_pbe",
"e_form_per_atom_mp2020_corrected",
"e_form_per_atom_uncorrected",
"e_form_per_atom_wbm",
"e_above_hull_wbm",
"formula",
"n_sites",
"uncorrected_energy",
"uncorrected_energy_from_cse",
"volume",
"wyckoff_spglib",
}
assert len(os.listdir(tmp_path)) == 1, "cache_dir should have one file"
assert df.shape == expected_shape
assert (
set(df) >= expected_cols
), f"Loaded df missing columns {expected_cols - set(df)}"

stdout, stderr = capsys.readouterr()
assert stderr == ""
rel_path = getattr(type(DATA_FILES), file_key)
cache_path = f"{tmp_path}/{figshare_versions[-1]}/{rel_path}"
cache_path = f"{tmp_path}/{rel_path}"
assert (
f"Downloading {file_key!r} from {figshare_urls[file_key]}\nCached "
f"{file_key!r} to {cache_path!r}" in stdout
Expand Down

0 comments on commit 5d7c620

Please sign in to comment.