diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index e7f8c1d13..925d542cb 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -424,6 +424,7 @@ def setup_fastmri(data_dir, src_data_dir): def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): + """Downloads and returns the download dir.""" imagenet_train_filepath = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) imagenet_val_filepath = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) @@ -506,7 +507,20 @@ def setup_imagenet_pytorch(data_dir): val_tar_file_path = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) test_dir_path = os.path.join(data_dir, 'imagenet_v2') - # Setup jax dataset dir + # Check if downloaded data has been moved + manual_download_dir = os.path.join(data_dir, 'jax', 'downloads', 'manual') + if not os.path.exists(train_tar_file_path): + if os.path.exists( + os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)): + train_tar_file_path = os.path.join(manual_download_dir, + IMAGENET_TRAIN_TAR_FILENAME) + if not os.path.exists(val_tar_file_path): + if os.path.exists( + os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)): + val_tar_file_path = os.path.join(manual_download_dir, + IMAGENET_VAL_TAR_FILENAME) + + # Setup pytorch dataset dir imagenet_pytorch_data_dir = os.path.join(data_dir, 'pytorch') os.makedirs(imagenet_pytorch_data_dir) os.makedirs(os.path.join(imagenet_pytorch_data_dir, 'train')) @@ -519,9 +533,9 @@ def setup_imagenet_pytorch(data_dir): logging.info('Moving {} to {}'.format(val_tar_file_path, imagenet_pytorch_data_dir)) shutil.move(val_tar_file_path, imagenet_pytorch_data_dir) - if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')): + if not os.path.exists(os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')): logging.info('Moving imagenet_v2 to {}'.format( - os.path.join(imagenet_jax_data_dir, 'imagenet_v2'))) + os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2'))) shutil.move(test_dir_path, os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2'))