In [6]:
import os
import torch
from torch import nn 

Set working directory, so we don't download the same dataset twice

In [None]:
working_dir = os.getcwd()
working_dir

In [None]:
os.chdir(os.path.dirname(working_dir))
os.getcwd()

Model

In [3]:
# Input img -> Hidden dim -> mean, std -> Paramertication Trick -> Decoder -> Output img
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20) -> None:
        super().__init__()
        # Encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        # Decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)

        # Activation
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        # q_phi(z|x)
        h = self.relu(self.img_2hid(x))
        mu = self.hid_2mu(h)
        sigma = self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        # p_theta(x|z)
        h = self.relu(self.z_2hid(z))
        img = self.sigmoid(self.hid_2img(h))
        return img

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma*epsilon
        x_reconstructed = self.decode(z_reparametrized)
        return x_reconstructed, mu, sigma

Check the shape

In [4]:
x = torch.randn(4, 28*28)
vae = VariationalAutoEncoder(input_dim=784)
x_reconstructed, mu, sigma = vae(x)
print(x_reconstructed.shape)
print(mu.shape)
print(sigma.shape)

torch.Size([4, 784])
torch.Size([4, 20])
torch.Size([4, 20])


Training

In [35]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision import datasets, transforms
from tqdm import tqdm
from pathlib import Path

In [31]:
# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Current working device: {device}")
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20
NUM_EPOCH = 10
BATCH_SIZE = 128
LR_RATE = 3e-4 # Karpathy constant

[INFO] Current working device: cuda


In [27]:
# Dataset Loading
transform = transforms.ToTensor()
dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

In [28]:
# Model params
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

Traning loop

In [32]:
for epoch in range(NUM_EPOCH):
    loop = tqdm(enumerate(train_loader))
    for i, (x, _) in loop:
        # Forward pass
        x = x.to(device).view(x.shape[0], INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)

        # Compute loss
        reconstructed_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

        # Backprop
        loss = reconstructed_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())


1875it [00:14, 131.20it/s, loss=3.99e+3]
1875it [00:14, 133.33it/s, loss=4.25e+3]
1875it [00:14, 132.31it/s, loss=4.21e+3]
1875it [00:13, 134.38it/s, loss=3.99e+3]
1875it [00:13, 134.12it/s, loss=4.09e+3]
1875it [00:14, 133.86it/s, loss=3.83e+3]
1875it [00:14, 129.94it/s, loss=4.07e+3]
1875it [00:14, 131.37it/s, loss=4.1e+3] 
1875it [00:14, 131.39it/s, loss=4.25e+3]
1875it [00:15, 121.11it/s, loss=3.93e+3]


In [36]:
model = model.to("cpu")
def inference(digit, num_examples=1, out_path:Path=Path("./Autoencoders/VAE_gen_examples")):
    """
    Generates (num_exmaples) of a particular digit.
    Specifically we extract an exmaple of eaxh digit, 
    then after we have mu, sigma representation for 
    each digit we can sample from that.

    After we sample we can run the decoder part of the VAE and generate examples.
    """
    out_path.mkdir(parents=True, exist_ok=True)
    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1
        if idx == 10:
            break

    encoding_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images[d]. view(1, 784))
        encoding_digit.append((mu, sigma))
    
    mu, sigma = encoding_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, out_path / f"generated_{digit}_ex{example}.png")

In [37]:
for idx in range(10):
    inference(idx, num_examples=3)