In [None]:
import os
from datetime import datetime

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

In [None]:
from vad.architectures import STAD
from vad.datasets import TrajectoryDataset, ExactBatchSampler

In [3]:
torch.set_num_threads(8)

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cuda


# Experiment Parameters

In [13]:
# Experiment parameters
exp_type = 'components'
include_weather = True
n_weather_vars = 5
n_components_gmm = [1, 5, 10, 20, 25, 30, 40, 50, 100]
hidden_dims_gmm = [int(((n_components_gmm[i] - 20) / 20) * 32 + 32) for i in range(len(n_components_gmm))]

print(
    f'Numbers of Gaussian Mixture Model Components:\t{n_components_gmm}',
    f'Hidden Dimensions of GMM:\t\t\t{hidden_dims_gmm}', sep='\n'
)

# Training parameters
epochs = 100
patience = epochs
learning_rate = 1e-5 # Peak LR
embed_dim = 32
latent_dim_ae = 32
weight_decay = 0.1
dropout = 0.1
n_head_te = 8
n_layers_te = 4
eps_gmm = 1e-7
eps_loss = 1

Numbers of Gaussian Mixture Model Components:	[1, 5, 10, 20, 25, 30, 40, 50, 100]
Hidden Dimensions of GMM:			[1, 8, 16, 32, 40, 48, 64, 80, 160]


# STAD Loss Function

In [5]:
def calculate_gmm_penalty(sigma, epsilon=eps_loss):
    """
    Vectorized computation of GMM penalty (sum of reciprocals of diagonal elements)

    sigma: Component covariances. Shape: [num_components, input_dim, input_dim]
    epsilon: Small value for numerical stability
    """
    # Extract diagonal elements from all covariance matrices at once
    # Shape: [num_components, input_dim]
    diag_elements = torch.diagonal(sigma, dim1=-2, dim2=-1)

    # Add epsilon for numerical stability before taking reciprocal
    # This prevents division by very small numbers
    penalty = torch.sum(1.0 / (diag_elements + epsilon))

    return penalty

def compute_full_loss(penalty, transformer_loss, energy, d,
                      lambda_1=1, lambda_2=1, lambda_3=5e-3):
                      # λ₁=1, λ₂=1, λ₃=0.005 as in the STAD publication
    return transformer_loss + (lambda_1 * energy) + (lambda_2 * d) + (lambda_3 * penalty)

# STAD Unbiased Loss Function

In [6]:
def evaluate_training_set(model, train_dataloader, device):
    """
    Evaluate the training set with model in eval mode to get unbiased loss.
    Returns the average training loss without gradients or dropout effects.
    """
    model.eval()
    total_train_loss = 0

    with torch.no_grad():
        for batch in tqdm(train_dataloader, desc="Evaluating Training Set"):
            # Move data to device
            inputs = {k: v.to(device) for k, v in batch.get('src_window').items()}
            targets = {k: v.to(device) for k, v in batch.get('tgt_window').items()}
            weather_stats = batch.get('weather_stats', None).to(device)

            # Forward pass (will use testing=False path due to eval mode)
            l, energy, d_h, sigma = model(inputs, targets, weather_stats)
            l, energy, d_h = l.mean(), energy.mean(), d_h.mean()

            # Calculate loss components
            penalty = calculate_gmm_penalty(sigma)
            penalty = penalty.mean()

            # Compute final loss
            train_eval_loss = compute_full_loss(penalty, l, energy, d_h).mean()

            # Accumulate loss
            total_train_loss += train_eval_loss.item()

    # Return average loss
    return total_train_loss / len(train_dataloader)

# STAD Validation Loop

