Skip to content

Commit

Permalink
add download tests for CIFAR (pytorch#2747)
Browse files Browse the repository at this point in the history
* add download tests for CIFAR

* fix tests in case of bad request
  • Loading branch information
pmeier authored and bryant1410 committed Nov 22, 2020
1 parent 7552ac8 commit 130a6c8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests-schedule.yml
Expand Up @@ -35,7 +35,7 @@ jobs:
run: pip install pytest

- name: Run tests
run: pytest --durations=20 -ra test/test_datasets_download.py
run: pytest -ra -v test/test_datasets_download.py

- uses: JasonEtco/create-an-issue@v2.4.0
name: Create issue if download tests failed
Expand Down
58 changes: 34 additions & 24 deletions test/test_datasets_download.py
Expand Up @@ -4,6 +4,7 @@
import unittest.mock
from datetime import datetime
from os import path
from urllib.error import HTTPError
from urllib.parse import urlparse
from urllib.request import urlopen, Request

Expand Down Expand Up @@ -86,25 +87,26 @@ def retry(fn, times=1, wait=5.0):
)


def assert_server_response_ok(response, url=None):
msg = f"The server returned status code {response.code}"
if url is not None:
msg += f"for the the URL {url}"
assert 200 <= response.code < 300, msg
@contextlib.contextmanager
def assert_server_response_ok():
try:
yield
except HTTPError as error:
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error


def assert_url_is_accessible(url):
request = Request(url, headers=dict(method="HEAD"))
response = urlopen(request)
assert_server_response_ok(response, url)
with assert_server_response_ok():
urlopen(request)


def assert_file_downloads_correctly(url, md5):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with urlopen(url) as response, open(file, "wb") as fh:
assert_server_response_ok(response, url)
fh.write(response.read())
with assert_server_response_ok():
with urlopen(url) as response, open(file, "wb") as fh:
fh.write(response.read())

assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"

Expand All @@ -125,6 +127,16 @@ def make_download_configs(urls_and_md5s, name=None):
]


def collect_download_configs(dataset_loader, name):
try:
with log_download_attempts() as urls_and_md5s:
dataset_loader()
except Exception:
pass

return make_download_configs(urls_and_md5s, name)


def places365():
with log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
Expand All @@ -137,23 +149,19 @@ def places365():


def caltech101():
try:
with log_download_attempts() as urls_and_md5s:
datasets.Caltech101(".", download=True)
except Exception:
pass

return make_download_configs(urls_and_md5s, "Caltech101")
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), "Caltech101")


def caltech256():
try:
with log_download_attempts() as urls_and_md5s:
datasets.Caltech256(".", download=True)
except Exception:
pass
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), "Caltech256")


def cifar10():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR10")


return make_download_configs(urls_and_md5s, "Caltech256")
def cifar100():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100")


def make_parametrize_kwargs(download_configs):
Expand All @@ -166,7 +174,9 @@ def make_parametrize_kwargs(download_configs):
return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)


@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256())))
@pytest.mark.parametrize(
**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256(), cifar10(), cifar100()))
)
def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url))

Expand Down

0 comments on commit 130a6c8

Please sign in to comment.