In [22]:
from variational_autoencoder import VAE
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from skimage import io, transform
import os
import sys
import PIL

In [28]:
def get_files_in_sub_dirs(path):
    file_names = []
    for root, dirs, files in os.walk(path):
        for file in files:
            file_names.append(os.path.join(root, file))
            
    return file_names

def get_data(path):
    files_names = get_files_in_sub_dirs(path)
    files_names = [file for file in files_names if file.endswith('.jpg')]
    return files_names

In [25]:
def reconstruction_loss(x, x_hat):
    return F.mse_loss(x_hat, x, reduction='sum')

def kl_divergence_loss(mu, stddev, beta=0.5):
    return beta * torch.sum(torch.exp(stddev) + mu**2 - 1.0 - stddev)

def loss_function(x, x_hat, mu, stddev, beta=0.5):
    return reconstruction_loss(x, x_hat) + kl_divergence_loss(mu, stddev, beta)

In [None]:
class SimpsonImageDataset(Dataset):
    def __init__(self, root_dir, transform=None) -> None:
        self.root_dir = root_dir
        self.image_names = get_data(self.root_dir)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_names)
    
    def __getitem__(self, index) -> Any:
        if torch.is_tensor(index):
            index = index.tolist()

In [26]:
def train(model, device, optimizer, train_dataloader, epochs, loss_output_interval):
    model.to(device)
    model.train()
    
    for step in tqdm(range(epochs), desc="Epoch"):
        running_loss = 0
        total_running_loss = 0
        
        for i, (X, y) in enumerate(tqdm(train_dataloader, desc="Batch", leave=False)):
            X = X.to(device)
            y = y.to(device)
            X_hat, mu, stddev = model(X)
            loss = loss_function(X, X_hat, mu, stddev)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            total_running_loss += loss.item()
        
        if (step+1) % loss_output_interval == 0:
            print(f' Epoch {step+1} Average Batch Loss: {total_running_loss/len(train_dataloader)}')
            
    model.eval()