In [10]:
from torchvision import models
from torchsummary import summary
# vgg = models.resnet50()
# summary(vgg, (3, 542, 1042))

"""
The following is an import of PyTorch libraries.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import random
from tqdm import tqdm_notebook
from tqdm import trange


In [14]:
import time
from tqdm.notebook import tqdm

l = [None] * 10000

for i, e in tqdm(enumerate(l), total = len(l)): 
    time.sleep(0.01



  0%|          | 0/10000 [00:00<?, ?it/s]

In [11]:
"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
"""
A Convolutional Variational Autoencoder
"""
class VAE(nn.Module):
    def __init__(self, imgChannels=1, featureDim=32*20*20, zDim=256):
        super(VAE, self).__init__()

        # Initializing the 2 convolutional layers and 2 full-connected layers for the encoder
        self.encConv1 = nn.Conv2d(imgChannels, 16, 5)
        self.encConv2 = nn.Conv2d(16, 32, 5)
        self.encFC1 = nn.Linear(featureDim, zDim)
        self.encFC2 = nn.Linear(featureDim, zDim)

        # Initializing the fully-connected layer and 2 convolutional layers for decoder
        self.decFC1 = nn.Linear(zDim, featureDim)
        self.decConv1 = nn.ConvTranspose2d(32, 16, 5)
        self.decConv2 = nn.ConvTranspose2d(16, imgChannels, 5)

    def encoder(self, x):

        # Input is fed into 2 convolutional layers sequentially
        # The output feature map are fed into 2 fully-connected layers to predict mean (mu) and variance (logVar)
        # Mu and logVar are used for generating middle representation z and KL divergence loss
        x = F.relu(self.encConv1(x))
        x = F.relu(self.encConv2(x))
        x = x.view(-1, 32*20*20)
        mu = self.encFC1(x)
        logVar = self.encFC2(x)
        return mu, logVar

    def reparameterize(self, mu, logVar):

        #Reparameterization takes in the input mu and logVar and sample the mu + std * eps
        std = torch.exp(logVar/2)
        eps = torch.randn_like(std)
        return mu + std * eps

    def decoder(self, z):

        # z is fed back into a fully-connected layers and then into two transpose convolutional layers
        # The generated output is the same size of the original input
        x = F.relu(self.decFC1(z))
        x = x.view(-1, 32, 20, 20)
        x = F.relu(self.decConv1(x))
        x = torch.sigmoid(self.decConv2(x))
        return x

    def forward(self, x):

        # The entire pipeline of the VAE: encoder -> reparameterization -> decoder
        # output, mu, and logVar are returned for loss computation
        mu, logVar = self.encoder(x)
        z = self.reparameterize(mu, logVar)
        out = self.decoder(z)
        return out, mu, logVar

In [13]:

"""
Initialize Hyperparameters
"""
batch_size = 128
learning_rate = 1e-3
num_epochs = 10


"""
Create dataloaders to feed data into the neural network
Default MNIST dataset is used and standard train/test split is performed
"""
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('/Users/cabe0006/Projects/monash/data', train=True, download=True,
                    transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('/Users/cabe0006/Projects/monash/data', train=False, transform=transforms.ToTensor()),
    batch_size=1)


"""
Initialize the network and the Adam optimizer
"""
net = VAE().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)


"""
Training the network for a given number of epochs
The loss after every epoch is printed
"""
for epoch in range(num_epochs):
    print(epoch)
    with tqdm(train_loader, unit="batch") as tepoch:
        for idx, data in enumerate(train_loader, 0):
            tepoch.set_description(f"Epoch {epoch}")
            imgs, _ = data
            imgs = imgs.to(device)

            # Feeding a batch of images into the network to obtain the output image, mu, and logVar
            out, mu, logVar = net(imgs)

            # The loss is the BCE loss combined with the KL divergence to ensure the distribution is learnt
            kl_divergence = 0.5 * torch.sum(-1 - logVar + mu.pow(2) + logVar.exp())
            loss = F.binary_cross_entropy(out, imgs, size_average=False) + kl_divergence

            # Backpropagation based on the loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            tepoch.set_postfix(loss=loss.item())
    print('Epoch {}: Loss {}'.format(epoch, loss))





  0%|          | 0/469 [00:00<?, ?batch/s][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:00<?, ?batch/s][A[A[A[A

0






Epoch 0:   0%|          | 0/469 [00:00<?, ?batch/s, loss=6.84e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:00<?, ?batch/s, loss=6.84e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:00<?, ?batch/s, loss=6.44e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:00<?, ?batch/s, loss=6.44e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:01<?, ?batch/s, loss=5.86e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:01<?, ?batch/s, loss=5.86e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:01<?, ?batch/s, loss=5.66e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:01<?, ?batch/s, loss=5.66e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:02<?, ?batch/s, loss=5.71e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:02<?, ?batch/s, loss=5.71e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:02<?, ?batch/s, loss=5.62e+4][A[A[A[A



Epoch 0:   0%|          | 0/469 [00:02<?, ?batch/s, loss=5.62e+4][A[A[A[A



Epoch 0:   0%|          

KeyboardInterrupt: 

In [13]:
# summary(vae, (1, 28, 28))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random

net.eval()
with torch.no_grad():
    for data in random.sample(list(test_loader), 1):
        imgs, _ = data
        imgs = imgs.to(device)
        img = np.transpose(imgs[0].cpu().numpy(), [1,2,0])
        plt.subplot(121)
        plt.imshow(np.squeeze(img))
        out, mu, logVAR = net(imgs)
        outimg = np.transpose(out[0].cpu().numpy(), [1,2,0])
        plt.subplot(122)
        plt.imshow(np.squeeze(outimg))
        break