In [None]:
import imageio
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from tqdm import tqdm


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
import matplotlib
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
matplotlib.style.use('ggplot')

In [None]:
from vae_model import *
from training import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ConvVAE().to(device)
lr = 0.001
epochs = 200
batch_size = 128
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 10, factor = 0.1)

criterion = nn.BCELoss(reduction='sum')


In [None]:

transform1 = transforms.Compose([
                                transforms.Resize(32),
                                transforms.RandomAffine(degrees=30, scale=(.9, 1.1),translate=(0.1, 0.2)),
                                transforms.ToTensor(),
])
# training set and train data loader
trainset = torchvision.datasets.MNIST(
    root='./DATA_MNIST/', train=True, download=True, transform=transform1
)
trainloader = DataLoader(
    trainset, batch_size=batch_size, shuffle=True
)
# validation set and validation data loader
testset = torchvision.datasets.MNIST(
    root='./DATA_MNIST/', train=False, download=True, transform=transform1
)
testloader = DataLoader(
    testset, batch_size=batch_size, shuffle=False
)

In [None]:
import os
def show_images(images, nmax=64):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(images.detach()[:nmax], nrow=8).permute(1, 2, 0))

def show_batch(dl, nmax=64):
    for images, _ in dl:
        show_images(images, nmax)
        break
show_batch(trainloader)

In [None]:
train_loss = []
valid_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(
        model, trainloader, trainset, device, optimizer, criterion
    )
    valid_epoch_loss, recon_images = validate(
        model, testloader, testset, device, criterion
    )
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {valid_epoch_loss:.4f}")
    scheduler.step()

In [None]:
prediction = model.decoder(torch.randn(1,16).cuda())[0]
prediction.squeeze(0).shape

In [None]:
torch.save(model.state_dict(), 'conv_vae300.pth')
