In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.optim import Adam
from torch.utils.data import DataLoader

In [2]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])
train_ds = datasets.FashionMNIST('datasets', train=True, transform=transform, download=False)
test_ds = datasets.FashionMNIST('datasets', train=False, transform=transform, download=False)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=True, num_workers=4)


In [3]:
from VAE.models.vanilla_vae import VanillaVAE
vae = VanillaVAE(1, 50)
optimizer = Adam(vae.parameters(), lr=1e-3)

In [4]:
def train_loop(dataloader, model, loss_fn ,optimizer):
    size = len(dataloader.dataset)
    for batch, (x, _) in enumerate(dataloader):
        recons, mu, log_var = model(x)
        loss = loss_fn(recons, x, mu, log_var)

        #back prop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(x)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for x in dataloader:
            recons, mu, log_var = model(X)
            test_loss += loss_fn(recons, x).item()

    test_loss /= num_batches
    print(f"Test Avg loss: {test_loss:>8f} \n")

epochs = 50
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, vae, VanillaVAE.loss_function, optimizer)
    test_loop(test_loader, vae, VanillaVAE.loss_function)

Epoch 1
-------------------------------
loss: 0.263990  [    0/60000]
loss: 0.028654  [ 6400/60000]


KeyboardInterrupt: 