In [5]:
import torch
from torch import nn
from torchvision import datasets,transforms
import cv2
from tqdm import tqdm
import numpy as np

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # Converts the pixel values in the image in the range -1 to 1
    transforms.Normalize((0.5,), (0.5,))
])

mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./test_data', train=False,download=True,transform=transform)

data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
                                          batch_size=64,
                                          shuffle=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_data,
                                          batch_size=64,
                                          shuffle=True)

In [4]:
# Input Img -> Hiddden dim -> mean, std -> Parameterization Trick -> Decoder -> Output Img
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.common_fc = nn.Sequential(
            nn.Linear(28*28,196),
            nn.Tanh(),
            nn.Linear(196,48),
            nn.Tanh()
        )

        self.mean_fc = nn.Sequential(
            nn.Linear(48,16),
            nn.Tanh(),
            nn.Linear(16,2)
        )
        # Here we are calculating the log variance not the actual variance in the distribution
        self.log_var_fc = nn.Sequential(
            nn.Linear(48,16),
            nn.Tanh(),
            nn.Linear(16,2)
        )

        self.decoder_fcs = nn.Sequential(
            nn.Linear(2,16),
            nn.Tanh(),
            nn.Linear(16,48),
            nn.Tanh(),
            nn.Linear(48,196),
            nn.Tanh(),
            nn.Linear(196,28*28)
        )
    def encode(self,x):
        out = self.common_fc(torch.flatten(x,start_dim=1))
        mean = self.mean_fc(out)
        log_var = self.log_var_fc(out)
        return mean,log_var
    

    # Here we are applying the reparametrization trick
    def sample(self,mean,log_var):
        std = torch.exp(0.5*log_var)
        z = torch.randn_like(std)
        z = z*std + mean
        return z
    
    def decode(self,z):
        out = self.decoder_fcs(z)
        out = out.reshape((z.size(0),1,28,28))
        return out
    
    def forward(self,x):
        # Batch,Channel,Height,Width
        ## Encoder
        mean,log_var = self.encode(x)
        ## Sampling
        z = self.sample(mean,log_var)
        ## Decoder
        output = self.decode(z)
        return mean, log_var, output

In [None]:
def train():
    # Instantiate the model
    model = VAE().to(device=device)
    # Training Parameters
    num_epochs = 10
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=1e-5)
    criterion = torch.nn.MSELoss()

    recon_losses = []
    kl_losses = []
    losses = []

    for epoch_idx in range(num_epochs):
        for im,_ in tqdm(data_loader):
            im = im.to(device)
            optimizer.zero_grad()
            mean, log_var, out = model(im)
            cv2.imwrite('./vae_outputs/input.jpeg',255*((im+1)/2).detach().cpu().numpy()[0,0])
            cv2.imwrite('./vae_outputs/output.jpeg',255*((out+1)/2).detach().cpu().numpy()[0,0])

            kl_loss = torch.mean(0.5*(torch.sum(torch.exp(log_var) + mean **2 -1 - log_var,dim = -1)))
            recon_loss = criterion(out,im)
            loss = recon_loss+0.00001*kl_loss
            recon_losses.append(recon_loss.item())
            losses.append(loss.item())
            kl_losses.append(kl_loss.item())
            loss.backward()
            optimizer.step()
        print(f'Finished Epoch: {epoch_idx+1}|Reconstruction Loss: {np.mean(recon_losses):.4f}|KL Loss:{np.mean(kl_losses):4f}|')

    # Run a reconstruction for some sample test images
    idxs = torch.randint(0,len(test_data)-1,(100,))
    ims = torch.cat([test_data[idx][0][None,:] for idx in idxs]).float()

