In [1]:
#|default_exp training
#|export

from LendingClubAutoencoder import autoencoders

import torch
import torch.nn.functional as F

import glob
import os

In [2]:
#|test
from LendingClubAutoencoder import preprocessing
from datetime import datetime

In [4]:
#|export

def masked_vae_loss(reconstruction:torch.Tensor, x:torch.Tensor, mean:torch.Tensor, log_variance:torch.Tensor, mask:torch.Tensor)->tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    Custom VAE loss function with masking for handling replaced NaN values
    
    Args:
        reconstruction (torch.Tensor): Reconstructed input from the decoder
        x (torch.Tensor): Original input
        mean (torch.Tensor): Mean of the latent distribution
        log_variance (torch.Tensor): Log variance of the latent distribution
        mask (torch.Tensor): Binary mask where 1 indicates valid values and 0 indicates replaced NaN values

    Returns:
        tuple: (total_loss, reconstruction_loss, kl_divergence_loss)
    '''
    # Reconstruction Loss (masked MSE)
    
    mse_loss = F.mse_loss(reconstruction * mask, x * mask, reduction='sum')# Only compute MSE for non-masked values
    mse_loss = mse_loss / mask.sum()# normalise by the number of unmasked elements
    rmse_loss = torch.sqrt(mse_loss)
    
    # KL Divergence Loss
    kl_loss = -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())
    kl_loss = kl_loss / x.size(0)  # normalise by batch size
    
    # Total loss (beta-VAE formulation)
    total_loss = rmse_loss + kl_loss
        
    return total_loss, rmse_loss, kl_loss


In [5]:
#|export

def train_variational_autoencoder(model, optimiser, train_loader, validation_loader, n_epochs:int=100, patience:int=10, device:str='cuda'):
    '''
    Training loop for the VAE with early stopping
    
    Args:
        model (VAE): The VAE model
        optimiser (torch.optim.optimiser): The optimiser
        train_loader (DataLoader): Training data
        validation_loader (DataLoader): Validation data
        n_epochs (int): Maximeanm number of epochs
        patience (int): Number of epochs to wait for improvement before stopping
        device (str): Device to train on
    '''
    first_run = True
    
    best_validation_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(n_epochs):
        model.train()
        train_total_loss = 0
        train_reconstruction_loss = 0
        train_kl_loss = 0
        
        for data, mask in train_loader:
            data, mask = data.to(device), mask.to(device)
            
            reconstruction, mean, log_variance = model(data)

            loss, reconstruction_loss, kl_loss = masked_vae_loss(
                reconstruction, data, mean, log_variance, mask
            )
            
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
            
            train_total_loss = train_total_loss + loss.item()
            train_reconstruction_loss = train_reconstruction_loss + reconstruction_loss.item()
            train_kl_loss = train_kl_loss + kl_loss.item()
            
        # Validation Phase  
        model.eval()
        validation_total_loss = 0
        validation_reconstruction_loss = 0
        validation_kl_loss = 0
        
        with torch.no_grad():  # No gradients needed for validation
            for data, mask in validation_loader:
                data, mask = data.to(device), mask.to(device)
                
                reconstruction, mean, log_variance = model(data)
                loss, reconstruction_loss, kl_loss = masked_vae_loss(
                    reconstruction, data, mean, log_variance, mask
                )
                
                validation_total_loss = validation_total_loss + loss.item()
                validation_reconstruction_loss = validation_reconstruction_loss + reconstruction_loss.item()
                validation_kl_loss = validation_kl_loss + kl_loss.item()
        
        # Calculate average losses
        average_train_loss = train_total_loss / len(train_loader)
        average_validation_loss = validation_total_loss / len(validation_loader)
        
        print(f'Epoch {epoch+1}/{n_epochs}:')
        print(f'  Training Loss: {average_train_loss:.4f}')
        print(f'  Validation Loss: {average_validation_loss:.4f}')
        
        # Early stopping check
        if average_validation_loss < best_validation_loss:
            best_validation_loss = average_validation_loss
            patience_counter = 0
            
            if first_run:
                os.makedirs('trained_models', exist_ok=True)
                first_run = False

            file_name = f'trained_models/vae_best-input_size:{model.input_size}.pt'
            torch.save(model.state_dict(), file_name)

        else:
            patience_counter = patience_counter + 1
            if patience_counter >= patience:
                print(f'Early stopping triggered after {epoch+1} epochs')
                break

In [6]:
#|export

def get_best_model(model_class, file_name:str = None):
    if file_name is None:
        matching_files = glob.glob('trained_models/vae_best*.pt')

        if matching_files:
            file_name = matching_files[0]
        else:
            raise FileNotFoundError(f'No best model found in ../trained_models/')
    
    parameters = file_name[:]
    parameters = parameters.split('.p')[0]
    parameters = parameters.split('-')[1:]

    model_args = {parameter.split(':')[0] : int(parameter.split(':')[1]) for parameter in parameters}

    model = model_class(**model_args)
    model.load_state_dict(torch.load(file_name))

    return model


In [None]:
#|test

#Get and preproces data
lending_club_data_handler = preprocessing.DataHandler(csv_path='../local_data/all_lending_club_loan_data_2007-2018.csv')

start = datetime(2007,1,1)
end = datetime(2008,12,31)
train_data, train_mask = lending_club_data_handler.get_train_data(start, end)

start = datetime(2009,1,1)
end = datetime(2009,12,31)
validation_data, validation_mask = lending_club_data_handler.get_test_data(start, end)

train_loader = preprocessing.to_torch(train_data, train_mask)
validation_loader = preprocessing.to_torch(validation_data, validation_mask)

#Train model
model = autoencoders.VariationalAutoencoder(input_size=len(train_data[0]), hidden_size=64, latent_size=32)
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)#original is 1e-3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_variational_autoencoder(model, optimiser, train_loader, validation_loader, n_epochs=100, patience=10, device=device)