# 👖 Autoencoders on Fashion MNIST

In [None]:
working_dir = "/home/mary/work/repos/generative_deep_Learning_2nd_edition_pytorch"

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os

# Add the path to the notebooks folder
notebooks_path = os.path.abspath(working_dir)
if notebooks_path not in sys.path:
    sys.path.append(notebooks_path)

In [None]:
import torchvision
from torchvision import transforms
import torch
from notebooks.utils import display

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary

import matplotlib.pyplot as plt
import math
import numpy as np

## 0. Parameters <a name="parameters"></a>


In [None]:
IMAGE_SIZE = 32
CHANNELS = 1
BATCH_SIZE = 100
BUFFER_SIZE = 1000
VALIDATION_SPLIT = 0.2
EMBEDDING_DIM = 2
EPOCHS = 3

## 1. Prepare the data <a name="prepare"></a>

In [None]:
data_dir = working_dir + "/data"

In [None]:
import torch.utils
import torch.utils.data


transform = transforms.Compose([
    transforms.Pad((2, 2, 2, 2)), # Padding (left, top, right, bottom)
    transforms.ToTensor()
    
])

train_data = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=transform, download=True)

train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_data = torchvision.datasets.FashionMNIST(data_dir, train=False, transform=transform, download=True)

test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [None]:
print(f"training data size= {len(train_data)}")
print(f"test data size= {len(test_data)}")

In [None]:
dataiter = iter(train_data_loader)
images, lables = next(dataiter)

print(type(images))
print(images.shape)

In [None]:
display(images[:10])
print(lables[:10])

In [None]:
class Encoder(nn.Module):

    def __init__(self, image_size, channels):
        super().__init__()

        p = self._get_padding_size(image_size, 2, 3)
        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=32, kernel_size=3, stride=2, padding=p)
        
        p = self._get_padding_size(image_size/2, 2, 3)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=p)

        p = self._get_padding_size(image_size/4, 2, 3)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=p)

        self.shape_before_flattening = (128, image_size/8, image_size/8)

        self.fc1 = nn.Linear(in_features=int(math.prod(self.shape_before_flattening)), out_features=EMBEDDING_DIM)

    def get_shape_before_flattening(self):
        return self.shape_before_flattening
    
    @staticmethod
    def _get_padding_size(input_w, stride, kernal_size):
        p = ((input_w /2) - 1) * stride
        p = (p - input_w) + kernal_size
        p = math.ceil(p/2)

        return p
    
    def forward(self, x):

        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        # # flatten
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        
        return x


In [None]:
class Decoder(nn.Module):
    def __init__(self, shape_before_flatten):
        super().__init__()

        self.shape_before_flatten = shape_before_flatten

        self.fc1 = nn.Linear(in_features=EMBEDDING_DIM, out_features=int(math.prod(self.shape_before_flatten)))

        p = self._get_padding_size(self.shape_before_flatten[1], stride=2, kernaal_size=3)
        self.conv_trans1 = nn.ConvTranspose2d(in_channels=self.shape_before_flatten[0], out_channels=128, 
                                              kernel_size=3, stride=2, padding=1, output_padding=1)
        
        p = self._get_padding_size(self.shape_before_flatten[1]*2, stride=2, kernaal_size=3)
        self.conv_trans2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, 
                                              stride=2, padding=p, output_padding=1)
        
        p = self._get_padding_size(self.shape_before_flatten[1]*4, stride=2, kernaal_size=3)
        self.conv_trans3 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3,
                                              stride=2, padding=p, output_padding=1)
        
        self.conv1 = nn.Conv2d(in_channels=32, out_channels=CHANNELS, kernel_size=3,
                               stride=1, padding='same')
        
    @staticmethod
    def _get_padding_size(input_w, stride, kernaal_size):
        p = ((input_w - 1) * stride) / 2
        p = p - input_w
        p = p + (kernaal_size / 2)
        p = p + 1/2
        return math.ceil(p)
    
    def forward(self, x):
        x = self.fc1(x)
        c, w, h = self.shape_before_flatten
        x = x.view(-1, c, w, h)
        x = self.conv_trans1(x)
        x = F.relu(x)
        x = self.conv_trans2(x)
        x = F.relu(x)
        x = self.conv_trans3(x)
        x = F.relu(x)
        x = self.conv1(x)
        x = F.sigmoid(x)
        # Should we add sigmoid?

        return x
        

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, image_size, channels):
        super().__init__()

        self.encoder = Encoder(image_size=image_size, channels=channels)
        self.shape_before_flatten = tuple(map(int, self.encoder.get_shape_before_flattening()))
        self.decoder = Decoder(self.shape_before_flatten)
    
    def forward(self, x):
        emb = self.encoder(x)
        img = self.decoder(emb)

        return img

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(32, 1).to(device)
print(encoder)

