diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 925d542cb..2f808b64b 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -386,11 +386,11 @@ def download_fastmri(data_dir, return data_dir -def extract(source, dest): +def extract(source, dest, mode='r:xz'): if not os.path.exists(dest): os.path.makedirs(dest) logging.info(f'Extracting {source} to {dest}') - tar = tarfile.open(source, 'r:xz') + tar = tarfile.open(source, mode) logging.info('Opened tar') tar.extractall(dest) @@ -543,7 +543,8 @@ def setup_imagenet_pytorch(data_dir): logging.info('Extracting imagenet train data') extract( os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME), - os.path.join(imagenet_pytorch_data_dir, 'train')) + os.path.join(imagenet_pytorch_data_dir, 'train'), + mode='r:') train_tar_filenames = os.listdir( os.path.join(imagenet_pytorch_data_dir, 'train')) @@ -552,13 +553,15 @@ def setup_imagenet_pytorch(data_dir): dir_name = tar_filename[:-4] extract( os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME), - os.path.join(imagenet_pytorch_data_dir, 'train', dir_name)) + os.path.join(imagenet_pytorch_data_dir, 'train', dir_name), + mode='r:') # Extract val data logging.info('Extracting imagenet val data') extract( os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME), - os.path.join(imagenet_pytorch_data_dir, 'val')) + os.path.join(imagenet_pytorch_data_dir, 'val'), + mode='r:') valprep_command = [ 'wget',