In [1]:
from PIL import Image
import matplotlib.pyplot as plt
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class MNISTDataset(Dataset):
    def __init__(self, dataset_path: str):
        df = pd.read_csv(dataset_path)

        self.X = torch.Tensor(df.drop(columns=['label']).values).reshape((-1, 1, 28, 28)).to(device)
        self.Y = torch.LongTensor(df['label']).to(device)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx) -> tuple:
        return self.X[idx], self.Y[idx].item()

    @property
    def shape(self):
        return self.X.shape, self.Y.shape

    @property
    def data_shape(self):
        return self.X.shape[1:]

In [4]:
mnist_train = MNISTDataset("../../data/mnist_train.csv")
print(mnist_train.shape)
print(mnist_train.data_shape)

mnist_test = MNISTDataset("../../data/mnist_test.csv")
print(mnist_test.shape)
print(mnist_test.data_shape)

(torch.Size([60000, 1, 28, 28]), torch.Size([60000]))
torch.Size([1, 28, 28])
(torch.Size([10000, 1, 28, 28]), torch.Size([10000]))
torch.Size([1, 28, 28])


In [31]:
def color_reverse(ndarray: np.ndarray) -> np.ndarray:
    return 255 - ndarray

def to_image(ndarray: np.ndarray) -> Image:
    return Image.fromarray(ndarray[0].astype(np.int16)).convert('LA')

def white_to_transparent(img: Image):
    for x in range(img.width):
        for y in range(img.height):
            l, a = img.getpixel((x, y))
            
            if l == 255:
                img.putpixel((x, y), (255, 0))
    
    return img

def paste(canvas: Image, *images: tuple[Image, tuple[int, int]]):
    for image, (pos_x, pos_y) in images:
        for x in range(pos_x, pos_x + image.width):
            for y in range(pos_y, pos_y + image.height):
                c_l, c_a = canvas.getpixel((x, y))
                l, a = image.getpixel((x - pos_x, y - pos_y))
                
                if a != 0:  # if not transparent
                    canvas.putpixel((x, y), (l, a))
                

In [47]:
count = 3
samples = mnist_train.X[:count].numpy()
height, width = samples[0].shape[1:]

canvas = Image.new('LA', size=(count * width, count * height), color=(255, 255))
images = [to_image(color_reverse(sample)) for sample in samples]

In [48]:
image_and_pos = []

for image in images:
    image_and_pos.append((
        white_to_transparent(image),
	    (random.randint(0, canvas.width - image.width), random.randint(0, canvas.height - image.height))
    ))


In [49]:
paste(canvas, *image_and_pos)
canvas.show()