<a href="https://colab.research.google.com/github/clashgamer123/SOC_Pytorch/blob/main/mnist_cnn_autoencode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Let us create an AutoEncoder that encodes and decodes the standard MNIST data. <br>
import all the libraries

In [29]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

Let us create our DataSet.

In [30]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = 0.1307, std = 0.3087),
])

batch_size = 64

mnist_data = datasets.MNIST(root = './data', train = True, download = True, transform = transform)
data_loader = DataLoader(dataset = mnist_data, batch_size = batch_size, shuffle = True)

Now comes the main part. Let us implement the Autoencoder.

In [31]:

class ConvAutoEncoder(nn.Module):
    def __init__(self):
        super(ConvAutoEncoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # Output: 32 x 14 x 14
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # Output: 64 x 7 x 7
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=7, stride=1, padding=0) # Output: 128 x 1 x 1
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=7, stride=1, padding=0),  # Output: 64 x 7 x 7
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Output: 32 x 14 x 14
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),   # Output: 1 x 28 x 28
            nn.Sigmoid()
        )


    def forward(self, x):
      encoded = self.encoder(x)
      decoded = self.decoder(encoded)
      return decoded

Now define the loss criteriona and the optimizer.

In [32]:
model = ConvAutoEncoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3, weight_decay = 1e-5)

Train the model.

In [None]:
epochs = 10
for epoch in range(epochs):
  outputs = []
  cum_loss = 0
  # Let us append some images and their reformed images so as to check how close they are.
  for dataset in data_loader:
    image, label = dataset

    reformed_img = model(image)
    loss = criterion(reformed_img, image)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    cum_loss += loss.item()*batch_size
    outputs.append((epoch, image, reformed_img))

  print(f'Epoch: {epoch+1}, Loss: {cum_loss/len(data_loader.dataset):.4f}')


Let us show the images plt and compare them.

In [None]:
for k in range(0, epochs, 4):
  plt.figure(figsize = (9, 2))
  plt.gray()
  imgs = outputs[k][1].detach().numpy()
  reformed_imgs = outputs[k][2].detach().numpy()
  for i, item in enumerate(imgs):
    if i >= 9:
      break
    plt.subplot(2, 9, i+1)
    plt.imshow(item.reshape(28, 28))

  for i, item in enumerate(reformed_imgs):
    if i >= 9:
      break
    plt.subplot(2, 9, 9+i+1)
    plt.imshow(item.reshape(28, 28))