### Libraries

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import os
from umap import UMAP

### Autoencoder class

In [3]:
class Autoencoder(nn.Module):
    def __init__(self, d_z = 10, input_size = 24, batch_size = 2):
        super().__init__()
        self.d_z = d_z
        self.input_size = input_size
        self.batch_size = batch_size
        self.encoder = nn.Sequential(
            nn.Linear(input_size + batch_size, 128),
            nn.Linear(128, 64),
            nn.Linear(64, d_z)
        )

        self.decoder = nn.Sequential(
            nn.Linear(d_z, 64),
            nn.Linear(64, 128),
            nn.Linear(128, input_size)
        )

    def forward(self, x):
        x = x.view(-1, self.input_size + self.batch_size)
        z = self.encoder(x)
        decoded = self.decoder(z)
        return decoded.view(-1, self.input_size)
    
    def encode(self, x):
        x = x.view(-1, self.input_size + self.batch_size)
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)

### Train model function

In [2]:
def train_model(
        model,
        train_loader,
        val_loader,
        test_loader,
        criterion,
        optimizer,
        num_epochs,
        device,
        model_name = "model",
        save_model = False
):
    train_losses = []
    test_losses = []
    val_losses = []
    lowest_val_loss = float('inf')
    best_epoch = -1
    best_model_state = None

    for epoch in range(num_epochs):
        #Training
        model.train()
        train_loss = 0.0
        for x, _ in train_loader:
            #Forward pass
            x = x.to(device)
            out = model(x)

            #Backpropagation
            loss = criterion(out, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #Save loss values
            train_loss += loss.item()
        train_loss /= len(train_loader)
        train_losses.append(train_loss)

        #Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for x, _ in val_loader:
                #Forward pass
                x = x.to(device)
                out = model(x)

                #Save loss values
                loss = criterion(out, x)
                val_loss += loss.item()
            
        val_loss /= len(val_loader)
        val_losses.append(val_loss)

        if val_loss < lowest_val_loss:
            lowest_val_loss = val_loss
            best_epoch = epoch
            best_model_state = model.state_dict()
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    #Testing
    test_loss = 0.0
    with torch.no_grad():
        for x, _ in test_loader:
            #Forward pass
            x = x.to(device)
            out = model(x)

            #Save loss values
            loss = criterion(out, x)
            test_loss += loss.item()
    
    test_loss /= len(test_loader)
    test_losses.append(test_loss)
    print(f"Test Loss: {test_loss:.4f}")

    if save_model == True:
        script_path = os.getcwd()
        dest = os.path.join(script_path, f'model_weights/{model_name}_best.pth')
        if not os.path.exists(os.path.dirname(dest)):
            os.makedirs(os.path.dirname(dest), exist_ok=True)
        torch.save(best_model_state, dest)

    return train_losses, val_losses, test_losses, best_epoch