From 0f6f62c612c142c2d3736998d9fe4b2a72a64f2a Mon Sep 17 00:00:00 2001 From: ardunn Date: Fri, 19 Aug 2022 23:38:32 -0700 Subject: [PATCH] allow for automatic retries of dataset downloads, plus alter tests (attempt to fix hash problem that fails tests randomly) --- matminer/datasets/tests/test_utils.py | 8 +++- matminer/datasets/utils.py | 54 ++++++++++++++++++++++----- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/matminer/datasets/tests/test_utils.py b/matminer/datasets/tests/test_utils.py index 490291d22..521205662 100644 --- a/matminer/datasets/tests/test_utils.py +++ b/matminer/datasets/tests/test_utils.py @@ -60,8 +60,12 @@ def test_validate_dataset(self): _validate_dataset(self._path, url=None, file_hash=self._hash, download_if_missing=True) with self.assertRaises(UserWarning): - _validate_dataset(self._path, self._url, file_hash="!@#$%^&*", download_if_missing=True) - os.remove(self._path) + _validate_dataset(self._path, self._url, file_hash="!@#$%^&*", download_if_missing=True, n_retries_allowed=0) + if os.path.exists(self._path): + os.remove(self._path) + + with self.assertRaises(ValueError): + _validate_dataset(self._path, self._url, file_hash=self._hash, download_if_missing=True, n_retries_allowed=-1) _validate_dataset(self._path, self._url, self._hash, download_if_missing=True) self.assertTrue(os.path.exists(self._path)) diff --git a/matminer/datasets/utils.py b/matminer/datasets/utils.py index 04dcf3e4f..7a8389915 100644 --- a/matminer/datasets/utils.py +++ b/matminer/datasets/utils.py @@ -1,6 +1,8 @@ import hashlib import json import os +import warnings +import time import pandas as pd import requests @@ -43,7 +45,13 @@ def _get_data_home(data_home=None): return data_home -def _validate_dataset(data_path, url=None, file_hash=None, download_if_missing=True): +def _validate_dataset( + data_path, + url=None, + file_hash=None, + download_if_missing=True, + n_retries_allowed=3 +): """ Checks to see if a dataset is on the local machine, if not tries to download if download_if_missing is set to true, @@ -62,9 +70,18 @@ def _validate_dataset(data_path, url=None, file_hash=None, download_if_missing=T download_if_missing (bool): whether or not to try downloading the dataset if it is not on local disk + n_retries_allowed (int): Number of retries to do before failing a dataset + download based on hash mismatch. Retries are spaced apart by + 60s. + Returns (None) """ + DOWNLOAD_RETRY_WAIT = 60 + + if n_retries_allowed < 0: + raise ValueError("Number of retries for download cannot be less than 0.") + do_download = False # If the file doesn't exist, download it if not os.path.exists(data_path): @@ -81,15 +98,34 @@ def _validate_dataset(data_path, url=None, file_hash=None, download_if_missing=T print(f"Making dataset storage folder at {data_home}", flush=True) os.makedirs(data_home) - _fetch_external_dataset(url, data_path) + do_download = True - # Check to see if file hash matches the expected value, if hash is provided - if file_hash is not None: - if file_hash != _get_file_sha256_hash(data_path): - raise UserWarning( - "Error, hash of downloaded file does not match that " - "included in metadata, the data may be corrupt or altered" - ) + hash_mismatch_msg = "Error, hash of downloaded file does not match that " \ + "included in metadata, the data may be corrupt or altered" + if do_download: + n_retries = 0 + while n_retries <= n_retries_allowed: + try: + _fetch_external_dataset(url, data_path) + + # Check to see if file hash matches the expected value, if hash is provided + if file_hash is not None: + if file_hash != _get_file_sha256_hash(data_path): + raise UserWarning + break + except UserWarning: + warnings.warn(hash_mismatch_msg) + if n_retries < n_retries_allowed: + warnings.warn(f"Waiting {DOWNLOAD_RETRY_WAIT}s and trying again...") + time.sleep(DOWNLOAD_RETRY_WAIT) + else: + raise UserWarning( + f"File could not be downloaded to {data_path} after {n_retries_allowed} retries" + f"due to repeated hash validation failures." + ) + if os.path.exists(data_path): + os.remove(data_path) + n_retries += 1 def _fetch_external_dataset(url, file_path):