In [None]:
import torch
from torch import nn
import torch.nn.functional as f
import torchvision

from tqdm import trange
import random

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

params = {
    "LR": 3e-4,
    "N_BATCHS": 32,
    "N_EPOCHS": 10,
}

device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
print("Device type:", device)

In [None]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root="./datasets", train=True, transform=transform, download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="./datasets", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=params["N_BATCHS"], shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=params["N_BATCHS"], shuffle=True,
)

In [None]:
class AE(nn.Module):
    def __init__(self,
        latent_space_size=4,
    ) -> None:
        super(AE, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(5,5), padding=1, stride=2),
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 128, kernel_size=(5,5), padding=1, stride=2),
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Flatten(),
            nn.Linear(128*6*6, latent_space_size),
        )
        self.deconv_0 = nn.Sequential(
            nn.Linear(latent_space_size, 128*6*6),
            # nn.BatchNorm1d(128*6*6),
            nn.ELU(),
        )
        self.deconv_1 = nn.Sequential(
            nn.ConvTranspose2d(128, 46, kernel_size=(5,5), padding=1, stride=2),
            nn.BatchNorm2d(46),
            nn.ELU(),
            nn.ConvTranspose2d(46, 1, kernel_size=(6,6), padding=1, stride=2),
            nn.Sigmoid(), # the prediction probabilities of each class are computed independently of each other.
        )

    def forward(self, x):
        x = x.view(x.size(0), 1, 28, 28)
        z = self.conv(x)
        z = self.deconv_0(z).view(x.size(0), 128, 6, 6)
        x_hat = self.deconv_1(z).view(x.size(0), 1, 28, 28)
        return x_hat

In [None]:
model_ae = AE()
model_ae= nn.DataParallel(model_ae)
model_ae.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(model_ae.parameters(), lr=params["LR"])

train_history, test_history = [], []

for epoch in range(params["N_EPOCHS"]):
    train_loss = 0.
    test_loss = 0.
    for X, _ in train_loader:
        X = X.to(device)
        X_hat = model_ae(X)
        loss = loss_fn(X_hat, X)
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_loss += loss.item()

    with torch.no_grad():
        for X, _ in test_loader:
            X = X.to(device)
            X_hat = model_ae(X)
            loss = loss_fn(X_hat, X)
            test_loss += loss.item()

    train_loss /= len(train_loader)            
    test_loss /= len(test_loader)
    print(f"{epoch+1}/{params['N_EPOCHS']} | train_loss = {train_loss:.6f}; eval_loss = {test_loss:.6f}")
    train_history.append(train_loss)
    test_history.append(test_loss)

In [None]:
SAMPLE = [random.choice(train_dataset.data) for _ in range(20)]
RANDOM_SAMPLE = [random.choice(torch.randn(size=(140, 28, 28))) for _ in range(20)]
plt.figure(figsize=(16,16))
plt.title("Input")
plt.imshow(np.concatenate(RANDOM_SAMPLE, axis=1))

plt.figure(figsize=(16,16))
plt.title("Output : AE")
plt.imshow(np.concatenate([model_ae(x.to(device).float().view(1, 1, 28, 28)).detach().cpu().numpy().reshape((28,28)) for x in RANDOM_SAMPLE], axis=1))

In [None]:
# plt.figure(figsize=(16,16))
plt.title("MSE loss")
plt.grid()
plt.plot(train_history, label="train loss")
plt.plot(test_history, label="test loss")
plt.legend()
plt.show()

In [None]:
class VAE(nn.Module):
    def __init__(self,
        hid_size=200,
        dist_size=10,
        num_base_channels=1,
        act_fn=nn.ReLU) -> None:
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28*28, hid_size),
            act_fn(),
            nn.Linear(hid_size, hid_size),
            act_fn(),
            nn.Linear(hid_size, hid_size),
            act_fn(),
        )

        self.mu, self.sigma = nn.Linear(hid_size, dist_size), nn.Linear(hid_size, dist_size)
        self.decoder = nn.Sequential(
            nn.Linear(dist_size, hid_size),
            act_fn(),
            nn.Linear(hid_size, hid_size),
            act_fn(),
            nn.Linear(hid_size, 28*28),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.tensor) -> tuple:
        num_batchs = x.size(0)
        x = x.view(num_batchs, -1)
        x = self.encoder(x)
        mu, sigma = self.mu(x), self.sigma(x)
        z = mu + sigma * torch.randn_like(sigma)
        z = self.decoder(z).view(x.size(0), 1, 28, 28)
        return z, mu, sigma

class VAELoss(nn.Module):
    def __init__(self) -> None:
        super(VAELoss, self).__init__()
    
    def forward(self, x, x_hat, mu, sigma):
        kl_divergence = - torch.sum(1+torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
        return nn.BCELoss()(x_hat, x) + kl_divergence

In [None]:
model_vae = VAE()
model_vae= nn.DataParallel(model_vae)
model_vae.to(device)
optimizer = torch.optim.Adam(model_vae.parameters(), lr=params["LR"])
loss_fn = VAELoss()


# x_hat, mu, sigma = model_vae(x:=torch.randn(size=(1,1, 28,28)).to(device))
# print(x_hat.size())
# print(loss_fn(x, x_hat, mu, sigma))

train_history, test_history = [], []

for epoch in range(params['N_EPOCHS']):
    train_loss = 0.
    test_loss = 0.
    for x, _ in train_loader:
        optimizer.zero_grad()
        x = x.view(x.size(0), 1, 28, 28).to(device)
        x_hat, mu, sigma = model_vae(x)
        loss = loss_fn(x, x_hat, mu, sigma)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    with torch.no_grad():
        for x, _ in test_loader:
            x = x.view(x.size(0), 1, 28, 28).to(device)
            x_hat, mu, sigma = model_vae(x)
            loss = loss_fn(x_hat, x, mu, sigma)
            test_loss += loss.item()

    train_loss /= len(train_loader)            
    test_loss /= len(test_loader)
    print(f"{epoch+1}/{params['N_EPOCHS']} | train_loss = {train_loss:.6f}; eval_loss = {test_loss:.6f}")
    train_history.append(train_loss)
    test_history.append(test_loss)

In [None]:
SAMPLE = [random.choice(train_dataset.data) for _ in range(20)]
# SAMPLE = [random.choice(torch.randn(size=(140, 28, 28))) for _ in range(20)]
plt.figure(figsize=(16,16))
plt.title("Input")
plt.imshow(np.concatenate(SAMPLE, axis=1))

plt.figure(figsize=(16,16))
plt.title("Output : VAE")
plt.imshow(np.concatenate([model_vae(x.to(device).float().view(1, 1, 28, 28))[0].detach().cpu().numpy().reshape((28,28)) for x in SAMPLE], axis=1))

In [None]:
# plt.figure(figsize=(16,16))
plt.title("BCE loss")
plt.grid()
plt.plot(train_history, label="train loss")
plt.plot(test_history, label="test loss")
plt.legend()
plt.show()