Skip to content

Commit

Permalink
allow for automatic retries of dataset downloads, plus alter tests (a…
Browse files Browse the repository at this point in the history
…ttempt to fix hash problem that fails tests randomly)
  • Loading branch information
ardunn committed Aug 20, 2022
1 parent 4a86c0a commit 0f6f62c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
8 changes: 6 additions & 2 deletions matminer/datasets/tests/test_utils.py
Expand Up @@ -60,8 +60,12 @@ def test_validate_dataset(self):
_validate_dataset(self._path, url=None, file_hash=self._hash, download_if_missing=True)

with self.assertRaises(UserWarning):
_validate_dataset(self._path, self._url, file_hash="!@#$%^&*", download_if_missing=True)
os.remove(self._path)
_validate_dataset(self._path, self._url, file_hash="!@#$%^&*", download_if_missing=True, n_retries_allowed=0)
if os.path.exists(self._path):
os.remove(self._path)

with self.assertRaises(ValueError):
_validate_dataset(self._path, self._url, file_hash=self._hash, download_if_missing=True, n_retries_allowed=-1)

_validate_dataset(self._path, self._url, self._hash, download_if_missing=True)
self.assertTrue(os.path.exists(self._path))
Expand Down
54 changes: 45 additions & 9 deletions matminer/datasets/utils.py
@@ -1,6 +1,8 @@
import hashlib
import json
import os
import warnings
import time

import pandas as pd
import requests
Expand Down Expand Up @@ -43,7 +45,13 @@ def _get_data_home(data_home=None):
return data_home


def _validate_dataset(data_path, url=None, file_hash=None, download_if_missing=True):
def _validate_dataset(
data_path,
url=None,
file_hash=None,
download_if_missing=True,
n_retries_allowed=3
):
"""
Checks to see if a dataset is on the local machine,
if not tries to download if download_if_missing is set to true,
Expand All @@ -62,9 +70,18 @@ def _validate_dataset(data_path, url=None, file_hash=None, download_if_missing=T
download_if_missing (bool): whether or not to try downloading the
dataset if it is not on local disk
n_retries_allowed (int): Number of retries to do before failing a dataset
download based on hash mismatch. Retries are spaced apart by
60s.
Returns (None)
"""
DOWNLOAD_RETRY_WAIT = 60

if n_retries_allowed < 0:
raise ValueError("Number of retries for download cannot be less than 0.")

do_download = False
# If the file doesn't exist, download it
if not os.path.exists(data_path):

Expand All @@ -81,15 +98,34 @@ def _validate_dataset(data_path, url=None, file_hash=None, download_if_missing=T
print(f"Making dataset storage folder at {data_home}", flush=True)
os.makedirs(data_home)

_fetch_external_dataset(url, data_path)
do_download = True

# Check to see if file hash matches the expected value, if hash is provided
if file_hash is not None:
if file_hash != _get_file_sha256_hash(data_path):
raise UserWarning(
"Error, hash of downloaded file does not match that "
"included in metadata, the data may be corrupt or altered"
)
hash_mismatch_msg = "Error, hash of downloaded file does not match that " \
"included in metadata, the data may be corrupt or altered"
if do_download:
n_retries = 0
while n_retries <= n_retries_allowed:
try:
_fetch_external_dataset(url, data_path)

# Check to see if file hash matches the expected value, if hash is provided
if file_hash is not None:
if file_hash != _get_file_sha256_hash(data_path):
raise UserWarning
break
except UserWarning:
warnings.warn(hash_mismatch_msg)
if n_retries < n_retries_allowed:
warnings.warn(f"Waiting {DOWNLOAD_RETRY_WAIT}s and trying again...")
time.sleep(DOWNLOAD_RETRY_WAIT)
else:
raise UserWarning(
f"File could not be downloaded to {data_path} after {n_retries_allowed} retries"
f"due to repeated hash validation failures."
)
if os.path.exists(data_path):
os.remove(data_path)
n_retries += 1


def _fetch_external_dataset(url, file_path):
Expand Down

0 comments on commit 0f6f62c

Please sign in to comment.