## Create Multi-Digit Dataset
##### Only do this if really necessary because it takes a really long time to create a big enough data set. Otherwise, just load the pre made data set. 

In [None]:
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import datasets, transforms

In [None]:
num_of_digits = 3
train_imgs_to_gen = 120000 # 960000 1920000
test_img_to_gen = 30000 # 240000 480000
random_seed = 1

dataset_path = "../../data"

In [None]:
class NumberDataset(torchvision.datasets.MNIST):
    def __init__(
        self,
        num_to_generate=120000,
        num_of_digits=1,
        im_width=28,
        im_height=28,
        train=True,
        download=True,
        dataset_path="",
    ):
        """
        Args :
          num_of_digits (int) : the number of digits in each number
          im_width (int) : the width of a single digit image
          im_height (int) : the height of a single digit image
          train (bool) : if True create the images from the training set
          download (bool) : if True downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
        """

        self.transform = transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5,), (0.5,)),
            ]
        )
        self.data = torch.utils.data.DataLoader(
            datasets.MNIST(
                dataset_path, train=train, download=download, transform=self.transform
            ),
            batch_size=num_of_digits,
            shuffle=True,
        )

        self.res = []

        for i in range(num_to_generate):
            if (i % 1000) == 0 and i != 0:
                print("Done {} numbers".format(i))
            digits, vals = next(iter(self.data))
            target = 0
            image = torch.transpose(
                torch.reshape(
                    torch.transpose(digits, 2, 3),
                    (1, num_of_digits * im_width, im_height),
                ),
                1,
                2,
            )
            for j in range(num_of_digits):
                target = target + vals[j] * pow(10, num_of_digits - 1 - j)
            self.res.append((image, target))


In [None]:
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

train_data = NumberDataset(num_to_generate=train_imgs_to_gen, num_of_digits=num_of_digits, dataset_path=dataset_path).res
print(f"Done proccessing training set, got {len(train_data)} numbers")

test_data = NumberDataset(num_to_generate=test_img_to_gen, num_of_digits=num_of_digits, dataset_path=dataset_path, train=False).res
print(f"Done proccessing test set, got {len(test_data)} numbers")

fig2, axes = plt.subplots(3,3)
fig2.tight_layout()
for i in range(9):
  sub = axes[int(i/3), i%3]
  sub.imshow(train_data[i][0][0], cmap='gray', interpolation='none')
  sub.set_title("Ground Truth: {}".format(train_data[i][1])) 
  sub.set_xticks([])
  sub.set_yticks([])

In [None]:
torch.save(train_data, f'{dataset_path}/{num_of_digits}_digit_model/mnist_{num_of_digits}_digit_train_data')
torch.save(test_data, f'{dataset_path}/{num_of_digits}_digit_model/mnist_{num_of_digits}_digit_test_data')