diff --git a/.github/workflows/slow-tests.yml b/.github/workflows/slow-tests.yml new file mode 100644 index 00000000..77fd3b3c --- /dev/null +++ b/.github/workflows/slow-tests.yml @@ -0,0 +1,26 @@ +name: Run slow tests after PR merge + +on: + pull_request: + types: [closed] + branches: [main] + +jobs: + slow-tests: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.9 + + - name: Install dependencies + run: pip install -e .[test] + + - name: Run slow tests + run: pytest -m slow diff --git a/.github/workflows/run-scripts.yml b/.github/workflows/test-scripts.yml similarity index 89% rename from .github/workflows/run-scripts.yml rename to .github/workflows/test-scripts.yml index f51e2c8d..64b713b5 100644 --- a/.github/workflows/run-scripts.yml +++ b/.github/workflows/test-scripts.yml @@ -18,10 +18,10 @@ jobs: - scripts/analyze_element_errors.py steps: - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.9 diff --git a/matbench_discovery/data.py b/matbench_discovery/data.py index 4d1afc7a..afabe7c4 100644 --- a/matbench_discovery/data.py +++ b/matbench_discovery/data.py @@ -28,6 +28,18 @@ ) +def as_dict_handler(obj: Any) -> dict[str, Any] | None: + """Pass this to json.dump(default=) or as pandas.to_json(default_handler=) to + serialize Python classes with as_dict(). Warning: Objects without a as_dict() method + are replaced with None in the serialized data. + """ + try: + return obj.as_dict() # all MSONable objects implement as_dict() + except AttributeError: + return None # replace unhandled objects with None in serialized data + # removes e.g. non-serializable AseAtoms from M3GNet relaxation trajectories + + def load_train_test( data_names: str | Sequence[str], version: str = figshare_versions[-1], @@ -62,8 +74,7 @@ def load_train_test( ValueError: On bad version number or bad data names. Returns: - pd.DataFrame: Single dataframe or dictionary of dfs if - multiple data were requested. + pd.DataFrame: Single dataframe or dictionary of dfs if multiple data requested. """ if version not in figshare_versions: raise ValueError(f"Unexpected {version=}. Must be one of {figshare_versions}.") @@ -85,16 +96,24 @@ def load_train_test( reader = pd.read_csv if file.endswith(csv_ext) else pd.read_json cache_path = f"{cache_dir}/{file}" - if os.path.isfile(cache_path): + if os.path.isfile(cache_path): # load from disk cache print(f"Loading {key!r} from cached file at {cache_path!r}") df = reader(cache_path, **kwargs) - else: + else: # download from Figshare URL + # manually set compression since pandas can't infer from URL + if file.endswith(".gz"): + kwargs.setdefault("compression", "gzip") + elif file.endswith(".bz2"): + kwargs.setdefault("compression", "bz2") url = file_urls[key] print(f"Downloading {key!r} from {url}") try: - df = reader(url) + df = reader(url, **kwargs) except urllib.error.HTTPError as exc: raise ValueError(f"Bad {url=}") from exc + except Exception: + print(f"\n\nvariable dump:\n{file=},\n{url=},\n{reader=},\n{kwargs=}") + raise if cache_dir and not os.path.isfile(cache_path): os.makedirs(os.path.dirname(cache_path), exist_ok=True) if ".csv" in file: @@ -143,7 +162,7 @@ def glob_to_df( Returns: pd.DataFrame: Combined dataframe. """ - reader = reader or pd.read_csv if ".csv" in pattern else pd.read_json + reader = reader or pd.read_csv if ".csv" in pattern.lower() else pd.read_json # prefix pattern with ROOT if not absolute path files = glob(pattern) @@ -194,9 +213,9 @@ def __getattribute__(self, key: str) -> str: val = super().__getattribute__(key) if key in self and not os.path.isfile(val): msg = f"Warning: {val!r} associated with {key=} does not exist." - try: - self._on_not_found(key, msg) # type: ignore[misc] - except TypeError: + if self._on_not_found: + self._on_not_found(key, msg) + else: print(msg, file=sys.stderr) return val @@ -214,7 +233,7 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override] # default to 'y' if not in interactive session, and user can't answer answer = "" if sys.stdin.isatty() else "y" - while answer not in ("y", "n", "\x1b", ""): + while answer not in ("y", "n"): answer = input(f"{msg} [y/n] ").lower().strip() if answer == "y": load_train_test(key) # download and cache data file @@ -239,17 +258,5 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override] DATA_FILES = DataFiles() -def as_dict_handler(obj: Any) -> dict[str, Any] | None: - """Pass this to json.dump(default=) or as pandas.to_json(default_handler=) to - serialize Python classes with as_dict(). Warning: Objects without a as_dict() method - are replaced with None in the serialized data. - """ - try: - return obj.as_dict() # all MSONable objects implement as_dict() - except AttributeError: - return None # replace unhandled objects with None in serialized data - # removes e.g. non-serializable AseAtoms from M3GNet relaxation trajectories - - df_wbm = load_train_test("wbm_summary") df_wbm["material_id"] = df_wbm.index diff --git a/tests/test_data.py b/tests/test_data.py index 4fb9ce1d..410ce36b 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -107,8 +107,9 @@ def test_load_train_test_raises(tmp_path: Path) -> None: assert ( str(exc_info.value) - == f"Unexpected version='invalid-version'. Must be one of {figshare_versions}." + == f"Unexpected {version=}. Must be one of {figshare_versions}." ) + assert os.listdir(tmp_path) == [], "cache_dir should be empty" def test_load_train_test_doc_str() -> None: @@ -123,60 +124,58 @@ def test_load_train_test_doc_str() -> None: assert os.path.isdir(f"{ROOT}/site/src/routes/{route}") -# TODO skip this test if offline -# @pytest.mark.skipif(online, reason="requires internet connection") -@pytest.mark.parametrize("version", [figshare_versions[-1]]) -def test_load_train_test_no_mock_mp_refs( - version: str, capsys: CaptureFixture[str], tmp_path: Path -) -> None: - # this function runs the download from Figshare for real hence takes some time and - # requires being online - file_key = "mp_elemental_ref_entries" - df = load_train_test(file_key, version=version, cache_dir=tmp_path) - assert df.shape == (5, 89) - - stdout, stderr = capsys.readouterr() - assert stderr == "" - rel_path = getattr(type(DATA_FILES), file_key) - cache_path = f"{tmp_path}/{rel_path}" - assert ( - f"Downloading {file_key!r} from {figshare_urls[file_key]}\nCached " - f"{file_key!r} to {cache_path!r}" in stdout - ) - - # test that df loaded from cache is the same as initial df - pd.testing.assert_frame_equal( - df, load_train_test(file_key, version=version, cache_dir=tmp_path) - ) - stdout, stderr = capsys.readouterr() - assert stderr == "" - assert stdout == f"Loading {file_key!r} from cached file at {cache_path!r}\n" +wbm_summary_expected_cols = { + "bandgap_pbe", + "e_form_per_atom_mp2020_corrected", + "e_form_per_atom_uncorrected", + "e_form_per_atom_wbm", + "e_above_hull_wbm", + "formula", + "n_sites", + "uncorrected_energy", + "uncorrected_energy_from_cse", + "volume", + "wyckoff_spglib", +} # TODO skip this test if offline -@pytest.mark.slow -@pytest.mark.parametrize("version", [figshare_versions[-1]]) -def test_load_train_test_no_mock_wbm_summary( - version: str, capsys: CaptureFixture[str], tmp_path: Path +# @pytest.mark.skipif(online, reason="requires internet connection") +@pytest.mark.parametrize( + "file_key, version, expected_shape, expected_cols", + [ + ("mp_elemental_ref_entries", figshare_versions[-1], (5, 89), set()), + pytest.param( + "wbm_summary", + figshare_versions[-1], + (256963, 15), + wbm_summary_expected_cols, + marks=pytest.mark.slow, # run pytest -m 'slow' to select this marker + ), + pytest.param( + # large file but needed to test loading compressed JSON from URL + "mp_computed_structure_entries", + figshare_versions[-1], + (154718, 1), + {"entry"}, + marks=pytest.mark.slow, + ), + ], +) +def test_load_train_test_no_mock( + file_key: str, + version: str, + expected_shape: tuple[int, int], + expected_cols: set[str], + capsys: CaptureFixture[str], + tmp_path: Path, ) -> None: - # this function runs the download from Figshare for real hence takes some time and + assert os.listdir(tmp_path) == [], "cache_dir should be empty" + # This function runs the download from Figshare for real hence takes some time and # requires being online - file_key = "wbm_summary" df = load_train_test(file_key, version=version, cache_dir=tmp_path) - assert df.shape == (256963, 15) - expected_cols = { - "bandgap_pbe", - "e_form_per_atom_mp2020_corrected", - "e_form_per_atom_uncorrected", - "e_form_per_atom_wbm", - "e_above_hull_wbm", - "formula", - "n_sites", - "uncorrected_energy", - "uncorrected_energy_from_cse", - "volume", - "wyckoff_spglib", - } + assert len(os.listdir(tmp_path)) == 1, "cache_dir should have one file" + assert df.shape == expected_shape assert ( set(df) >= expected_cols ), f"Loaded df missing columns {expected_cols - set(df)}" @@ -184,7 +183,7 @@ def test_load_train_test_no_mock_wbm_summary( stdout, stderr = capsys.readouterr() assert stderr == "" rel_path = getattr(type(DATA_FILES), file_key) - cache_path = f"{tmp_path}/{figshare_versions[-1]}/{rel_path}" + cache_path = f"{tmp_path}/{rel_path}" assert ( f"Downloading {file_key!r} from {figshare_urls[file_key]}\nCached " f"{file_key!r} to {cache_path!r}" in stdout