In [2]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models


import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

%matplotlib inline
%config InlineBackend.figure_format = "retina"

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

cpu


In [3]:
# Train Parameters
batch_size = 64
num_epochs = 40
lr = 1e-4
num_grid_rows = 8
num_samples = 64

# Diffusion Parameters
beta_start = 1e-4
beta_end = 0.02
T = 1000

# Model Parameters
nc = 3
image_size = 32

In [4]:
transform = transforms.Compose([transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize([0.5 for _ in range(nc)], [0.5 for _ in range(nc)])])

In [7]:
celebADataset = datasets.ImageFolder(root = "../CelebA", transform = transform)
celebALoader = DataLoader(dataset = celebADataset, batch_size = batch_size, shuffle = True)

mnist_train = datasets.MNIST(root = "../MNIST", train = True, transform = transform, download = True)
mnist_test = datasets.MNIST(root = "../MNIST", train = False, transform = transform, download = True)

mnist_combined_loader = DataLoader(dataset = mnist_train + mnist_test, batch_size = batch_size, shuffle = True)

## **Latent Diffusion Models**

### The idea is to train the diffusion models on a low dimensional latent representation rather than the entire big pixel space. In addition to that, also train an Encoder-Decoder model that takes the original image converts it into the latent representation using the encoder and reconverts the latent representation to the reconstructed image.

<p align="center">
<img src="./Media/Latent2.png" style="width:60%;border:0;" alt = "image">
</p>

### The downside is that although the $L1/L2$ reconstruction loss might be low, the perceptual features in the reconstructed image still might be **fuzzy**

## Discretizing the Latent Space using the $\text{CodeBooks}$ from $\text{VQVAEs}$
### $\text{VQVAE}$ as the $\text{AutoEncoder}$

$k$ vectors, each of $d$ dimensions $(k \times d)$ help us encode the data.

<p align="center">
<img src="./Media/VQVAE1.png" style="width:70%;border:0;" alt = "image">
</p>

The encoder generates a feature map of $H \times W$ features each of $d$ dimension.
<p align="center">
<img src="./Media/VQVAE2.png" style="width:70%;border:0;" alt = "image">
</p>

For each of the features, we find the nearest $d$ dimensional encoding to it and replace it with that.

$$ z_q(x) = e_k $$
$$ k = \argmin_j || z_e(x) - e_j ||_2 $$

<p align="center">
<img src="./Media/VQVAE3.png" style="width:70%;border:0;" alt = "image">
</p>

The decoder then discards off the feature map given by the encoder and only uses the nearest codeblock feature map to reconstruct the output image.
<p align="center">
<img src="./Media/VQVAE4.png" style="width:70%;border:0;" alt = "image">
</p>

The issue is we have to define the gradients for the $\argmin$ step separately for the gradients to flow back. We approximate the gradient similar to the straight-through estimator and just copy gradients from decoder input $z_q(x)$ to encoder output $z_e(x)$
<p align="center">
<img src="./Media/VQVAE5.png" style="width:70%;border:0;" alt = "image">
</p>

$$ L = \log p(x | z_q(x)) + || \text{sg}[z_e(x)] - e ||_2^2 + \beta || z_e(x) - \text{sg}[e] ||_2^2 $$

## Perceptual Retention & $\text{LPIPS}$ as the metric

<p align="center">
<img src="./Media/Percept.png" style="width:80%;border:0;" alt = "image">
</p>

CLearly as said that although the $L_1$ or $L_2$ reconstruction loss might be low for the image, yet the perceptual features in the image perceived by a human are still **blurry**. Now in order to understand how a model would perceive the image, there is no better place to dig into pretrained classification **CNNs** $\to$ **VGGs**. The goal is to bring the feature map extracted at each VGG layer to be very similar to the original image's feature maps at each VGG layer. This distance metric between the feature maps extracted from the layers of a pretrained VGG is called the **perceptual loss**.

<p align="center">
<img src="./Media/LPIPS1.png" style="width:80%;border:0;" alt = "image">
</p>

Check out the original implementation too at [Perceptual Similarity](https://github.com/richzhang/PerceptualSimilarity).

In [3]:
# lpips.py implementation
from collections import namedtuple

def spatial_average(in_tens, keepdim = True):
    return in_tens.mean([2, 3], keepdim = keepdim)


class vgg16(nn.Module):
    def __init__(self, requires_grad = False, pretrained = True):
        super().__init__()
        vgg_pretrained_features = models.vgg16(pretrained = pretrained).features
        self.slice1 = nn.Sequential()
        self.slice2 = nn.Sequential()
        self.slice3 = nn.Sequential()
        self.slice4 = nn.Sequential()
        self.slice5 = nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
            
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        
        # Freeze the model
        if requires_grad == False:
            for param in self.parameters():
                param.requires_grad = False
    
    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
        return out
    
    
class ScalingLayer(nn.Module):
    def __init__(self):
        super().__init__()
        # mean = [0.485, 0.456, 0.406]
        # std = [0.229, 0.224, 0.225]
        self.register_buffer("shift", torch.tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer("scale", torch.tensor([.458, .448, .450])[None, :, None, None])
    
    def forward(self, inp):
        return (inp - self.shift) / self.scale

class NetLinLayer(nn.Module):
    def __init__(self, chn_in, chn_out = 1, use_dropout = False):
        super().__init__()
        
        layers = [nn.Dropout(), ] if (use_dropout) else []
        layers += [nn.Conv2d(chn_in, chn_out, kernel_size = 1, stride = 1, padding = 0, bias = False), ]
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

class LPIPS(nn.Module):
    def __init__(self, net = "vgg", version = "0.1", use_dropout = True):
        super().__init__()
        self.scaling_layer = ScalingLayer()
        
        self.chns = [64, 128, 256, 512, 512]
        self.L = len(self.chns)
        self.net = vgg16(pretrained = True, requires_grad = False)
        
        self.lin0 = NetLinLayer(self.chns[0], use_dropout = use_dropout)
        self.lin1 = NetLinLayer(self.chns[1], use_dropout = use_dropout)
        self.lin2 = NetLinLayer(self.chns[2], use_dropout = use_dropout)
        self.lin3 = NetLinLayer(self.chns[3], use_dropout = use_dropout)
        self.lin4 = NetLinLayer(self.chns[4], use_dropout = use_dropout)
        self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
        self.lins = nn.ModuleList(self.lins)
        
        import inspect
        
        model_path = os.path.abspath(
            os.path.join(inspect.getfile(self.__init__), "..", "weights/v%s/%s.pth" % (version, net))
        )
        print("Loading model from: %s" % model_path)
        self.net.load_state_dict(torch.load(model_path, map_location = device), strict = False)
        
        self.eval()
        
        for param in self.parameters():
            param.requires_grad = False
        
    
    def forward(self, in0, in1, normalize = False):
        if normalize:
            in0 = 2 * in0 - 1
            in1 = 2 * in1 - 1
        
        in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
        
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        
        feats0, feats1, diffs = {}, {}, {}
        
        for kk in range(self.L):
            feats0[kk], feats1[kk] = F.normalize(outs0[kk], dim = 1), F.normalize(outs1[kk])
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
        
        res = [spatial_average(self.lins[kk](diffs[kk]), keepdim = True) for kk in range(self.L)]
        val = 0
        
        for l in range(self.L):
            val += res[l]
        
        return val

<p align="center">
<img src="./Media/Latent3.png" style="width:70%;border:0;" alt = "image">
</p>

<p align="center">
<img src="./Media/Latent4.png" style="width:70%;border:0;" alt = "image">
</p>
