In [20]:
import pandas as pd
import numpy as np
import os
import random

import torch
import torch.nn as nn
import os
import yaml
import itertools
import json
import traceback

from datetime import datetime
from tqdm import tqdm
from pathlib import Path
from dotenv import load_dotenv
from torch.utils.data import DataLoader, TensorDataset

from trajectory_predictor import TrajectoryPredictor, trajectory_loss, DEVICE
from seed import set_seed
from config import flatten_config

from classification_rnn import ClassificationRNN, DEVICE
from seed import set_seed
from config import flatten_config

BASE = "../../data/aisdk/processed"


# 0 - Fetch the data

In [21]:
train_traj = np.load(os.path.join(BASE, "windows/train_trajectories.npz"))
val_traj = np.load(os.path.join(BASE, "windows/val_trajectories.npz"))
test_traj = np.load(os.path.join(BASE, "windows/test_trajectories.npz"))

X_train, X_val = train_traj["past"], val_traj["past"]
y_train, y_val = train_traj["future"], val_traj["future"]
c_train, c_val = train_traj["cluster"], val_traj["cluster"]



In [23]:
# %% Create data loaders
def make_loaders(batch_size, cid):

    # Boolean masks for this cluster
    train_mask = (c_train == cid)
    val_mask   = (c_val == cid)

    # Subset numpy arrays
    X_tr = X_train[train_mask]
    y_tr = y_train[train_mask]
    X_v  = X_val[val_mask]
    y_v  = y_val[val_mask]
    
    # %% Convert to PyTorch tensors
    X_train_t = torch.tensor(X_tr, dtype=torch.float32)
    X_val_t   = torch.tensor(X_v,   dtype=torch.float32)

    y_train_t = torch.tensor(y_tr, dtype=torch.float32)
    y_val_t   = torch.tensor(y_v,   dtype=torch.float32)

    # Create data loaders
    print(f"\nTensor shapes:")
    print(f"  X_train_t: {X_train_t.shape}")
    print(f"  X_val_t:   {X_val_t.shape}")
    print(f"  y_train_t: {y_train_t.shape}")
    print(f"  y_val_t:   {y_val_t.shape}")

    train_loader = DataLoader(TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True, drop_last=False)
    val_loader = DataLoader(TensorDataset(X_val_t, y_val_t), batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader

# 1 - Define training functions 
## 1.1 - One run

In [24]:
def _train_one_run(cfg, train_loader, val_loader, cid):
    device = cfg["device"]

    model = TrajectoryPredictor(
        input_dim=cfg["input_dim"],
        hidden_dim=cfg["hidden_dim"],
        output_dim=cfg["output_dim"],
        num_layers_encoder=cfg["num_layers_encoder"],
        num_layers_decoder=cfg["num_layers_decoder"],
        attn_dim=cfg["attn_dim"]
    ).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])

    best_val_mse = float("inf")

    # ------ Training loop ------ 
    train_samples = 0
    for epoch in range(1, cfg["epochs"] + 1):

        model.train()
        total = 0.0

        for xb, yb in tqdm(train_loader, desc=f"Cluster {cid} Epoch {epoch}/{cfg['epochs']}"):
            xb = xb.to(device)
            yb = yb.to(device)


            opt.zero_grad()
            yb_pred = model(xb, target_length=yb.size(1), targets=yb, teacher_forcing_ratio=cfg["teacher_forcing"])
            
            loss = trajectory_loss(yb_pred, yb)
            loss.backward()
            
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg["max_norm"])
            opt.step()
            total += loss.item() * xb.size(0)
            train_samples += xb.size(0)


        train_mse = total / len(train_loader)


        # --- VALIDATION ---
        model.eval()
        val_total = 0.0
        val_samples = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device)

                yb_pred = model(
                    xb,
                    target_length=yb.size(1),
                    targets=None,
                    teacher_forcing_ratio=0.0
                )

                loss = trajectory_loss(yb_pred, yb)
                bs = xb.size(0)

                val_total += loss * bs
                val_samples += bs.size(0)

        val_mse = val_total / len(val_loader)
        best_val_mse = min(best_val_mse,val_mse)

        return best_val_mse


