In [1]:
# Allow Collab to connect to your Drive
from google.colab import drive
drive.mount('/content/drive')

# Path to the data on Drive
root = "/content/drive/MyDrive/PhD/Integrated_Gradient"


Mounted at /content/drive


In [6]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torchsummary import summary
import numpy as np
from math import floor


In [3]:
# Checking if GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    use_cuda = True
else:
    device = torch.device("cpu")
    use_cuda = False
print("Found:",torch.cuda.device_count(), device)


Found: 1 cuda


In [4]:
class VAE(nn.Module):
    def __init__(self, latent_dim, hidden_size):
        super().__init__()
        self.latent_dim = latent_dim
        self.hidden_size = hidden_size

        self.fc1 = nn.Linear(784, hidden_size)
        self.fc21 = nn.Linear(hidden_size, latent_dim)
        self.fc22 = nn.Linear(hidden_size, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_size)
        self.fc4 = nn.Linear(hidden_size, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [7]:
class ConvVAE(nn.Module):
    def __init__(self):
        self.name = "Conv_VAE"
        super(ConvVAE, self).__init__()
        kernel_size      = 4   # (4 x 4)
        stride           = 2
        padding          = 1
        init_channels    = 16 # initial number of filters
        self.latent_dim  = 20 # latent dimension for sampling

        # Encoder
        # Conv1
        self.enc1 = nn.Conv2d(in_channels = 1, out_channels = init_channels, kernel_size = kernel_size, stride = stride, padding = padding)
        height_width = floor((28 - kernel_size  + 2 * padding) / stride + 1)

        # Conv2
        self.enc2 = nn.Conv2d(in_channels = init_channels, out_channels = init_channels * 2, kernel_size = kernel_size, stride = stride, padding = padding)
        height_width = floor((height_width - kernel_size  + 2 * padding) / stride + 1)

        # Conv3
        self.enc3 = nn.Conv2d(in_channels = init_channels * 2, out_channels = init_channels * 4, kernel_size = kernel_size, stride = stride, padding = padding)
        height_width = floor((height_width - kernel_size  + 2 * padding) / stride + 1)
        
        # fully connected layers for learning representations
        hidden_size = height_width ** 2 * init_channels * 4
        self.fc_mu = nn.Linear(hidden_size, self.latent_dim)
        self.fc_log_var = nn.Linear(hidden_size, self.latent_dim)
        self.fc2 = nn.Linear(self.latent_dim, init_channels * 4)
        
        

        # Decoder
        # ConvT  get width/height equal to kernel_size
        self.dec1 = nn.ConvTranspose2d(in_channels = init_channels * 4, out_channels = init_channels * 4, kernel_size = kernel_size, stride = 1, padding = 0)
        height_width = kernel_size
        
        # ConvT
        self.dec2 = nn.ConvTranspose2d(in_channels = init_channels * 4, out_channels = init_channels * 2, kernel_size = kernel_size, stride = stride, padding = padding)
        height_width = (height_width - 1) * stride - 2 * padding + kernel_size
        
        # ConvT
        self.dec3 = nn.ConvTranspose2d(in_channels = init_channels * 2, out_channels = init_channels, kernel_size = kernel_size, stride = stride, padding = padding)
        height_width = (height_width - 1) * stride - 2 * padding + kernel_size
        
        # ConvT adjust the padding so that we end up at 28x28
        required_padding = int(((height_width - 1) * stride + kernel_size - 28) / 2) 
        self.dec4 = nn.ConvTranspose2d(in_channels = init_channels, out_channels = 1, kernel_size = kernel_size, stride = stride, padding = required_padding)


    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        return mu + eps * std # sampling
        

    def encode(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        hidden = x.view(x.shape[0], -1)
        return self.fc_mu(hidden), self.fc_log_var(hidden)


    def decode(self, z):
        z = self.fc2(z)
        z = z.view(-1, 64, 1, 1)
        x = F.relu(self.dec1(z))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        return torch.sigmoid(self.dec4(x))


    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [8]:
model = ConvVAE().to(device)
model.load_state_dict(torch.load(f"{root}/models/{model.name}.pt"))
summary(model, (1, 28, 28))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 14, 14]             272
            Conv2d-2             [-1, 32, 7, 7]           8,224
            Conv2d-3             [-1, 64, 3, 3]          32,832
            Linear-4                   [-1, 20]          11,540
            Linear-5                   [-1, 20]          11,540
            Linear-6                   [-1, 64]           1,344
   ConvTranspose2d-7             [-1, 64, 4, 4]          65,600
   ConvTranspose2d-8             [-1, 32, 8, 8]          32,800
   ConvTranspose2d-9           [-1, 16, 16, 16]           8,208
  ConvTranspose2d-10            [-1, 1, 28, 28]             257
Total params: 172,617
Trainable params: 172,617
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.10
Params size (MB): 0.66
Estimated T

In [9]:
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(f"{root}/dataset", train=False, download=True, transform=transforms.ToTensor()),
        batch_size=64, shuffle=False)
image_batch, target_batch = next(iter(test_loader))
image_batch = image_batch.to(device)
target_batch = target_batch.to(device)
image_batch.shape

torch.Size([64, 1, 28, 28])

In [10]:
# Regular interpolation
n_samples = 15
alpha = torch.linspace(0, 1, n_samples)


In [11]:
for i in range(5):
    linear_interp = torch.cat([ a * image_batch[[i]] + (1 - a) * image_batch[[i + 1]] for a in alpha])
    save_image(linear_interp, f"{root}/results/Interpolation/linear_interp_{target_batch[i]}_to_{target_batch[i + 1]}.png", nrow=n_samples)

In [12]:
# spherical linear interpolation (slerp)
def slerp(z1, z2, alpha):
    omega = torch.acos(torch.sum(z1 / torch.norm(z1, dim = -1, keepdim = True) \
                               * z2 / torch.norm(z2, dim = -1, keepdim = True), dim = -1, keepdims = True))
    so = torch.sin(omega)
    return torch.cat([torch.sin((1 - a) * omega) / so * z1 + torch.sin(a * omega) / so * z2 for a in alpha]), omega

In [14]:
# Compute latent repersentations
z, _ = model.encode(image_batch)
for i in range(5):
    linear_z_interp = torch.cat([ a * z[[i]] + (1 - a) * z[[i + 1]] for a in alpha])
    semantic_interp = model.decode(linear_z_interp)
    save_image(semantic_interp, f"{root}/results/Interpolation/semantic_interp_{target_batch[i]}_to_{target_batch[i + 1]}.png", nrow=n_samples)

    slerp_z_interp, _ = slerp(z[[i + 1]], z[[i]], alpha)
    spherical_interp = model.decode(slerp_z_interp)
    save_image(spherical_interp, f"{root}/results/Interpolation/spherical_interp_{target_batch[i]}_to_{target_batch[i + 1]}.png", nrow=n_samples)