In [1]:
import torch
import torchvision
from torchvision.datasets import MNIST

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from random import sample, random
import itertools

from sklearn.model_selection import train_test_split

import os

In [2]:
train_set = pd.read_csv("/Users/lfeng/Projects/11-785-Project/dataset/MNIST_CSV/mnist_train.csv")
test_images = pd.read_csv("/Users/lfeng/Projects/11-785-Project/dataset/MNIST_CSV/mnist_test.csv")

In [3]:
train_images, val_images, train_labels, val_labels = train_test_split(train_set.iloc[:, 1:], 
                                                                     train_set.iloc[:, 0], 
                                                                     test_size=0.3)

train_images.reset_index(drop=True, inplace=True)
val_images.reset_index(drop=True, inplace=True)
train_labels.reset_index(drop=True, inplace=True)
val_labels.reset_index(drop=True, inplace=True)

In [14]:
train_data = train_images.to_numpy()
train_labels = train_labels.to_numpy()

val_data = val_images.to_numpy()
val_labels = val_labels.to_numpy()

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(8,8))
fig.subplots_adjust(hspace=.3)
for i in range(3):
    for j in range(3):
        ax[i][j].axis('off')
        ax[i][j].imshow(train_images.iloc[[i+(j*5)], :].to_numpy().astype(np.uint8).reshape(28, 28), cmap='gray')
        ax[i][j].set_title(train_labels[i+(j*5)])

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(8,8))
fig.subplots_adjust(hspace=.3)

for i in range(4):
    for j in range(4):
        ax[i][j].axis('off')
        ax[i][j].imshow(tiles[i * 4 + j], cmap='gray')

In [18]:
class JigsawDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, permutations=10, img_transformer=None):
        self.images = images
        self.labels = labels
        self.permutations = permutations

        self.N = len(self.images)
        self.grid_size = 4

    def __retrieve_permutations(self):
        nums = range(self.grid_size * self.grid_size)

        permutations = []
        for i in range(self.permutations):
            permutations.append(np.random.permutation(nums))

        return permutations       

    def __get_image(self, index):
        return self.images[index]

    def __get_tiles(self, index):
        img = self.__get_image(index).reshape(28, 28)
        tiles = np.zeros((16, 7, 7))

        for i in range(4):
            for j in range(4):
                tiles[i * 4 + j] = img[i*7:i*7+7, j*7:j*7+7]

        return tiles
    
    def make_grid(x):
        return torchvision.utils.make_grid(x, self.grid_size, padding=0)
        
    def __getitem__(self, index):
        n_grids = self.grid_size ** 2
        tiles = self.__get_tiles(index)

        order = self.__retrieve_permutations()
        
        data = [tiles[order[t]] for t in range(n_grids)]
            
        item = torch.stack(data, 0)
        return make_grid(item), int(order), int(self.labels[index])

    def __len__(self):
        return self.N

In [19]:
train_loader = torch.utils.data.DataLoader(
        dataset     = train_data,
        num_workers = 2,
        batch_size  = 64,
        pin_memory  = True,
        shuffle     = True,
)


In [None]:
train_loader = torch.utils.data.DataLoader(
        dataset     = train_data,
        num_workers = 2,
        batch_size  = config['batch_size'],
        pin_memory  = True,
        shuffle     = True,
)

val_loader = torch.utils.data.DataLoader(
        dataset     = val_data,
        num_workers = 1,
        batch_size  = config['batch_size'],
        pin_memory  = True,
        shuffle     = False,
    )

test_loader = torch.utils.data.DataLoader(
        dataset     = test_data,
        num_workers = 1,
        batch_size  = config['batch_size'],
        pin_memory  = True,
        shuffle     = False,
    )

print("Batch size: ", config['batch_size'])
print("Train dataset samples = {}, batches = {}".format(train_data.__len__(), len(train_loader)))
print("Val dataset samples = {}, batches = {}".format(val_data.__len__(), len(val_loader)))
print("Test dataset samples = {}, batches = {}".format(test_data.__len__(), len(test_loader)))