Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
115 lines (84 sloc) 3.79 KB
#!/usr/bin/env python3
from __future__ import print_function
import os
import pickle
import numpy as np
import torch
import as data
from import download_file_from_google_drive
def download_pkl(google_drive_id, data_root, mode):
filename = 'mini-imagenet-cache-' + mode
file_path = os.path.join(data_root, filename)
if not os.path.exists(file_path + '.pkl'):
print('Downloading:', file_path + '.pkl')
download_file_from_google_drive(google_drive_id, file_path + '.pkl')
print("Download finished")
print("Data was already downloaded")
def index_classes(items):
idx = {}
for i in items:
if (i not in idx):
idx[i] = len(idx)
return idx
class MiniImagenet(data.Dataset):
The *mini*-ImageNet dataset was originally introduced by Vinyals et al., 2016.
It consists of 60'000 colour images of sizes 84x84 pixels.
The dataset is divided in 3 splits of 64 training, 16 validation, and 20 testing classes each containing 600 examples.
The classes are sampled from the ImageNet dataset, and we use the splits from Ravi & Larochelle, 2017.
1. Vinyals et al. 2016. “Matching Networks for One Shot Learning.” NeurIPS.
2. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR.
* **root** (str) - Path to download the data.
* **mode** (str, *optional*, default='train') - Which split to use.
Must be 'train', 'validation', or 'test'.
* **transform** (Transform, *optional*, default=None) - Input pre-processing.
* **target_transform** (Transform, *optional*, default=None) - Target pre-processing.
train_dataset ='./data', mode='train')
train_dataset =
train_generator =, ways=ways)
def __init__(self, root, mode='train', transform=None, target_transform=None):
super(MiniImagenet, self).__init__()
self.root = root
if not os.path.exists(root):
self.transform = transform
self.target_transform = target_transform
self.mode = mode
if self.mode == 'test':
google_drive_file_id = '1wpmY-hmiJUUlRBkO9ZDCXAcIpHEFdOhD'
elif self.mode == 'train':
google_drive_file_id = '1I3itTXpXxGV68olxM5roceUMG8itH9Xj'
elif self.mode == 'validation':
google_drive_file_id = '1KY5e491bkLFqJDp0-UWou3463Mo8AOco'
raise ('ValueError', 'Needs to be train, test or validation')
if not self._check_exists():
download_pkl(google_drive_file_id, root, mode)
pickle_file = os.path.join(self.root, 'mini-imagenet-cache-' + mode + '.pkl')
f = open(pickle_file, 'rb') = pickle.load(f)
self.x = torch.from_numpy(["image_data"]).permute(0, 3, 1, 2).float()
self.y = np.ones(len(self.x))
# TODO Remove index_classes from here
self.class_idx = index_classes(['class_dict'].keys())
for class_name, idxs in['class_dict'].items():
for idx in idxs:
self.y[idx] = self.class_idx[class_name]
def __getitem__(self, idx):
data = self.x[idx]
if self.transform:
data = self.transform(data)
return data, self.y[idx]
def __len__(self):
return len(self.x)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, 'mini-imagenet-cache-' + self.mode + '.pkl'))
You can’t perform that action at this time.