### Batch Augmentation

In [None]:
from torchvision import datasets
import torch
import matplotlib.pyplot as plt
from imgaug import augmenters as iaa
import numpy as np

In [None]:
# Fetch the first 32 images in the training dataset
data_folder = '~/data/FMNIST'
fmnist = datasets.FashionMNIST(data_folder, train=True, download=True)

# the augmenter needs the images as np arrays, that's why i convert it here
tr_images = np.array(fmnist.data)
tr_targets = np.array(fmnist.targets)

In [None]:
#  Validation data
val_fmnist = datasets.FashionMNIST(data_folder, download=True, train=False)
val_images = np.array(val_fmnist.data)
val_targets = np.array(val_fmnist.targets)

In [None]:
# Augmentation to be performed.
aug = iaa.Sequential(
    [
        iaa.Affine(translate_px={"x":(-10,10)}, mode='constant')
    ]
)

When performing augmentation of images, you can follow 2 approaches.

1. Augmenting the batch one image at a time  
2. Augmenting the whole batch at once. 

From the experiment of the time taken to do both, it's faster to augment the whole batch at once instead of one image at a time.

This is the best practice in the industry.

In [None]:
# Dataset class to take input images and their augmenter
from torch.utils.data import DataLoader, Dataset

class MyDataste(Dataset):
    def __init__(self, x, y, aug=None):
        super().__init__()
        self.x = x
        self.y = y
        self.aug = aug
    def __getitem__(self, index):
        x, y = self.x[index], self.y[index]
        return x, y
    def __len__(self):
        return len(self.x)
    def collate_fn(self, batch):
        ims, classes = list(zip(*batch))
        if self.aug:
            self.aug.augment_images(images=ims)
        ims = torch.tensor(ims)[:, None, :, :]/255
        classes = torch.tensor(classes)
        return ims, classes
            

In [None]:
# create the train object
train = MyDataste(tr_images, tr_images, aug=aug)

In [None]:
# Define the model architecture
from torch.optim  import SGD, Adam
import torch.nn as nn

def get_model():
    model = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=3),
        nn.MaxPool2d(2),
        nn.ReLU(),
        nn.Conv2d(64, 128, kernel_size=3),
        nn.MaxPool2d(2),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(3200, 256),
        nn.ReLU(),
        nn.Linear(256, 10)     
    )
    # loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=1e-3)
    return model, loss_fn, optimizer

In [None]:
# define the train batch function in order to train on batches of the data
def get_batch(x, y, model, opt, loss_fn):
    model.train()
    prediction = model(x)
    batch_loss = loss_fn(prediction, y)
    batch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return batch_loss.item()

In [None]:
# define the get_data function to fetch the training and validation DataLoaders
def get_data():
    train = MyDataste(tr_images, tr_targets, aug)
    # DataLoader with the collate function
    trn_dl = DataLoader(train, batch_size=64, collate_fn=train.collate_fn, shuffle=True)
    val = MyDataste(val_images, val_targets)
    val_dl = DataLoader(val, batch_size=len(val_images), collate_fn=val.collate_fn, shuffle=True)
    return trn_dl, val_dl
    

In [None]:
trn_dl, val_dl = get_data()
model, loss_fn, optimizer = get_model()

In [None]:
# Train for 5 epochs
for epoch in range(5):
    print(epoch)
    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        batch_loss = get_batch(x=x, y=y, model=model,opt=optimizer, loss_fn=loss_fn)

In [None]:
pred  = []
ix = 24300
for px in range(-5,6):
    img = tr_images[ix]/255
    img = img.view(28, 28)
    img2 = np.roll(img, px, axis=1)
    plt.imshow(img2)
    plt.show()
    img3 = torch.Tensor(img2).view(-1, 1, 28, 28)
    np_output = model(img3).detach().numpy()
    pred.append(np.exp(np_output)/np.sum(np.exp(np_output)))

In [None]:
import seaborn as sns
fig, ax = plt.subplots(1,1, figsize=(12,10))
plt.title("Probability of each class")
sns.heatmap(
    np.array(pred).reshape(11,10),
    annot=True,
    ax=ax,
    fmt='.2f',
    xticklabels=fmnist.classes,
    yticklabels=[str(i)+str('pixel') for i in range(-6,5)],
    cmap='gray'
)
