In [1]:
#|default_exp training
#|export

from LendingClubAutoencoder import autoencoders

import torch
import torch.nn.functional as F

import glob
import os

In [2]:
#|test
from LendingClubAutoencoder import preprocessing
from datetime import datetime

In [4]:
#|export

def masked_vae_loss(reconstruction:torch.Tensor, x:torch.Tensor, not_null_mask:torch.Tensor, binary_mask:torch.Tensor)->tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    Custom VAE loss function with masking for handling replaced NaN values
    It combines root mean square error for numerical values with the binary cross entropy loss for categorical features 
    
    Args:
        reconstruction (torch.Tensor): Reconstructed input from the decoder
        x (torch.Tensor): Original input
        mean (torch.Tensor): Mean of the latent distribution
        mask (torch.Tensor): Binary mask where 1 indicates valid values and 0 indicates replaced NaN values

    Returns:
        tuple: (total_loss, reconstruction_loss, kl_divergence_loss)
    '''

    # Create masks for binary and numeric features

    inverse_binary_mask = ~binary_mask

    binary_mask = not_null_mask & binary_mask[None, :]
    numeric_mask = not_null_mask & inverse_binary_mask[None, :]
    
    # Calculate RMSE loss for numeric features only
    mse_loss = F.mse_loss(reconstruction[numeric_mask], x[numeric_mask], reduction='mean')# Only compute MSE for non-masked values
    rmse_loss = torch.sqrt(mse_loss)

    # Calculate BCE loss for binary features only
    bce_loss = F.binary_cross_entropy(reconstruction[binary_mask], x[binary_mask], reduction='mean')
    
    total_loss = rmse_loss + bce_loss * 2# Combine losses (with weighting for BCE to create parity in the prioritisation of both feature types)
        
    return total_loss, rmse_loss, bce_loss


In [5]:
#|export

def train_variational_autoencoder(model, optimiser, train_loader, validation_loader, binary_mask, n_epochs:int=100, patience:int=10, device:str='cuda'):
    '''
    Trains a Variational Autoencoder (VAE) model with early stopping based on validation loss.

    Args:
        model: The VAE model to be trained.
        optimiser: The optimiser for updating model parameters.
        train_loader: DataLoader for the training data.
        validation_loader: DataLoader for the validation data.
        binary_mask: Mask indicating which features are binary.
        n_epochs (int, optional): Maximum number of training epochs. Default is 100.
        patience (int, optional): Number of epochs to wait for improvement before early stopping. Default is 10.
        device (str, optional): Device to use for training ('cuda' or 'cpu'). Default is 'cuda'.

    Returns:
        None. The best model is saved to disk during training.
    '''

    # Move model and mask to the specified device
    model = model.to(device)
    binary_mask = binary_mask.to(device)

    first_run = True
    
    # Initialise early stopping variables
    best_validation_loss = float('inf')
    patience_counter = 0
    
    # Training loop
    for epoch in range(n_epochs):
        model.train()# Set model to training mode
        train_total_loss = 0
        train_rmse_loss = 0
        train_bce_loss = 0
        
        for data, not_null_mask in train_loader:
            data, not_null_mask = data.to(device), not_null_mask.to(device)
            
            reconstruction, mean = model(data)

            loss, rmse_loss, bce_loss = masked_vae_loss(
                reconstruction, data, not_null_mask, binary_mask
            )# Forward pass and compute loss
            
            optimiser.zero_grad()
            loss.backward()#Backpropagation
            optimiser.step()#Optimiser step
            
            train_total_loss = train_total_loss + loss.item()# Accumulate training losses
            train_rmse_loss = train_rmse_loss + rmse_loss.item()
            train_bce_loss = train_bce_loss + bce_loss.item()
            
        # Validation Phase  
        model.eval()# Set model to evaluation mode
        validation_total_loss = 0
        validation_rmse_loss = 0
        validation_bce_loss = 0
        
        with torch.no_grad():  # No gradients needed for validation
            for data, not_null_mask in validation_loader:
                data, not_null_mask = data.to(device), not_null_mask.to(device)
                
                reconstruction, mean = model(data)
                loss, rmse_loss, bce_loss = masked_vae_loss(
                    reconstruction, data, not_null_mask, binary_mask
                )
                
                validation_total_loss = validation_total_loss + loss.item()# Accumulate validation losses
                validation_rmse_loss = validation_rmse_loss + rmse_loss.item()
                validation_bce_loss = validation_bce_loss + bce_loss.item()

        print("Mean stats:", mean.mean().item(), mean.std().item())# Print statistics for the current epoch
        
        # Calculate average losses
        average_train_loss = train_total_loss / len(train_loader)
        average_validation_loss = validation_total_loss / len(validation_loader)
        
        print(f'Epoch {epoch+1}/{n_epochs}:')
        print(f'  Training Loss: {average_train_loss:.4f}')
        print(f'  Total Validation Loss: {average_validation_loss:.4f}, Total RMSE Loss{(validation_rmse_loss/ len(validation_loader)):.4f}, Total BCE Loss{(validation_bce_loss/ len(validation_loader)):.4f}')
        
        # Early stopping check
        if average_validation_loss < best_validation_loss:
            best_validation_loss = average_validation_loss
            patience_counter = 0
            
            if first_run:
                os.makedirs('trained_models', exist_ok=True)
                first_run = False

            file_name = f'trained_models/vae_best-input_size:{model.input_size}.pt'
            torch.save(model.state_dict(), file_name)

        else:
            patience_counter = patience_counter + 1
            if patience_counter >= patience:
                print(f'Early stopping triggered after {epoch+1} epochs')
                break

In [6]:
#|export

def get_best_model(model_class, sigmoid_mask, file_name:str = None):
    '''
    Loads the best saved VAE model from disk, reconstructing its arguments from the filename.

    Args:
        model_class: The class of the VAE model to instantiate.
        sigmoid_mask: Mask indicating which features require a sigmoid activation.
        file_name (str, optional): Path to the saved model file. If None, will search for the best model in 'trained_models/'.

    Returns:
        model: The loaded VAE model with weights restored.
    '''
    if file_name is None:
        matching_files = glob.glob('trained_models/vae_best*.pt')#Returns first eligible model if specific model is not outlined

        if matching_files:
            file_name = matching_files[0]
        else:
            raise FileNotFoundError(f'No best model found in ../trained_models/')
    
    # Parse model parameters from the filename
    parameters = file_name[:]
    parameters = parameters.split('.p')[0]
    parameters = parameters.split('-')[1:]

    model_args = {parameter.split(':')[0] : int(parameter.split(':')[1]) for parameter in parameters}
    model_args['sigmoid_mask'] = sigmoid_mask
 
    model = model_class(**model_args)
    model.load_state_dict(torch.load(file_name, map_location=torch.device('cpu')))# Load model weights from file

    return model


In [None]:
#|test

#Get and preproces data
lending_club_data_handler = preprocessing.DataHandler(csv_path='../local_data/all_lending_club_loan_data_2007-2018.csv')

start = datetime(2007,1,1)
end = datetime(2008,12,31)
train_data, train_mask = lending_club_data_handler.get_train_data(start, end)

start = datetime(2009,1,1)
end = datetime(2009,12,31)
validation_data, validation_mask = lending_club_data_handler.get_test_data(start, end)

train_loader = preprocessing.to_torch_dataloader(train_data, train_mask)
validation_loader = preprocessing.to_torch_dataloader(validation_data, validation_mask)

#Train model
model = autoencoders.VariationalAutoencoder(input_size=len(train_data[0]), hidden_size=64, latent_size=32)
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)#original is 1e-3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_variational_autoencoder(model, optimiser, train_loader, validation_loader, n_epochs=100, patience=10, device=device)