In [7]:
def validate(model, dataloader, device):

    total_val_loss = 0
    total_energy = 0
    total_te_loss = 0

    model.eval()

    for batchidx, batch in enumerate(tqdm(dataloader, desc="Validation")):

        # Move data to device
        inputs = {k: v.to(device) for k, v in batch.get('src_window').items()}
        targets = {k: v.to(device) for k, v in batch.get('tgt_window').items()}
        weather_stats = batch.get('weather_stats', None).to(device)

        # Pass data to model
        l, energy, d_h, sigma = model(inputs, targets, weather_stats)
        l, energy, d_h = l.mean(), energy.mean(), d_h.mean()

        # Calculate loss components
        penalty = calculate_gmm_penalty(sigma)
        penalty = penalty.mean()

        # Compute the final loss
        stad_loss = compute_full_loss(penalty, l, energy, d_h)
        stad_loss = stad_loss.mean()

        # Update total validation loss and total energy
        total_val_loss += stad_loss.item()
        total_energy += energy.item()
        total_te_loss += l.item()

    # Calculate average validation loss and energy
    avg_val_loss = total_val_loss / len(dataloader)
    avg_energy = total_energy / len(dataloader)
    avg_te_loss = total_te_loss / len(dataloader)
    return avg_val_loss, avg_energy, avg_te_loss

# STAD Training Loop

In [None]:
def train(model,
          train_dataloader,
          valid_dataloader,
          optimizer,
          scheduler,
          num_epochs,
          device,
          patience,
          save_dir='./models'):

    # Create directory
    os.makedirs(save_dir, exist_ok=True)

    # Initialize TensorBoard writer
    timestamp = datetime.now().strftime('%b%d_%H-%M-%S')
    writer = SummaryWriter(log_dir=f'./runs/{timestamp}_{experiment_name}')

    # Initialize variables for early stopping
    best_val_loss = float('inf')
    patience_counter = 0

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        # Training phase
        for batchidx, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")):

            # Move data to device
            inputs = {k: v.to(device) for k, v in batch.get('src_window').items()}
            targets = {k: v.to(device) for k, v in batch.get('tgt_window').items()}
            weather_stats = batch.get('weather_stats', None).to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Pass data to model
            l, energy, d_h, sigma = model(inputs, targets, weather_stats)
            l, energy, d_h = l.mean(), energy.mean(), d_h.mean()

            # Calculate loss components
            penalty = calculate_gmm_penalty(sigma)
            penalty = penalty.mean()

            # Compute the final loss
            stad_loss = compute_full_loss(penalty, l, energy, d_h).mean()

            # Update total loss for epoch
            train_loss += stad_loss.mean()

            # Print progress
            if batchidx % 200 == 0:
                writer.add_scalar('Batch/te_loss', l, epoch * len(train_dataloader) + batchidx)
                writer.add_scalar('Batch/Energy', energy, epoch * len(train_dataloader) + batchidx)
                writer.add_scalar('Batch/train_loss', stad_loss, epoch * len(train_dataloader) + batchidx)
                writer.add_scalar('Batch/Penalty', penalty*0.005, epoch * len(train_dataloader) + batchidx)
                print(f'Batch {batchidx}/{len(train_dataloader)} | Loss: {stad_loss:.6f}')

            # Backward pass and optimize
            stad_loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss = train_loss.detach()

        # Calculate average training loss for this epoch
        avg_train_loss = train_loss / len(train_dataloader)
        true_loss = evaluate_training_set(model, train_dataloader, device) # already averaged

        # Validation phase
        val_loss, avg_energy, avg_te_loss = validate(model, valid_dataloader, device)

        # Log metrics to TensorBoard
        writer.add_scalar('Epoch/train_loss', avg_train_loss, epoch)
        writer.add_scalar('Epoch/validation_loss', val_loss, epoch)
        writer.add_scalar('Epoch/avg_energy', avg_energy, epoch)
        writer.add_scalar('Epoch/avg_te_loss', avg_te_loss, epoch)
        writer.add_scalar('Epoch/learning_rate', scheduler.get_last_lr()[0], epoch)
        writer.add_scalar('Epoch/true_loss', true_loss, epoch)

        # Print epoch summary
        print(f'Epoch {epoch+1}/{num_epochs} | Average Train Loss: {avg_train_loss:.6f} | Average Validation Loss: {val_loss:.6f}')

        # Save latest model
        latest_model_path = os.path.join(save_dir, 'STAD_latest.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': val_loss
        }, latest_model_path)

        # Check if this is the best model so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0

            # Save best model
            best_model_path = os.path.join(save_dir, 'STAD_best.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': val_loss
            }, best_model_path)
            print(f"Saved new best model with validation loss: {val_loss:.6f}")
        else:
            patience_counter += 1
            print(f"Validation loss did not improve. Patience: {patience_counter}/{patience}")

        # Early stopping check
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs!")
            break

        # Free cached memory
        torch.cuda.empty_cache()

    # Close TensorBoard writer
    writer.close()

    # Load the best model
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1} with validation loss: {checkpoint['val_loss']:.6f}")

    return model