## 1.2 - Function for hyperparameter tuning

In [25]:
# Define key metrics
input_dim = X_train.shape[-1]
device = "cuda" if torch.cuda.is_available() else "cpu"

In [27]:
def hyperparameter_search_trajectory(
    device=device,
    input_dim=input_dim,
    output_dim=input_dim,
    cid=1,
    search_type='grid',
    save_dir='../../checkpoints/hyperparameter_results_trajectory',
):
    os.makedirs(save_dir, exist_ok=True)

    print("=" * 70)
    print(f"HYPERPARAMETER TUNING FOR TRAJECTORY PREDICTOR (Cluster {cid})")
    print("=" * 70)

    # -------- 1) Define search space --------
    if search_type == 'grid':
        param_grid = {
            # main model knobs
            "hidden_dim": [64, 128],
            "num_layers_encoder": [1, 2],
            "num_layers_decoder": [1, 2],
            "attn_dim": [64, 128],
            "batch_size": [256],

            # training knobs
            "lr": [1e-4, 3e-4, 1e-3],
            "weight_decay": [0.0, 1e-4],
            "teacher_forcing": [0.3, 0.5, 0.7],
            "max_norm": [1.0],
            "epochs": [20],
        }
        # total combos = 2*2*2*2 * 3 * 2 * 3 * 1 * 1 = 288 (you can shrink!)

    elif search_type == 'quick':
        param_grid = {
            "hidden_dim": [64, 128],
            "num_layers_encoder": [1],
            "num_layers_decoder": [1],
            "attn_dim": [64],
            "batch_size": [256],

            "lr": [3e-4, 1e-3],
            "weight_decay": [0.0],
            "teacher_forcing": [0.5],
            "max_norm": [1.0],

            "epochs": [10],
        }
        # much smaller: 2 * 1 * 1 * 1 * 2 * 1 * 1 * 1 * 1 = 4

    else:
        raise ValueError(f"Unknown search_type: {search_type}")

    keys = list(param_grid.keys())
    values = list(param_grid.values())
    combinations = list(itertools.product(*values))
    param_combinations = [dict(zip(keys, combo)) for combo in combinations]

    print(f"{search_type.capitalize()} Search: {len(param_combinations)} combinations")
    print(f"Estimated time: {len(param_combinations) * 5} minutes (VERY rough)")
    print("=" * 70)

    results = []
    best_score = float("inf")  # here: best_val_mse
    best_params = None

    start_time = datetime.now()

    # -------- 2) Main loop over configs --------
    for idx, params in enumerate(param_combinations):
        print("\n" + "=" * 70)
        print(f"Trial {idx + 1}/{len(param_combinations)}  (Cluster {cid})")
        elapsed_min = (datetime.now() - start_time).total_seconds() / 60
        print(f"Time elapsed: {elapsed_min:.1f} min")
        print("=" * 70)

        print("Testing configuration:")
        print(f"  hidden_dim        = {params['hidden_dim']}")
        print(f"  num_layers_enc    = {params['num_layers_encoder']}")
        print(f"  num_layers_dec    = {params['num_layers_decoder']}")
        print(f"  attn_dim          = {params['attn_dim']}")
        print(f"  batch_size        = {params['batch_size']}")
        print(f"  lr                = {params['lr']}")
        print(f"  weight_decay      = {params['weight_decay']}")
        print(f"  teacher_forcing   = {params['teacher_forcing']}")
        print(f"  max_norm          = {params['max_norm']}")
        print(f"  epochs            = {params['epochs']}")
        print()

        try:
            # ---- Build cfg for this run ----
            cfg = {
                "device": device,
                "input_dim": input_dim,
                "output_dim": output_dim,
                "hidden_dim": params["hidden_dim"],
                "num_layers_encoder": params["num_layers_encoder"],
                "num_layers_decoder": params["num_layers_decoder"],
                "attn_dim": params["attn_dim"],
                "batch_size": params["batch_size"],
                "lr": params["lr"],
                "weight_decay": params["weight_decay"],
                "teacher_forcing": params["teacher_forcing"],
                "max_norm": params["max_norm"],
                "epochs": params["epochs"],
            }

            # ---- Train once with this config ----
            train_loader, val_loader = make_loaders(cfg["batch_size"], cid)
            best_val_mse = _train_one_run(cfg, train_loader, val_loader, cid)

            score = float(best_val_mse)

            result = {
                **params,
                "best_val_mse": score,
                "score": score,  # for consistency with other scripts
                "trial": idx + 1,
            }
            results.append(result)

            print(f"Results:")
            print(f"  Best Val MSE: {best_val_mse:.6f}")
            print(f"  Score (Val MSE): {score:.6f}")

            if score < best_score:
                best_score = score
                best_params = params.copy()
                print("  ⭐ NEW BEST CONFIG FOUND!")
                # If you later change _train_one_run to also return a model,
                # you can save the model here.

            # Optionally clear cache if GPU:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        except Exception as e:
            print(f"  ❌ FAILED: {str(e)}")
            traceback.print_exc()
            result = {**params, "error": str(e), "trial": idx + 1}
            results.append(result)

    # -------- 3) Save & summarize --------
    if len(results) == 0:
        print("\n❌ No trials completed!")
        return None, None

    df = pd.DataFrame(results)

    if "score" in df.columns:
        df = df.sort_values("score", na_position="last")

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_path = os.path.join(save_dir, f"results_cluster{cid}_{timestamp}.csv")
    df.to_csv(csv_path, index=False)

    if best_params is None:
        print("\n❌ All trials failed! Check error messages above.")
        return df, None

    json_path = os.path.join(save_dir, f"best_params_cluster{cid}_{timestamp}.json")
    with open(json_path, "w") as f:
        json.dump(best_params, f, indent=2)

    total_time = (datetime.now() - start_time).total_seconds() / 60
    print("\n" + "=" * 70)
    print(f"HYPERPARAMETER SEARCH COMPLETE (Cluster {cid})")
    print("=" * 70)
    print(f"Total time: {total_time:.1f} minutes")
    successful = [r for r in results if "error" not in r]
    print(f"Successful trials: {len(successful)}/{len(results)}")

    print(f"\nBest parameters (score={best_score:.6f}):")
    for k, v in best_params.items():
        print(f"  {k}: {v}")

    print(f"\nResults saved to: {csv_path}")
    print(f"Best params saved to: {json_path}")

    # Top 5 successful configs
    successful_df = df[df["score"].notna()]
    top_cols = [
        "hidden_dim",
        "num_layers_encoder",
        "num_layers_decoder",
        "attn_dim",
        "lr",
        "teacher_forcing",
        "weight_decay",
        "best_val_mse",
        "score",
    ]
    print("\nTop 5 configurations:")
    print(successful_df[top_cols].head())

    return df, best_params