In [None]:
summary(encoder, (1, 32, 32))

In [None]:
shape_before_flatten = encoder.get_shape_before_flattening()
shape_before_flatten = tuple(map(int, shape_before_flatten))

decoder = Decoder(shape_before_flatten).to(device)
print(decoder)

In [None]:
summary(decoder, (EMBEDDING_DIM,))

In [None]:
auto_encoder = AutoEncoder(32, 1).to(device)
print(auto_encoder)

In [None]:
summary(auto_encoder, (1, 32, 32))

## 3. Train the autoencoder <a name="train"></a>

In [None]:

learning_rate = 0.0005

In [None]:
# loss_fn = nn.BCEWithLogitsLoss()
loss_fn = nn.BCELoss()

In [None]:
optmizer = optim.Adam(auto_encoder.parameters(), lr=learning_rate)

In [None]:
def fit(model, train_dataloader, optimizer, loss_fn, epochs=10):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    # loop over the number of epoch
    for i in range(epochs):
        #  set the model for training
        model.train()
        # loop over the dataloader to get all the data
        running_loss = 0.0
        num_samples = 0
        correct = 0
        for images, _ in train_dataloader:
            #  zero the gradiants of the optimizer
            optimizer.zero_grad()
            # move the training data to the same device as the model
            images = images.to(device)
            # Predict the lables
            predictions = model(images)
            # calculate the loss
            loss = loss_fn(predictions, images)
            # calcualte the gradients for the loss
            loss.backward()
            # updat the weights using the optimizer
            optimizer.step()
            # accumilate the loss
            running_loss += loss.item()

            # calcualte the accuracy
            _,pred_lable = torch.max(predictions, 1)
            # _, corr_label = torch.max(labels, 1)
        
        print( f"Epoch {i} / {epochs}: loss= {running_loss/len(train_dataloader):.4f}")

            

In [None]:
fit(auto_encoder, train_data_loader, optmizer, loss_fn, EPOCHS)

In [None]:
# save the trained models
model_dir = working_dir + "/notebooks/03_vae/01_autoencoder/models"
torch.save(auto_encoder.state_dict(), model_dir + "/autoendcoder")
torch.save(auto_encoder.encoder.state_dict(), model_dir + "/encoder")
torch.save(auto_encoder.decoder.state_dict(), model_dir +  "/decoder")

## 4. Reconstruct using the autoencoder <a name="reconstruct"></a>

In [None]:
n_to_predict = 5000
print(len(test_data))
test_iter = iter(test_data_loader)
example_images, example_lables = [], []

while (len(example_images)*BATCH_SIZE) < n_to_predict:
    test_images_batch, test_lables_batch = next(test_iter)
    example_images.append(test_images_batch)
    example_lables.append(test_lables_batch)

example_images = torch.stack(example_images)
w, h, c = example_images.shape[2:]
example_images = example_images.view(-1, w, h, c)

example_lables = torch.stack(example_lables).view(-1)

print(example_images.shape)
print(example_lables.shape)

In [None]:
auto_encoder.eval()

with torch.no_grad():
    emb = auto_encoder.encoder(example_images.to(device))
    gen_images = auto_encoder.decoder(emb)

In [None]:
print("Example of real items")
display(example_images)
print("Example of reconstructed items")
display(gen_images)

