diff --git a/learn2learn/vision/datasets/cu_birds200.py b/learn2learn/vision/datasets/cu_birds200.py index de7fab4b..b9acef30 100644 --- a/learn2learn/vision/datasets/cu_birds200.py +++ b/learn2learn/vision/datasets/cu_birds200.py @@ -5,11 +5,15 @@ import torch from PIL import Image -from learn2learn.data.utils import download_file_from_google_drive +from learn2learn.data.utils import ( + download_file_from_google_drive, + download_file, +) DATA_DIR = 'cubirds200' DATA_FILENAME = 'CUB_200_2011.tgz' ARCHIVE_ID = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45' +ZENODO_URL = 'https://zenodo.org/record/8000562/files/CUB_200_2011.tgz' SPLITS = { 'train': [ @@ -353,11 +357,18 @@ def download(self): os.makedirs(data_path, exist_ok=True) tar_path = os.path.join(data_path, DATA_FILENAME) print('Downloading CUBirds200 dataset. (1.1Gb)') - download_file_from_google_drive(ARCHIVE_ID, tar_path) - tar_file = tarfile.open(tar_path) - tar_file.extractall(data_path) - tar_file.close() - os.remove(tar_path) + try: + download_file(ZENODO_URL, tar_path) + tar_file = tarfile.open(tar_path) + tar_file.extractall(data_path) + tar_file.close() + os.remove(tar_path) + except Exception: + download_file_from_google_drive(ARCHIVE_ID, tar_path) + tar_file = tarfile.open(tar_path) + tar_file.extractall(data_path) + tar_file.close() + os.remove(tar_path) def load_data(self, mode='train'): classes = sum(SPLITS.values(), []) if mode == 'all' else SPLITS[mode] @@ -425,7 +436,7 @@ def __len__(self): if __name__ == '__main__': import torchvision as tv cub = CUBirds200( - root='~/data', + root='./zenodo_data', mode='train', download=True, include_imagenet_duplicates=False,