In [12]:
from variational_autoencoder import VAE
import torch.nn.functional as F
import torch
from tqdm import tqdm
import os

In [13]:
# Grabs all jpg files in a directory and subdirectories
def get_data(path):
    data = []
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".jpg"):
                data.append(os.path.join(root, file))
    return data

In [14]:
def reconstruction_loss(x, x_hat):
    return F.mse_loss(x_hat, x, reduction='sum')

def kl_divergence_loss(mu, stddev, beta=0.5):
    return beta * torch.sum(torch.exp(stddev) + mu**2 - 1.0 - stddev)

def loss_function(x, x_hat, mu, stddev, beta=0.5):
    return reconstruction_loss(x, x_hat) + kl_divergence_loss(mu, stddev, beta)

In [15]:
def train(model, device, optimizer, train_dataloader, epochs, loss_output_interval):
    model.to(device)
    model.train()
    
    for step in tqdm(range(epochs), desc="Epoch"):
        running_loss = 0
        total_running_loss = 0
        
        for i, (X, y) in enumerate(tqdm(train_dataloader, desc="Batch", leave=False)):
            X = X.to(device)
            y = y.to(device)
            X_hat, mu, stddev = model(X)
            loss = loss_function(X, X_hat, mu, stddev)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            total_running_loss += loss.item()
        
        if (step+1) % loss_output_interval == 0:
            print(f' Epoch {step+1} Average Batch Loss: {total_running_loss/len(train_dataloader)}')
            
    model.eval()

['../data/simpsons_dataset/edna_krabappel/pic_0413.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0174.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0184.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0041.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0221.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0327.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0323.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0285.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0228.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0144.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0051.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0121.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0157.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0037.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0058.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0118.jpg',
 '../data/simpsons_dataset/edna_krabappel/pic_0152.jpg',
 '../data/simpsons_dataset/edna