In [10]:
import torch
from torch import nn
import torch.nn.functional as F

In [11]:
#process : images(input_dim) -> hidden dim -> mean and std deviation vectors -> reparametrization -> latent dimension -> output_dim
class VariationalAutoEncoder(nn.Module):
  def __init__(self,input_dim, h_dim = 200,z_dim = 20): #h_dim:hidden_dimension
    super().__init__()

    #encoder
    self.img_2hid = nn.Linear(input_dim,h_dim) #taking input image into a hidden dimension
    self.hid_2mu = nn.Linear(h_dim,z_dim) #z_dim is dimension of mu vector or std deviation vector or the latent vector (all 3 have the same dim)
    self.hid_2sigma = nn.Linear(h_dim,z_dim) #for sd vector

    #reparamterization will be taken care of later(in the forward class).for now,assuming reparamaterization is done,we start from the latent vector in the decoder

    #decoder
    self.z_2hid = nn.Linear(z_dim,h_dim)
    self.hid_2img = nn.Linear(h_dim,input_dim)



  def encode(self,x): #q_phi(z|x)
    h = self.relu(self.img_2hid(x)) #mapping input into hidden dim followed by relu activation
    mu,sigma = self.hid_2mu(h),self.hid_2sigma(h) #?) why not ReLu here
    return mu,sigma
  
  def decode(self,z): #p_theta(x|z) again assuming we already have the latent vector z
    h = F.relu(self.z_2hid(z))
    img = self.hid2img(h)
    img = torch.sigmoid(img) #this step is specifically for our use-case,that is MNIST dataset,which assumes the pixel values are in b/w 0 and 1
    return img

  def forward(self,x,z):
    #encode
    mu,sigma = self.encode(x)
    #now we apply reparamterization
    epsilon = torch.rand_like(sigma)
    z_reparametrized = mu+sigma*epsilon
    #decode
    x_reconstructed = self.decode(z_reparametrized)
    return x_reconstructed,mu,sigma #remember that we need this mu and sigma to compute the loss functions (especially while computing the KL divergence)

In [12]:
#if __name__ == "__main__":
#  x = torch.randn(4,28*28) #batch_size = 4 and 28*28 images
#  vae = VariationalAutoEncoder(input_dim = 784) #input dim is 28*28 flattened as an input vector. The other two paramters are already initialised above

In [13]:
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 784
Z_DIM = 20
H_DIM = 200
NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4

In [14]:
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision.datasets as datasets  
from torch.utils.data import DataLoader  

In [15]:
# Dataset loading
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

In [16]:
# Define train function
def train(num_epochs, model, optimizer, loss_fn):
    # Start training
    for epoch in range(num_epochs):
        loop = tqdm(enumerate(train_loader))
        for i, (x, y) in loop:
            # Forward pass
            x = x.to(device).view(-1, INPUT_DIM)
            x_reconst, mu, sigma = model(x)

            # loss, formulas from https://www.youtube.com/watch?v=igP03FXZqgo&t=2182s
            reconst_loss = loss_fn(x_reconst, x)
            kl_div = - torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

            # Backprop and optimize
            loss = reconst_loss + kl_div
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=loss.item())

In [None]:
# Initialize model, optimizer, loss
model = VariationalAutoEncoder(INPUT_DIM, Z_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

# Run training
train(NUM_EPOCHS, model, optimizer, loss_fn)

In [None]:
def inference(digit, num_examples=1):

#    Generates (num_examples) of a particular digit.
#   Specifically we extract an example of each digit,
#    then after we have the mu, sigma representation for
#    each digit we can sample from that.

#    After we sample we can run the decoder part of the VAE
#    and generate examples.

    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1
        if idx == 10:
            break

    encodings_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images[d].view(1, 784))
        encodings_digit.append((mu, sigma))

    mu, sigma = encodings_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, f"generated_{digit}_ex{example}.png")
