In [1]:
from vae import VAE, vae_loss

import torch
import torch.optim as optim
from torchvision import datasets, transforms

from tqdm.auto import tqdm

import random
import os
import numpy as np

In [2]:
IMAGE_SIZE_MNIST = 28

In [3]:
def seed_everything(seed=2023):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything()

In [4]:
# MNIST Dataset
train_dataset = datasets.MNIST(root='../data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='../data/', train=False, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
    device = torch.device('mps:0')
else:
    device = torch.device('cpu')

device

device(type='mps', index=0)

In [6]:
vae = VAE(
    x_dim=IMAGE_SIZE_MNIST ** 2,
    h_dim1=512,
    h_dim2=256,
    z_dim=20,
    drop_prob=0.1
)
vae.to(device)

VAE(
  (fc1): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
  )
  (fc2): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
  )
  (mu): Linear(in_features=256, out_features=20, bias=True)
  (log_var): Linear(in_features=256, out_features=20, bias=True)
  (fc4): Sequential(
    (0): Linear(in_features=20, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
  )
  (fc5): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
  )
  (fc6): Linear(in_features=512, out_features=784, bias=True)
)

In [7]:
optimizer = optim.Adam(vae.parameters())

In [8]:
def train(epoch):
    vae.train()
    train_loss = 0

    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        optimizer.zero_grad

        reconst_batch, mu, log_var = vae(data)

        loss = vae_loss(
            reconst_batch, 
            data, 
            mu, 
            log_var
        )
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    _train_loss = train_loss / len(train_loader.dataset)
    print(f'Epoch: {epoch} Train Loss: {_train_loss:.4f}')

In [9]:
def test():
    vae.eval()
    test_loss = 0

    with torch.no_grad():
        for data, _ in tqdm(test_loader):
            data = data.to(device)
            reconst, mu, log_var = vae(data)

            test_loss += vae_loss(reconst, data, mu, log_var).item()

    test_loss /= len(test_loader.dataset)
    print(f'Test Loss: {test_loss:.4f}')

In [10]:
for epoch in range(1, 51):
    train(epoch)
    test()

  0%|          | 0/469 [00:00<?, ?it/s]

: 

: 