In [1]:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import wandb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [12]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        # encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        # decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)

        self.relu = nn.ReLU()

    def encode(self, x):
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(h))

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_new = mu + sigma*epsilon
        x_reconstructed = self.decode(z_new)
        return x_reconstructed, mu, sigma


if __name__ == "__main__":
    x = torch.randn(4, 28*28)
    vae = VariationalAutoEncoder(input_dim=784)
    x_reconstructed, mu, sigma = vae(x)

In [None]:
wandb.login(key = 'key')
wandb.init(project="VAE(RAID)")

In [14]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20
number_epochs = 20
Batch_size = 32
Learning_rate = 3e-4

In [None]:

wandb.config.update({
    "input_dim": INPUT_DIM,
    "hidden_dim": H_DIM,
    "z_dim": Z_DIM,
    "num_epochs": number_epochs,
    "batch_size": Batch_size,
    "learning_rate": Learning_rate
})

dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=Batch_size, shuffle=True)
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=Learning_rate)
loss_fn = nn.BCELoss(reduction="sum")
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

In [16]:
def inference(num_examples=10):
    images = {i: [] for i in range(10)}
    for x, y in dataset:
        if len(images[y]) < num_examples:
            images[y].append(x)
        if all(len(img_list) == num_examples for img_list in images.values()):
            break

    encodings_digit = {}
    for digit in range(10):
        encodings_digit[digit] = []
        for img in images[digit]:
            with torch.no_grad():
                mu, sigma = model.encode(img.view(1, 784).to(DEVICE))
            encodings_digit[digit].append((mu, sigma))

    wandb_images = []
    for digit in range(10):
        for i in range(num_examples):
            mu, sigma = encodings_digit[digit][i]
            epsilon = torch.randn_like(sigma).to(DEVICE)
            z = mu + sigma * epsilon
            out = model.decode(z)
            out = out.view(1, 28, 28)

            original_img = images[digit][i].view(1, 28, 28)
            comparison = torch.cat([original_img, out.cpu()], dim=-1)
            wandb_images.append(wandb.Image(comparison, caption=f"Digit {digit} Example {i}"))

    wandb.log({"generated_images": wandb_images})

In [None]:
for epoch in range(number_epochs):
    model.train()
    overall_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{number_epochs}",
                        bar_format='{l_bar}{bar:30}{r_bar}{bar:-30b}',
                        colour='green', leave=False)

    for batch_idx, (x, _) in enumerate(progress_bar):
        x = x.view(Batch_size, INPUT_DIM).to(DEVICE)
        x_reconstructed, mu, sigma = model(x)

        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_divergence = -0.5 * torch.sum(1 + torch.log(sigma**2) - mu**2 - sigma**2)
        loss = reconstruction_loss + kl_divergence

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        overall_loss += loss.item()

        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

    progress_bar.close()

    average_loss = overall_loss / len(train_loader.dataset)
    wandb.log({
        "epoch": epoch,
        "loss": average_loss,
        "learning_rate": optimizer.param_groups[0]['lr']
    })
    print(f"Epoch [{epoch+1}/{number_epochs}], Loss: {average_loss:.4f}")

    scheduler.step(average_loss)

inference(num_examples=10)