# Dataset Dataloader

In [None]:
traj_dataset_train = TrajectoryDataset(ds_type='train',
                                       lat_bins=400,
                                       lon_bins=400,
                                       sog_bins=30,
                                       cog_bins=72,
                                       file_directory='../../data',
                                       filename='joined-train-stad-weather.pkl',
                                       include_weather=include_weather)

traj_dataset_valid = TrajectoryDataset(ds_type='valid',
                                       lat_bins=400,
                                       lon_bins=400,
                                       sog_bins=30,
                                       cog_bins=72,
                                       file_directory='../../data',
                                       filename='joined-valid-stad-weather.pkl',
                                       include_weather=include_weather)

In [10]:
train_batch_sampler = ExactBatchSampler(traj_dataset_train.batch_boundaries, shuffle_batches=True)
valid_batch_sampler = ExactBatchSampler(traj_dataset_valid.batch_boundaries, shuffle_batches=True)

In [11]:
data_loader_train = data.DataLoader(traj_dataset_train, batch_sampler=train_batch_sampler, num_workers=4, pin_memory=True, persistent_workers=True)
data_loader_valid = data.DataLoader(traj_dataset_valid, batch_sampler=valid_batch_sampler, num_workers=4, pin_memory=True, persistent_workers=True)

# Training call

In [None]:
# For each num of gaussian components:
for n_components, hidden_dim_gmm in tqdm(zip(n_components_gmm, hidden_dims_gmm), desc='Running Experiment'):

    experiment_name = f'{exp_type}_epochs_{epochs}_embed_{embed_dim}_wd_{weight_decay}_lr_{learning_rate}_hgmm_{hidden_dim_gmm}_lae_{latent_dim_ae}_comp_{n_components}_{datetime.now().strftime('%b%d_%H-%M-%S')}'
    print(experiment_name)

    # - Create a model
    stad = STAD(
        n_lat_bins=400,
        n_lon_bins=400,
        n_sog_bins=30,
        n_cog_bins=72,
        max_seq_len=10,
        embed_dim=embed_dim,
        dropout=dropout,
        nhead_te=n_head_te,
        n_layers_te=n_layers_te,
        latent_dim_ae=latent_dim_ae,
        n_weather_vars=n_weather_vars,
        hidden_dim_gmm=hidden_dim_gmm,
        eps_gmm=eps_gmm,
        n_components_gmm=n_components).to(device)

    print(stad)

    # Setup optimizer
    optimizer = AdamW(stad.parameters(),
    #                 betas=(0.5, 0.999), # Lower b1 because of variation in batch (trajectory) length
                     weight_decay=weight_decay)

    # Setup LR Scheduler
    scheduler = OneCycleLR(optimizer,
                        max_lr=learning_rate,            # Peak learning rate
                        epochs=epochs,
                        steps_per_epoch=len(data_loader_train),
                        anneal_strategy='cos'
    )

    # Train the model for same number of epochs and save the model
    final_model = train(stad,
                        data_loader_train,
                        data_loader_valid,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        num_epochs=epochs,
                        device=device,
                        patience=patience,
                        save_dir=f'./models/{experiment_name}')

    del stad, final_model, optimizer, scheduler