## 5. Embed using the encoder <a name="encode"></a>

In [None]:
with torch.no_grad():
    embeddings = auto_encoder.encoder(example_images.to(device))

embeddings_np = embeddings.to("cpu").detach().numpy()
print(embeddings_np[:10])

In [None]:
if EMBEDDING_DIM == 2:
    figure_size = 8

    plt.figure(figsize=(figure_size, figure_size))
    plt.scatter(embeddings_np[:, 0], embeddings_np[:, 1], c="black", s=3, alpha=0.5)
    plt.show()

In [None]:
if EMBEDDING_DIM == 2:
    figure_size = 8

    example_lables_np = example_lables.to("cpu").detach().numpy()

    plt.figure(figsize=(figure_size, figure_size))
    plt.scatter(embeddings_np[:, 0], embeddings_np[:, 1],
                cmap="rainbow",
                c = example_lables_np,
                alpha=0.8,
                s=3
                )
    plt.colorbar()
    plt.show()

## 6. Generate using the decoder <a name="decode"></a>

In [None]:
mins, maxs = np.min(embeddings_np, axis=0), np.max(embeddings_np, axis=0)
print(mins)
print(maxs)
grid_width, grid_height = (6, 6)

samples = np.random.uniform(mins, maxs, (grid_width*grid_height, EMBEDDING_DIM))

print(samples[:3])

In [None]:
type = embeddings.dtype
with torch.no_grad():
    reconstructions = auto_encoder.decoder(torch.as_tensor(samples).to(type).to(device))

print(reconstructions.shape)
reconstructions_np = reconstructions.permute(0, 2, 3,1).to("cpu").detach().numpy()

In [None]:
if EMBEDDING_DIM == 2:
    figure_size = 8

    plt.figure(figsize=(figure_size, figure_size))
    plt.scatter(embeddings_np[:, 0], embeddings_np[:, 1], c="black", s=3, alpha=0.5)
    plt.scatter(samples[:, 0], samples[:, 1], c="blue", alpha=1, s=20)
    plt.show()

    # Add underneath a grid of the decoded images
    fig = plt.figure(figsize=(figure_size, grid_height * 2))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)

    for i in range(grid_width * grid_height):
        ax = fig.add_subplot(grid_height, grid_width, i + 1)
        ax.axis("off")
        ax.text(
            0.5,
            -0.35,
            str(np.round(samples[i, :], 1)),
            fontsize=10,
            ha="center",
            transform=ax.transAxes,
        )
        ax.imshow(reconstructions_np[i, :, :], cmap="Greys")

In [None]:
# Colour the embeddings by their label (clothing type - see table)
if EMBEDDING_DIM == 2:

    figsize = 12
    grid_size = 15
    plt.figure(figsize=(figsize, figsize))
    plt.scatter(
        embeddings_np[:, 0],
        embeddings_np[:, 1],
        cmap="rainbow",
        c=example_lables_np,
        alpha=0.8,
        s=300,
    )
    plt.colorbar()

    x = np.linspace(min(embeddings_np[:, 0]), max(embeddings_np[:, 0]), grid_size)
    y = np.linspace(max(embeddings_np[:, 1]), min(embeddings_np[:, 1]), grid_size)
    xv, yv = np.meshgrid(x, y)
    xv = xv.flatten()
    yv = yv.flatten()
    grid = np.array(list(zip(xv, yv)))

    with torch.no_grad():
        reconstructions_2 = auto_encoder.decoder(torch.as_tensor(grid).to(type).to(device))

    reconstructions_np_2 = reconstructions_2.permute(0, 2, 3, 1).to("cpu").detach().numpy()
    # plt.scatter(grid[:, 0], grid[:, 1], c="black", alpha=1, s=10)
    plt.show()

    fig = plt.figure(figsize=(figsize, figsize))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)
    for i in range(grid_size**2):
        ax = fig.add_subplot(grid_size, grid_size, i + 1)
        ax.axis("off")
        ax.imshow(reconstructions_np_2[i, :, :], cmap="Greys")