From 42c4481b059be3661243a2f26c325fe7ab87592d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 14 Nov 2023 22:08:04 +0000 Subject: [PATCH 1/3] imagenet pytorch donwload fixes --- datasets/dataset_setup.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index e7f8c1d13..8c696e060 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -424,6 +424,7 @@ def setup_fastmri(data_dir, src_data_dir): def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): + """Downloads and returns the download dir.""" imagenet_train_filepath = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) imagenet_val_filepath = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) @@ -506,7 +507,16 @@ def setup_imagenet_pytorch(data_dir): val_tar_file_path = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) test_dir_path = os.path.join(data_dir, 'imagenet_v2') - # Setup jax dataset dir + # Check if downloaded data has been moved + manual_download_dir = os.path.join(data_dir, 'jax', 'downloads', 'manual') + if not os.path.exists(train_tar_file_path): + if os.path.exists(os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)) + train_tar_file_path = os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME) + if not os.path.exists(val_tar_file_path): + if os.path.exists(os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)) + val_tar_file_path = os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME) + + # Setup pytorch dataset dir imagenet_pytorch_data_dir = os.path.join(data_dir, 'pytorch') os.makedirs(imagenet_pytorch_data_dir) os.makedirs(os.path.join(imagenet_pytorch_data_dir, 'train')) @@ -519,9 +529,9 @@ def setup_imagenet_pytorch(data_dir): logging.info('Moving {} to {}'.format(val_tar_file_path, imagenet_pytorch_data_dir)) shutil.move(val_tar_file_path, imagenet_pytorch_data_dir) - if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')): + if not os.path.exists(os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')): logging.info('Moving imagenet_v2 to {}'.format( - os.path.join(imagenet_jax_data_dir, 'imagenet_v2'))) + os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2'))) shutil.move(test_dir_path, os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')) From 7b2c8a77cfbb90264afe705ea61551a4e0465cb6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 14 Nov 2023 22:18:01 +0000 Subject: [PATCH 2/3] syntax fix --- datasets/dataset_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 8c696e060..160088cb1 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -510,10 +510,10 @@ def setup_imagenet_pytorch(data_dir): # Check if downloaded data has been moved manual_download_dir = os.path.join(data_dir, 'jax', 'downloads', 'manual') if not os.path.exists(train_tar_file_path): - if os.path.exists(os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)) + if os.path.exists(os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)): train_tar_file_path = os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME) if not os.path.exists(val_tar_file_path): - if os.path.exists(os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)) + if os.path.exists(os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)): val_tar_file_path = os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME) # Setup pytorch dataset dir From 4c58be8e4bbbf22a9b754de90a078cb1c14f66d4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 15 Nov 2023 16:48:57 +0000 Subject: [PATCH 3/3] formatting --- datasets/dataset_setup.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 160088cb1..925d542cb 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -507,14 +507,18 @@ def setup_imagenet_pytorch(data_dir): val_tar_file_path = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) test_dir_path = os.path.join(data_dir, 'imagenet_v2') - # Check if downloaded data has been moved + # Check if downloaded data has been moved manual_download_dir = os.path.join(data_dir, 'jax', 'downloads', 'manual') if not os.path.exists(train_tar_file_path): - if os.path.exists(os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)): - train_tar_file_path = os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME) + if os.path.exists( + os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)): + train_tar_file_path = os.path.join(manual_download_dir, + IMAGENET_TRAIN_TAR_FILENAME) if not os.path.exists(val_tar_file_path): - if os.path.exists(os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)): - val_tar_file_path = os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME) + if os.path.exists( + os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)): + val_tar_file_path = os.path.join(manual_download_dir, + IMAGENET_VAL_TAR_FILENAME) # Setup pytorch dataset dir imagenet_pytorch_data_dir = os.path.join(data_dir, 'pytorch')