diff --git a/Makefile b/Makefile index 098ced9..4391481 100644 --- a/Makefile +++ b/Makefile @@ -106,7 +106,7 @@ clean: @echo '=== Cleaning up ===' find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true find . -type f -name "*.pyc" -delete 2>/dev/null || true - rm -rf $(PROJECT_NAME).egg-info .pytest_cache .coverage .$(LINTER)_cache site docs/site + rm -rf $(PROJECT_NAME).egg-info .pytest_cache .coverage .$(LINTER)_cache site docs/site .hyperbench_cache destroy: clean @echo '=== Destroying environment ===' diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 60036f4..dd5d4e0 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -382,6 +382,7 @@ def load_by_name( try: path_prefix = f"datasets--HypernetworkRG--{dataset_name}" shutil.rmtree(os.path.join(hf_cache_dir, path_prefix)) + shutil.rmtree(os.path.join(hf_cache_dir, ".locks", path_prefix)) except Exception as e: warnings.warn( f"Failed to clean up Hugging Face Hub cache after downloading dataset {dataset_name!r}: {e!s}.", diff --git a/hyperbench/tests/data/hif_test.py b/hyperbench/tests/data/hif_test.py index 07f7769..7cacc3a 100644 --- a/hyperbench/tests/data/hif_test.py +++ b/hyperbench/tests/data/hif_test.py @@ -3,6 +3,7 @@ import pytest import requests import torch +import os from unittest.mock import patch from hyperbench.data import HIFLoader, HIFProcessor @@ -654,6 +655,7 @@ def test_load_by_name_uses_hf_revision_when_github_download_fails(tmp_path, mock assert result.num_nodes == 2 assert result.num_hyperedges == 1 assert not (tmp_path / "hf_cache" / "datasets--HypernetworkRG--algebra").exists() + assert not (tmp_path / "hf_cache" / ".locks" / "datasets--HypernetworkRG--algebra").exists() def test_load_by_name_skips_cache_cleanup_when_hf_cache_dir_is_missing(tmp_path, mock_hypergraph): @@ -692,6 +694,36 @@ def test_load_by_name_skips_cache_cleanup_when_hf_cache_dir_is_missing(tmp_path, assert result.num_hyperedges == 1 +def test_load_by_name_cleans_hf_cache_and_locks(tmp_path, mock_hypergraph): + hf_sha = "2bb641461e00c103fb5ef4fe6a30aad42500fc21" + fallback_file = tmp_path / "algebra.json.zst" + fallback_file.write_bytes(b"mock_zst_content") + payload = __hif_payload(mock_hypergraph) + + response = requests.Response() + response.status_code = 404 + response._content = b"" + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.os.path.isdir", return_value=True), + patch("hyperbench.data.hif.requests.get", return_value=response), + patch("hyperbench.data.hif.hf_hub_download", return_value=str(fallback_file)), + patch("hyperbench.data.hif.from_zst_file_to_json", return_value=payload), + patch("hyperbench.data.hif.validate_hif_data", return_value=True), + patch("hyperbench.data.hif.shutil.rmtree") as mock_rmtree, + pytest.warns(UserWarning, match="GitHub raw download failed"), + ): + result = HIFLoader.load_by_name("algebra", hf_sha=hf_sha, save_on_disk=False) + + cache_root = tmp_path / "hf_cache" + path_prefix = "datasets--HypernetworkRG--algebra" + mock_rmtree.assert_any_call(os.path.join(cache_root, path_prefix)) + mock_rmtree.assert_any_call(os.path.join(cache_root, ".locks", path_prefix)) + assert result.num_nodes == 2 + assert result.num_hyperedges == 1 + + def test_load_by_name_raises_when_hf_sha_is_missing_on_fallback(): response = requests.Response() response.status_code = 404 diff --git a/hyperbench/tests/utils/hif_utils_test.py b/hyperbench/tests/utils/hif_utils_test.py index 89ad135..4ff9cb0 100644 --- a/hyperbench/tests/utils/hif_utils_test.py +++ b/hyperbench/tests/utils/hif_utils_test.py @@ -112,10 +112,17 @@ def test_validate_hif_json_opens_the_given_path(tmp_path): mock_file.assert_called_once_with(str(path_valid), encoding="utf-8") -def test_validate_hif_json_with_url_success(): +def test_validate_hif_json_with_url_success(tmp_path): path_valid = f"{MOCK_BASE_PATH}/hif_compliant.json" + schema_path = tmp_path / "hif_schema.json" - with patch("hyperbench.utils.hif_utils.requests.get") as mock_get: + mock_files = MagicMock() + mock_files.joinpath.return_value = schema_path + + with ( + patch("hyperbench.utils.hif_utils.requests.get") as mock_get, + patch("hyperbench.utils.hif_utils.resources.files", return_value=mock_files), + ): mock_response = MagicMock() mock_response.json.return_value = {"type": "object"} # Minimal valid schema mock_get.return_value = mock_response @@ -126,6 +133,8 @@ def test_validate_hif_json_with_url_success(): timeout=10, ) + assert schema_path.exists() + def test_validate_hif_json_with_url_timeout_fallback(): path_valid = f"{MOCK_BASE_PATH}/hif_compliant.json" @@ -172,6 +181,24 @@ def test_validate_hif_json_with_url_request_exception_fallback(): mock_path.open.assert_called_once_with("r", encoding="utf-8") +def test_validate_hif_data_raises_when_schema_load_fails(): + with ( + patch( + "hyperbench.utils.hif_utils.resources.files", + side_effect=RuntimeError("no local schema"), + ), + patch( + "hyperbench.utils.hif_utils.requests.get", + side_effect=requests.RequestException("network down"), + ), + pytest.raises( + RuntimeError, + match="Failed to load HIF schema from both local file and remote URL", + ), + ): + validate_hif_data({"incidences": [{"edge": 1, "node": 2}]}) + + def test_get_datasets_shas_returns_shas_and_none_on_failure(): names = ["algebra", "missing-dataset"] diff --git a/hyperbench/utils/hif_utils.py b/hyperbench/utils/hif_utils.py index 69b37ed..88402c2 100644 --- a/hyperbench/utils/hif_utils.py +++ b/hyperbench/utils/hif_utils.py @@ -112,12 +112,27 @@ def get_gh_dataset_sha(dataset_name: str, owner: str, repository: str) -> str | def __load_hif_schema() -> dict[str, Any]: url = f"https://raw.githubusercontent.com/HIF-org/HIF-standard/{HIF_SCHEMA_COMMIT_SHA}/schemas/hif_schema.json" + try: - return requests.get(url, timeout=10).json() - except (requests.RequestException, requests.Timeout): with ( resources.files("hyperbench.utils.schema") .joinpath("hif_schema.json") .open("r", encoding="utf-8") as f ): return json.load(f) + except Exception: + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + schema = response.json() + schema_path = resources.files("hyperbench.utils.schema").joinpath("hif_schema.json") + with ( + resources.as_file(schema_path) as path, + path.open("w", encoding="utf-8") as f, + ): + json.dump(schema, f) + return schema + except (requests.RequestException, requests.Timeout) as e: + raise RuntimeError( + "Failed to load HIF schema from both local file and remote URL. " + ) from e