In [12]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [13]:
import os
import shutil
import numpy as np
import torch
import torchvision.datasets as datasets
from PIL import Image

In [14]:
def save_img(img, path):
    if isinstance(img, torch.Tensor): 
        img = img.detach().cpu().numpy()
    if isinstance(img, np.ndarray):
        img = np.expand_dims(img, -1)
        img = np.repeat(img, 3, 2)
        img = Image.fromarray(img)
    if isinstance(img, Image.Image):
        img.save(path)
    else:
        print('fail', type(img))

mnist = datasets.MNIST('/tmp/MNIST/', download=True, transform=None)

In [15]:
# Take the classes list from data
data_dir = 'data/MNIST/data' 
if not os.path.exists(data_dir):
    os.mkdir(data_dir)
classes = [c.split(' ')[-1] for c in mnist.classes]

# Create a new classes text file
with open(os.path.join(data_dir, 'classes.txt'), 'w') as f:
    for c in classes:
        f.write(c + '\n')

In [16]:
# Creat dir for saving each class
for target in classes:
    if not os.path.exists(os.path.join(data_dir, target)):
        os.mkdir(os.path.join(data_dir, target))

# Save data as images 
img_count = [0] * 10
for i in range(len(mnist.data)):
    x = mnist.data[i]
    y = mnist.targets[i].item()
    save_path = os.path.join(data_dir, classes[y], '{}.jpg'.format(img_count[y]))
    save_img(x, save_path)
    img_count[y] += 1

In [17]:
# Create random concepts using cifar10
cifar10 = datasets.CIFAR10('/tmp/CIFAR10/', download=True, transform=None)

Files already downloaded and verified


In [18]:
concept_dir = 'data/MNIST/concepts'
n_random_concepts = 100
n_exp_p_concept = 500
for concept in range(n_random_concepts):
    _concept = 'random500_{}'.format(concept)
    _dir = os.path.join(concept_dir, _concept)
    if os.path.exists(_dir):
        shutil.rmtree(_dir)
    os.mkdir(_dir)
    for i in range(n_exp_p_concept):
        img = cifar10[n_exp_p_concept * concept + i][0]
        save_path = os.path.join(_dir, '{}.jpg'.format(i))
        save_img(img, save_path)