Skip to content

Commit

Permalink
Replace GDrive link by Zenodo.
Browse files Browse the repository at this point in the history
  • Loading branch information
seba-1511 committed Jun 3, 2023
1 parent 5fe670f commit 2ac9d32
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions learn2learn/vision/datasets/cu_birds200.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import torch

from PIL import Image
from learn2learn.data.utils import download_file_from_google_drive
from learn2learn.data.utils import (
download_file_from_google_drive,
download_file,
)

DATA_DIR = 'cubirds200'
DATA_FILENAME = 'CUB_200_2011.tgz'
ARCHIVE_ID = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'
ZENODO_URL = 'https://zenodo.org/record/8000562/files/CUB_200_2011.tgz'

SPLITS = {
'train': [
Expand Down Expand Up @@ -353,11 +357,18 @@ def download(self):
os.makedirs(data_path, exist_ok=True)
tar_path = os.path.join(data_path, DATA_FILENAME)
print('Downloading CUBirds200 dataset. (1.1Gb)')
download_file_from_google_drive(ARCHIVE_ID, tar_path)
tar_file = tarfile.open(tar_path)
tar_file.extractall(data_path)
tar_file.close()
os.remove(tar_path)
try:
download_file(ZENODO_URL, tar_path)
tar_file = tarfile.open(tar_path)
tar_file.extractall(data_path)
tar_file.close()
os.remove(tar_path)
except Exception:
download_file_from_google_drive(ARCHIVE_ID, tar_path)
tar_file = tarfile.open(tar_path)
tar_file.extractall(data_path)
tar_file.close()
os.remove(tar_path)

def load_data(self, mode='train'):
classes = sum(SPLITS.values(), []) if mode == 'all' else SPLITS[mode]
Expand Down Expand Up @@ -425,7 +436,7 @@ def __len__(self):
if __name__ == '__main__':
import torchvision as tv
cub = CUBirds200(
root='~/data',
root='./zenodo_data',
mode='train',
download=True,
include_imagenet_duplicates=False,
Expand Down

0 comments on commit 2ac9d32

Please sign in to comment.