# 2 - Fetch the data

In [28]:
import itertools
device = "cuda" if torch.cuda.is_available() else "cpu"

input_dim = 5          # e.g. [UTM_x, UTM_y, SOG, v_east, v_north]
output_dim = 5         # same features for future sequence

# Train cluster 1 for example
df_traj, best_params_traj = hyperparameter_search_trajectory(
    device=device,
    input_dim=input_dim,
    output_dim=output_dim,
    cid=1,
    search_type="quick",   # or "grid"
)


HYPERPARAMETER TUNING FOR TRAJECTORY PREDICTOR (Cluster 1)
Quick Search: 4 combinations
Estimated time: 20 minutes (VERY rough)

Trial 1/4  (Cluster 1)
Time elapsed: 0.0 min
Testing configuration:
  hidden_dim        = 64
  num_layers_enc    = 1
  num_layers_dec    = 1
  attn_dim          = 64
  batch_size        = 256
  lr                = 0.0003
  weight_decay      = 0.0
  teacher_forcing   = 0.5
  max_norm          = 1.0
  epochs            = 10


Tensor shapes:
  X_train_t: torch.Size([16725, 30, 5])
  X_val_t:   torch.Size([5235, 30, 5])
  y_train_t: torch.Size([16725, 30, 5])
  y_val_t:   torch.Size([5235, 30, 5])


Cluster 1 Epoch 1/10:  24%|██▍       | 16/66 [00:05<00:16,  3.12it/s]


KeyboardInterrupt: 