# Description

This section is for training the VAE implemented in the VAE.py file

# Imports

In [1]:
import torch
import math
import torch.nn as nn
from loss_function import vae_loss
from torchvision import transforms
import numpy as np
from dataset import PatternDB
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from VAE import VAE
import gc
from utils import display_images

# Start

In [2]:
DATA_PATH = "../data/train"
BATCH_SIZE = 32
INPUT_SHAPE = 512

In [3]:
transform = transforms.Compose([
    transforms.Resize((INPUT_SHAPE, INPUT_SHAPE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [4]:
dataset = PatternDB(data_path=DATA_PATH, transform=transform)

In [5]:
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [13]:
vae = VAE(input_size=INPUT_SHAPE, in_channel=3, latent_space_dim=[1024], device=device, kernel_size=4, features=[16, 32, 64, 128, 256, 512])
vae.load_state_dict(torch.load(f"vae_model_epoch_latest.pth", map_location=device))

Bottleneck size: 32768


  vae.load_state_dict(torch.load(f"vae_model_epoch_latest.pth", map_location=device))


<All keys matched successfully>

In [8]:
def anneal(current_epoch, total_epochs):
   
    return min(1,current_epoch/total_epochs)

In [9]:
def train_one(model, loss_fn, optimizer, dataloader, device, total_epochs, current_epoch):

    running_loss = 0.0
    recon = 0.0
    edge = 0.0
    kl = 0.0
    beta = anneal(current_epoch, total_epochs)

    for i, data in enumerate(dataloader):

        optimizer.zero_grad()
        data = data.to(device)

        mu, log_var = model.encode(data)


        z = model._reparameterize(mu, log_var)


        reconstructed = model.decode(z)


        loss, recon_loss, edge_loss, kl_loss  = loss_fn(data, reconstructed, mu, log_var, beta)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()


        running_loss+=loss.item()
        recon+=recon_loss
        edge+=edge_loss
        kl += kl_loss
        print(f"Batch {i+1} loss (total): {loss.item()}\n")
      #  torch.cuda.empty_cache()
    return running_loss/len(dataloader), recon/len(dataloader), edge/len(dataloader), kl/len(dataloader)


In [10]:
def train(model, train_loader, loss_fn, optimizer, device, epochs):


    model.to(device)


    for epoch in range(epochs):
        

        print(f'EPOCH {epoch+1}:')

        model.train(True)
        avg_loss, recon_loss, edge_loss, kl_loss= train_one(model, loss_fn, optimizer, dataloader, device, epochs, epoch+1)
        print(f'Average loss: {avg_loss}, Average recon_loss: {recon_loss}, Average edge loss: {edge_loss} Average KL loss {kl_loss}\n\n')
        
        print(f"Model result after epoch {epoch+1}")
        display_images(vae, dataloader, device)

        if (epoch+1) % 10 == 0:
            torch.save(model.state_dict(), f"vae_model_epoch_latest.pth")


In [11]:
optimizer = torch.optim.Adam(params=vae.parameters(), lr=0.00005, eps=1e-8)
loss_fn = vae_loss(device=device)

In [None]:
train(vae, dataloader, loss_fn, optimizer, device, epochs=300)