Skip to content
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
Cannot retrieve contributors at this time
#!/usr/bin/env python3
import os
import tarfile
import subprocess
import numpy as np
import pickle
from PIL import Image
from import Dataset
from import download_file
from torchvision.datasets.folder import default_loader
DATA_DIR = 'quickdraw'
GCLOUD_BUCKET = 'gs://quickdraw_dataset/full/numpy_bitmap/'
'train': [
'police car',
'alarm clock',
'school bus',
'The Eiffel Tower',
'ceiling fan',
'light bulb',
'flip flops',
'cruise ship',
'string bean',
'paint can',
'stop sign',
'The Great Wall of China',
'hot tub',
'baseball bat',
'traffic light',
'The Mona Lisa',
'sea turtle',
'sleeping bag',
'hot air balloon',
'house plant',
'see saw',
'floor lamp',
'wine glass',
'remote control',
'picture frame',
'hockey puck',
'flying saucer',
'washing machine',
'swing set',
'aircraft carrier',
'cell phone',
'garden hose',
'birthday cake',
'pickup truck',
'soccer ball',
'power outlet',
'test': [
'animal migration',
'diving board',
'fire hydrant',
'frying pan',
'hockey stick',
'palm tree',
'wine bottle',
'validation': [
'coffee cup',
'golf club',
'hot dog',
'ice cream',
'paper clip',
'roller coaster',
'smiley face',
'tennis racquet',
class Quickdraw(Dataset):
The Quickdraw dataset was originally introduced by Google Creative Lab in 2017 and then re-purposed for few-shot learning in Triantafillou et al., 2020.
See Ha and Heck, 2017 for more information.
The dataset consists of roughly 50M drawing images of 345 objects.
Each image was hand-drawn by human annotators and is represented as black-and-white 28x28 pixel array.
We follow the train-validation-test splits of Triantafillou et al., 2020.
(241 classes for train, 52 for validation, and 52 for test.)
1. [](
2. Ha, David, and Douglas Eck. 2017. "A Neural Representation of Sketch Drawings." ArXiv '17.
3. Triantafillou et al. 2020. "Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples." ICLR '20.
* **root** (str) - Path to download the data.
* **mode** (str, *optional*, default='train') - Which split to use.
Must be 'train', 'validation', or 'test'.
* **transform** (Transform, *optional*, default=None) - Input pre-processing.
* **target_transform** (Transform, *optional*, default=None) - Target pre-processing.
* **download** (bool, *optional*, default=False) - Whether to download the dataset.
train_dataset ='./data', mode='train')
train_dataset =
train_generator =, num_tasks=1000)
def __init__(
root = os.path.expanduser(root)
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self._bookkeeping_path = os.path.join(self.root, 'quickdraw-' + mode + '-bookkeeping.pkl')
if not self._check_exists() and download:
def _check_exists(self):
if not os.path.exists(self.root):
return False
data_path = os.path.join(self.root, DATA_DIR)
if not os.path.exists(data_path):
return False
all_classes = sum(SPLITS.values(), [])
for cls_name in all_classes:
cls_path = os.path.join(data_path, cls_name + '.npy')
if not os.path.exists(cls_path):
return False
return True
def download(self):
if not os.path.exists(self.root):
data_path = os.path.join(self.root, DATA_DIR)
if not os.path.exists(data_path):
print('Downloading Quickdraw dataset (50Gb)')
all_classes = sum(SPLITS.values(), [])
gcloud_url = GCLOUD_BUCKET + '*.npy'
cmd = ['gsutil', '-m', 'cp', gcloud_url, data_path]
def load_bookkeeping(self, mode='all'):
# We do manual bookkeeping, because the size of the dataset.
if not os.path.exists(self._bookkeeping_path):
# create bookkeeping
data_path = os.path.join(self.root, DATA_DIR)
splits = sum(SPLITS.values(), []) if mode == 'all' else SPLITS[mode]
labels = list(range(len(splits)))
indices_to_labels = {}
labels_to_indices = {}
offsets = []
index_counter = 0
for cls_idx, cls_name in enumerate(splits):
cls_path = os.path.join(data_path, cls_name + '.npy')
cls_data = np.load(cls_path, mmap_mode='r')
num_samples = cls_data.shape[0]
labels_to_indices[cls_idx] = list(range(index_counter, index_counter + num_samples))
for i in range(num_samples):
indices_to_labels[index_counter + i] = cls_idx
index_counter += num_samples
bookkeeping = {
'labels_to_indices': labels_to_indices,
'indices_to_labels': indices_to_labels,
'labels': labels,
'offsets': offsets,
# Save bookkeeping to disk
with open(self._bookkeeping_path, 'wb') as f:
pickle.dump(bookkeeping, f, protocol=-1)
with open(self._bookkeeping_path, 'rb') as f:
bookkeeping = pickle.load(f)
self._bookkeeping = bookkeeping
self.labels_to_indices = bookkeeping['labels_to_indices']
self.indices_to_labels = bookkeeping['indices_to_labels']
self.labels = bookkeeping['labels']
self.offsets = bookkeeping['offsets']
def load_data(self, mode='train'):
data_path = os.path.join(self.root, DATA_DIR)
splits = sum(SPLITS.values(), []) if mode == 'all' else SPLITS[mode] = []
for cls_name in splits:
cls_path = os.path.join(data_path, cls_name + '.npy'), mmap_mode='r'))
def __getitem__(self, i):
label = self.indices_to_labels[i]
cls_data =[label]
offset = self.offsets[label]
image = cls_data[i - offset].reshape(28, 28)
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def __len__(self):
return len(self.indices_to_labels)
if __name__ == '__main__':
qd = Quickdraw(root='/ssd/seba-1511/data', download=True)
img, lab = qd[len(qd) - 1]