In [None]:
import os

from PIL import Image

import numpy as np

from tqdm.notebook import tqdm

from torch.utils.data import DataLoader, Subset, random_split
import torchvision
from torchvision import transforms

In [None]:
#open relevant datasets

root = '/mnt/files/data'

mnist_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.Grayscale(3)
])
mnist_train_data = torchvision.datasets.MNIST(
    root=root, train=True, transform=mnist_transform
)
mnist_test_data = torchvision.datasets.MNIST(
    root=root, train=False, transform=mnist_transform
)

cifar10_train_data = torchvision.datasets.CIFAR10(
    root=root, train=True
)
cifar10_test_data = torchvision.datasets.CIFAR10(
    root=root, train=False
)

In [None]:
#prepare a non-perfectly correlated dataset

seed = 42
rng = np.random.default_rng(seed)

err_rate = 0.05

dominos_paths = root + '/mnist-cifar10-train-' + str(100 - int(100*err_rate)) + '/{}'
for y in range(10):
    if not os.path.exists(dominos_paths.format(y)):
        os.makedirs(dominos_paths.format(y))

dominos_files = dominos_paths + '/{}.jpg'
class_mnist_pos = np.zeros(10, dtype=np.intc)
mnist_pos = rng.permutation(len(mnist_train_data))
used = set()

for x, y in tqdm(cifar10_train_data):
    err = rng.binomial(1, err_rate)
    if err:
        i = class_mnist_pos.min()
        while (i in used) or (mnist_train_data[mnist_pos[i]][1] == y):
            i += 1
        used.add(i)
    else:
        i = class_mnist_pos[y]
        while mnist_train_data[mnist_pos[i]][1] != y:
            i += 1
        used.add(i)
    class_mnist_pos[mnist_train_data[mnist_pos[i]][1]] = i + 1
    domino = Image.new('RGB', (32, 64))
    domino.paste(mnist_train_data[mnist_pos[i]][0], (0, 0))
    domino.paste(x, (0, 32))
    domino.save(dominos_files.format(y, mnist_pos[i]))

In [None]:
#prepare a perfectly correlated dataset

seed = 42
rng = np.random.default_rng(seed)

err_rate = 0.0

dominos_paths = root + '/mnist-cifar10-train-' + str(100 - int(100*err_rate)) + '/{}'
for y in range(10):
    if not os.path.exists(dominos_paths.format(y)):
        os.makedirs(dominos_paths.format(y))

dominos_files = dominos_paths + '/{}.jpg'
class_mnist_pos = np.zeros(10, dtype=np.intc)
mnist_pos = rng.permutation(len(mnist_train_data))
used = set()

for x, y in tqdm(cifar10_train_data):
    err = rng.binomial(1, err_rate)
    if err:
        i = class_mnist_pos.min()
        while (i in used) or (mnist_train_data[mnist_pos[i]][1] == y):
            i += 1
        used.add(i)
    else:
        i = class_mnist_pos[y]
        while mnist_train_data[mnist_pos[i]][1] != y:
            i += 1
        used.add(i)
    class_mnist_pos[mnist_train_data[mnist_pos[i]][1]] = i + 1
    domino = Image.new('RGB', (32, 64))
    domino.paste(mnist_train_data[mnist_pos[i]][0], (0, 0))
    domino.paste(x, (0, 32))
    domino.save(dominos_files.format(y, mnist_pos[i]))

In [None]:
#prepare an uncorrelated dataset

seed = 42
rng = np.random.default_rng(seed)

dominos_paths = root + '/mnist-cifar10-test/{}'
for y in range(10):
    if not os.path.exists(dominos_paths.format(y)):
        os.makedirs(dominos_paths.format(y))

dominos_files = dominos_paths + '/{}.jpg'
mnist_pos = rng.permutation(len(mnist_test_data))

for (x, y), i in tqdm(zip(cifar10_test_data, mnist_pos)):
    domino = Image.new('RGB', (32, 64))
    domino.paste(mnist_test_data[mnist_pos[i]][0], (0, 0))
    domino.paste(x, (0, 32))
    domino.save(dominos_files.format(y, mnist_pos[i]))

In [None]:
#prepare a reversed dataset

seed = 42
rng = np.random.default_rng(seed)

dominos_paths = root + '/mnist-cifar10-test-rev/{}'
for y in range(10):
    if not os.path.exists(dominos_paths.format(y)):
        os.makedirs(dominos_paths.format(y))

dominos_files = dominos_paths + '/{}.jpg'
mnist_pos = rng.permutation(len(mnist_test_data))

for (x, y), i in tqdm(zip(cifar10_test_data, mnist_pos)):
    domino = Image.new('RGB', (32, 64))
    y_mnist = mnist_test_data[mnist_pos[i]][1]
    domino.paste(mnist_test_data[mnist_pos[i]][0], (0, 0))
    domino.paste(x, (0, 32))
    domino.save(dominos_files.format(y_mnist, mnist_pos[i]))