Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ clean:
@echo '=== Cleaning up ==='
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
find . -type f -name "*.pyc" -delete 2>/dev/null || true
rm -rf $(PROJECT_NAME).egg-info .pytest_cache .coverage .$(LINTER)_cache site docs/site
rm -rf $(PROJECT_NAME).egg-info .pytest_cache .coverage .$(LINTER)_cache site docs/site .hyperbench_cache

destroy: clean
@echo '=== Destroying environment ==='
Expand Down
1 change: 1 addition & 0 deletions hyperbench/data/hif.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def load_by_name(
try:
path_prefix = f"datasets--HypernetworkRG--{dataset_name}"
shutil.rmtree(os.path.join(hf_cache_dir, path_prefix))
shutil.rmtree(os.path.join(hf_cache_dir, ".locks", path_prefix))
except Exception as e:
warnings.warn(
f"Failed to clean up Hugging Face Hub cache after downloading dataset {dataset_name!r}: {e!s}.",
Expand Down
32 changes: 32 additions & 0 deletions hyperbench/tests/data/hif_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import requests
import torch
import os

from unittest.mock import patch
from hyperbench.data import HIFLoader, HIFProcessor
Expand Down Expand Up @@ -654,6 +655,7 @@ def test_load_by_name_uses_hf_revision_when_github_download_fails(tmp_path, mock
assert result.num_nodes == 2
assert result.num_hyperedges == 1
assert not (tmp_path / "hf_cache" / "datasets--HypernetworkRG--algebra").exists()
assert not (tmp_path / "hf_cache" / ".locks" / "datasets--HypernetworkRG--algebra").exists()


def test_load_by_name_skips_cache_cleanup_when_hf_cache_dir_is_missing(tmp_path, mock_hypergraph):
Expand Down Expand Up @@ -692,6 +694,36 @@ def test_load_by_name_skips_cache_cleanup_when_hf_cache_dir_is_missing(tmp_path,
assert result.num_hyperedges == 1


def test_load_by_name_cleans_hf_cache_and_locks(tmp_path, mock_hypergraph):
hf_sha = "2bb641461e00c103fb5ef4fe6a30aad42500fc21"
fallback_file = tmp_path / "algebra.json.zst"
fallback_file.write_bytes(b"mock_zst_content")
payload = __hif_payload(mock_hypergraph)

response = requests.Response()
response.status_code = 404
response._content = b""

with (
patch("hyperbench.data.hif.os.path.exists", return_value=False),
patch("hyperbench.data.hif.os.path.isdir", return_value=True),
patch("hyperbench.data.hif.requests.get", return_value=response),
patch("hyperbench.data.hif.hf_hub_download", return_value=str(fallback_file)),
patch("hyperbench.data.hif.from_zst_file_to_json", return_value=payload),
patch("hyperbench.data.hif.validate_hif_data", return_value=True),
patch("hyperbench.data.hif.shutil.rmtree") as mock_rmtree,
pytest.warns(UserWarning, match="GitHub raw download failed"),
):
result = HIFLoader.load_by_name("algebra", hf_sha=hf_sha, save_on_disk=False)

cache_root = tmp_path / "hf_cache"
path_prefix = "datasets--HypernetworkRG--algebra"
mock_rmtree.assert_any_call(os.path.join(cache_root, path_prefix))
mock_rmtree.assert_any_call(os.path.join(cache_root, ".locks", path_prefix))
assert result.num_nodes == 2
assert result.num_hyperedges == 1


def test_load_by_name_raises_when_hf_sha_is_missing_on_fallback():
response = requests.Response()
response.status_code = 404
Expand Down
31 changes: 29 additions & 2 deletions hyperbench/tests/utils/hif_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,17 @@ def test_validate_hif_json_opens_the_given_path(tmp_path):
mock_file.assert_called_once_with(str(path_valid), encoding="utf-8")


def test_validate_hif_json_with_url_success():
def test_validate_hif_json_with_url_success(tmp_path):
path_valid = f"{MOCK_BASE_PATH}/hif_compliant.json"
schema_path = tmp_path / "hif_schema.json"

with patch("hyperbench.utils.hif_utils.requests.get") as mock_get:
mock_files = MagicMock()
mock_files.joinpath.return_value = schema_path

with (
patch("hyperbench.utils.hif_utils.requests.get") as mock_get,
patch("hyperbench.utils.hif_utils.resources.files", return_value=mock_files),
):
mock_response = MagicMock()
mock_response.json.return_value = {"type": "object"} # Minimal valid schema
mock_get.return_value = mock_response
Expand All @@ -126,6 +133,8 @@ def test_validate_hif_json_with_url_success():
timeout=10,
)

assert schema_path.exists()


def test_validate_hif_json_with_url_timeout_fallback():
path_valid = f"{MOCK_BASE_PATH}/hif_compliant.json"
Expand Down Expand Up @@ -172,6 +181,24 @@ def test_validate_hif_json_with_url_request_exception_fallback():
mock_path.open.assert_called_once_with("r", encoding="utf-8")


def test_validate_hif_data_raises_when_schema_load_fails():
with (
patch(
"hyperbench.utils.hif_utils.resources.files",
side_effect=RuntimeError("no local schema"),
),
patch(
"hyperbench.utils.hif_utils.requests.get",
side_effect=requests.RequestException("network down"),
),
pytest.raises(
RuntimeError,
match="Failed to load HIF schema from both local file and remote URL",
),
):
validate_hif_data({"incidences": [{"edge": 1, "node": 2}]})


def test_get_datasets_shas_returns_shas_and_none_on_failure():
names = ["algebra", "missing-dataset"]

Expand Down
19 changes: 17 additions & 2 deletions hyperbench/utils/hif_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,27 @@ def get_gh_dataset_sha(dataset_name: str, owner: str, repository: str) -> str |

def __load_hif_schema() -> dict[str, Any]:
url = f"https://raw.githubusercontent.com/HIF-org/HIF-standard/{HIF_SCHEMA_COMMIT_SHA}/schemas/hif_schema.json"

try:
return requests.get(url, timeout=10).json()
except (requests.RequestException, requests.Timeout):
with (
resources.files("hyperbench.utils.schema")
.joinpath("hif_schema.json")
.open("r", encoding="utf-8") as f
):
return json.load(f)
except Exception:
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
schema = response.json()
schema_path = resources.files("hyperbench.utils.schema").joinpath("hif_schema.json")
with (
resources.as_file(schema_path) as path,
path.open("w", encoding="utf-8") as f,
):
json.dump(schema, f)
return schema
except (requests.RequestException, requests.Timeout) as e:
raise RuntimeError(
"Failed to load HIF schema from both local file and remote URL. "
) from e
Loading