In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F 
import torchvision
import torch.optim as optim 
import argparse
import matplotlib
import matplotlib.pyplot as plt 
import torchvision.transforms as transforms

from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from PIL import *
import imageio
import numpy as np
from torchvision.utils import make_grid

matplotlib.style.use('ggplot')

In [None]:
kernel_size = 4 # (4, 4) kernel
image_channels = 1 # MNIST images are grayscale
latent_dim = 16 # latent dimension for sampling
        

In [None]:
def image_to_vid(images): 
    imgs = [np.array(to_pil_image(img)) for img in images]
    imageio.mimsave(f'outputs/generated_images.gif',imgs)
def save_reconstructed_images(recon_images, epoch):
    save_image(recon_images.cpu(), f'outputs/output{epoch}.jpg')
def save_loss_plot(train_loss, valid_loss):
    plt.figure(figsize=(10,7))
    plt.plot(train_loss, color='red', label='train loss')
    plt.plot(valid_loss, color='blue', label='validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('outputs/loss.jpg')
    plt.show()
to_pil_image = transforms.ToPILImage()

In [None]:
class ConvVAE(nn.Module):
    def __init__(self):
        super(ConvVAE, self).__init__()

        # Encoder
        self.enc1 = nn.Conv2d(in_channels=1,out_channels=8,kernel_size=(4,4),stride=2, padding=1)
        self.enc2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(4,4), stride=2, padding=1)
        self.enc3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(4,4), stride=2, padding=1)
        self.enc4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(4,4),stride=2, padding=0)

        # Fully connected layers
        self.fc1 = nn.Linear(64, 128)
        self.fc_mu = nn.Linear(128,16)
        self.fc_var = nn.Linear(128,16)
        self.fc2 = nn.Linear(latent_dim, 64)


        # Decoder
        self.dec1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,kernel_size=(4,4), stride=2,padding=0)
        self.dec2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(4,4), stride=2, padding=1)
        self.dec3 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(4,4), stride=2, padding=1)
        self.dec4 = nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=(4,4), stride=2, padding=1)


        # Reparameterization Trick
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var) 
        eps = torch.randn_like(std) 
        sample = mu + (eps * std)

        return sample
    

    def forward(self, x):
        # Encoding
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        
        batch, _, _, _ = x.shape
        x = F.adaptive_avg_pool2d(x,1).reshape(batch,-1)
        hidden = self.fc1(x)

        mu = self.fc_mu(hidden)
        log_var = self.fc_var(hidden)
          
        z = self.reparameterize(mu, log_var)
        z = self.fc2(z)

        z = z.view(-1,64,1,1)

        # Decoding
        x = F.relu(self.dec1(z))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        reconstruction = torch.sigmoid(self.dec4(x))
        return reconstruction, mu, log_var

In [None]:
# learning parameters
epochs = 100
batch_size = 64
lr = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ConvVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss(reduction='sum')

grid_images = []

In [None]:
# training set and train data loader
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(), ])
trainset = torchvision.datasets.MNIST(
    root='../input', train=True, download=True, transform=transform
)
trainloader = DataLoader(
    trainset, batch_size=batch_size, shuffle=True
)
# validation set and validation data loader
testset = torchvision.datasets.MNIST(
    root='../input', train=False, download=True, transform=transform
)
testloader = DataLoader(
    testset, batch_size=batch_size, shuffle=False
)

In [None]:
def final_loss(mse_loss, mu, logvar):

    MSE = mse_loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return MSE + KLD

In [None]:
def train(model, dataloader, dataset, device, optimizer, criterion):
    model.train()
    running_loss = 0.0
    counter = 0
    for i, data in tqdm(enumerate(dataloader), total=int(len(dataset)/dataloader.batch_size)):
        counter +=1
        data = data[0]
        data = data.to(device)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        mse_loss = criterion(reconstruction, data)
        loss = final_loss(mse_loss, mu, logvar)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss = running_loss/ counter
    return train_loss

In [None]:
def validate(model, dataloader, dataset, device, criterion):
    model.eval()
    running_loss = 0.0
    counter = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(dataset)/dataloader.batch_size)):
            counter += 1
            data = data[0]
            data = data.to(device)
            reconstruction, mu, logvar = model(data)
            bce_loss = criterion(reconstruction, data)
            loss = final_loss(bce_loss, mu, logvar)
            running_loss += loss.item()
        
            # save the last batch input and output of every epoch
            if i == int(len(dataset)/dataloader.batch_size) - 1:
                recon_images = reconstruction
    val_loss = running_loss / counter
    return val_loss, recon_images

In [None]:
train_loss = []
valid_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(
        model, trainloader, trainset, device, optimizer, criterion
    )
    valid_epoch_loss, recon_images = validate(
        model, testloader, testset, device, criterion
    )
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    # save the reconstructed images from the validation loop
    save_reconstructed_images(recon_images, epoch+1)
    # convert the reconstructed images to PyTorch image grid format
    image_grid = make_grid(recon_images.detach().cpu())
    grid_images.append(image_grid)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {valid_epoch_loss:.4f}")

In [None]:
# save the reconstructions as a .gif file
image_to_vid(grid_images)
# save the loss plots to disk
save_loss_plot(train_loss, valid_loss)
print('TRAINING COMPLETE')