In [None]:
import torch
import torch.nn as nn

from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt

import math

from IPython.display import clear_output, display

%matplotlib inline

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
dataset_mnist_train = MNIST("data", download=True, transform=ToTensor())
dataset_mnist_test = MNIST("data", train=False, download=True, transform=ToTensor())

dataset_cifar_train = CIFAR10("data", download=True, transform=ToTensor())
dataset_cifar_test = CIFAR10("data", download=True, transform=ToTensor(), train=False)

In [None]:
i_row = 0
idxs = list()
for i in range(1000):
    label = dataset_cifar_train[i][1]
    if label == 3:  # 3 means cat
        idxs.append(i)

fig, axes = plt.subplots(7, math.floor(len(idxs) / 6), figsize=(10, 5))

i_row = 0
i_col = 0
for i, idx in enumerate(idxs):
    if i > 0 and i % math.floor(len(idxs) / 6) == 0:
        i_row += 1
        i_col = 0
    cat_img = dataset_cifar_train[idx][0].permute(1, 2, 0)
    axes[i_row, i_col].imshow(cat_img)
    axes[i_row, i_col].set_title(idx)
    axes[i_row, i_col].axis("off")
    i_col += 1

plt.tight_layout(pad=0)
plt.show()
plt.imshow(dataset_cifar_train[691][0].permute(1, 2, 0))
plt.show()


In [None]:
batch_size_train = 8
batch_size_test = 8

loader_mnist_train = torch.utils.data.DataLoader(
    dataset=dataset_mnist_train, batch_size=batch_size_train, shuffle=True
)

loader_mnist_test = torch.utils.data.DataLoader(
    dataset=dataset_mnist_test, batch_size=batch_size_test, shuffle=True
)

loader_cifar_train = torch.utils.data.DataLoader(
    dataset=dataset_cifar_train, batch_size=batch_size_train, shuffle=True
)

loader_cifar_test = torch.utils.data.DataLoader(
    dataset=dataset_cifar_test, batch_size=batch_size_test, shuffle=True
)


In [None]:
class BasicAutoEncoder(nn.Module):
    def __init__(self, input_size: int, latent_dim: int = 128):
        super().__init__()
        self.input_size = input_size

        self.encoder = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )

        self.latent_space = nn.Linear(256, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, input_size),
        )

    def forward(self, x):
        out = self.encoder(x)
        out = self.latent_space(out)
        out = self.decoder(out)
        return out

In [None]:
class CNNAutoEncoder(nn.Module):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.encoder = nn.Sequential(
            nn.Conv2d(
                in_channels=3, kernel_size=5, out_channels=32, stride=1
            ),  # in: (3, 32, 32) | out: (32, 28, 28)
            nn.MaxPool2d(
                kernel_size=2, stride=2
            ),  # in: (32, 28, 28) | out: (32, 14, 14)
            nn.Conv2d(
                in_channels=32, kernel_size=5, out_channels=64, stride=1
            ),  # in: (32, 14, 14) | out: (64, 10, 10)
            nn.MaxPool2d(
                kernel_size=2, stride=2
            ),  # in: (64, 10, 10) | out: (64, 5, 5) = 1600
        )
        self.fc1 = nn.Linear(in_features=1600, out_features=1024)
        self.latent_space = nn.Linear(in_features=1024, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=1600)
        # N_new = (N_old - 1) * S + (F - 1) + 1
        # Input Size: (64, 5, 5)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=64,
                out_channels=64,
                kernel_size=2,
                output_padding=0,
                stride=2,
            ),  # in: (64, 5, 5) | out: (64, 10, 10) | N_out = (5 - 1) * 2 + (2-1)+1 = 10
            nn.ConvTranspose2d(
                in_channels=64,
                out_channels=32,
                kernel_size=5,
                stride=1,
                output_padding=0,
            ),  # in: (64, 10, 10) | out: (6, 14, 14) | N_out = (10 - 1) * 1 + (5-1)+1 = 14
            nn.ConvTranspose2d(
                in_channels=32,
                out_channels=32,
                kernel_size=2,
                stride=2,
                output_padding=0,
            ),  # in: (6, 14, 14) | out: (6, 28, 28) | N_out = (14 - 1) * 2 + (2 - 1) + 1 = 28
            nn.ConvTranspose2d(
                in_channels=32,
                out_channels=3,
                kernel_size=5,
                stride=1,
                output_padding=0,
            ),  # in: (6, 28, 28) | out: (3, 32, 32) | N_out = (28 - 1) * 2 + (2 - 1) + 1 = 28
        )

    def forward(self, img):
        out = self.encoder(img)  # out: (64, 5, 5) = 400
        out = out.reshape(img.size()[0], 64 * 5 * 5)
        out = self.fc1(out)
        out = self.latent_space(out)
        out = self.fc2(out)
        out = out.reshape(img.size()[0], 64, 5, 5)  # in: (64, 5, 5)
        out = self.decoder(out)
        return out

In [None]:
basic_model = BasicAutoEncoder(input_size=28 * 28, latent_dim=256).to(device)
cnn_autoencoder = CNNAutoEncoder(batch_size=batch_size_train).to(device)


Train AutoEncoder


In [None]:
def train(
    model,
    loader_train,
    epoch: int = 50,
    img_to_pred=None,
    render_image: bool = False,
):
    _, axes = plt.subplots(1, 2)
    axes[0].set_title("Input Image")
    axes[1].set_title("Learned Image")

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for e in range(epoch):
        for i, (img, label) in enumerate(loader_train):
            # img = img.resize(batch_size_train, img_to_pred.size())
            img = img.to(device)
            label = label.to(device)
            y_pred = model(img)  # out: (16, 3, 32, 32)
            loss = criterion(y_pred, img)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                if render_image:
                    clear_output(wait=True)
                    display(plt.gcf())
                    print("Epoch:", e, "Step:", i, "Loss:", loss.item())
                    new_y_pred = model(img_to_pred.unsqueeze(0).to(device))
                    new_y_pred = new_y_pred.squeeze().permute(1, 2, 0)
                    axes[0].imshow(img_to_pred.permute(1, 2, 0))
                    axes[1].imshow(new_y_pred.detach().cpu().numpy())


train(
    model=cnn_autoencoder,
    loader_train=loader_cifar_train,
    img_to_pred=dataset_cifar_train[691][0],
    epoch=100,
    render_image=True,
)