In [1]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
from pathlib import Path
from tqdm import tqdm

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

ModuleNotFoundError: No module named 'seaborn'

## 1 - Load Data and Models

In [None]:
# Data paths
BASE = "../../data/aisdk/processed"
TEST_DATA_PATH = os.path.join(BASE, "windows/test_trajectories.npz")
MODELS_DIR = "../../checkpoints/trained_models"

# Load test data
print("Loading test data...")
test_data = np.load(TEST_DATA_PATH)
X_test = test_data["past"]           # (N, 30, 5) - past trajectories
y_test = test_data["future"]         # (N, 30, 5) - future trajectories
c_test = test_data["cluster"]        # (N,) - cluster labels

print(f"Test set shapes:")
print(f"  X_test (past):   {X_test.shape}")
print(f"  y_test (future): {y_test.shape}")
print(f"  c_test (labels): {c_test.shape}")
print(f"\nUnique clusters: {np.unique(c_test)}")
print(f"Cluster distribution:\n{pd.Series(c_test).value_counts().sort_index()}")

In [None]:
# Import model classes
import sys
sys.path.insert(0, os.getcwd())

from classification_rnn import ClassificationRNN
from trajectory_predictor import TrajectoryPredictor, trajectory_loss

# Model paths (update these with your actual model paths)
classification_model_path = os.path.join(MODELS_DIR, "classification_rnn_best.pt")
trajectory_model_dir = os.path.join(MODELS_DIR, "trajectory_predictors")  # per-cluster models

print(f"Classification model path: {classification_model_path}")
print(f"Trajectory models dir: {trajectory_model_dir}")
print(f"\nModel paths exist:")
print(f"  Classification: {os.path.exists(classification_model_path)}")
print(f"  Trajectory dir: {os.path.isdir(trajectory_model_dir)}")

## 2 - Classification RNN Evaluation

In [None]:
print("\n" + "="*70)
print("CLASSIFICATION RNN EVALUATION")
print("="*70)

# Load classification model
num_clusters = len(np.unique(c_test))
input_dim = X_test.shape[-1]

classifier = ClassificationRNN(
    input_dim=input_dim,
    hidden_dim=64,
    num_layers=1,
    output_dim=num_clusters
).to(device)

# Load weights
classifier.load_state_dict(torch.load(classification_model_path, map_location=device))
classifier.eval()
print(f"-> Loaded classification model")

# Run inference
print("\nRunning inference on test set...")
X_test_t = torch.tensor(X_test, dtype=torch.float32).to(device)
c_test_t = torch.tensor(c_test, dtype=torch.long).to(device)

with torch.no_grad():
    logits = classifier(X_test_t)  # (N, num_clusters)
    c_pred = torch.argmax(logits, dim=1).cpu().numpy()
    probs = torch.softmax(logits, dim=1).cpu().numpy()

print(f"Predictions shape: {c_pred.shape}")
print(f"Probabilities shape: {probs.shape}")

In [None]:
# Compute metrics
accuracy = accuracy_score(c_test, c_pred)
precision = precision_score(c_test, c_pred, average='weighted', zero_division=0)
recall = recall_score(c_test, c_pred, average='weighted', zero_division=0)
f1 = f1_score(c_test, c_pred, average='weighted', zero_division=0)

print(f"\n Classification Metrics:")
print(f"  Accuracy:  {accuracy:.4f}")
print(f"  Precision: {precision:.4f}")
print(f"  Recall:    {recall:.4f}")
print(f"  F1-Score:  {f1:.4f}")

# Per-cluster metrics
print(f"\n Per-Cluster Performance:")
for cid in np.unique(c_test):
    mask = c_test == cid
    acc = accuracy_score(c_test[mask], c_pred[mask])
    prec = precision_score(c_test[mask], c_pred[mask], average='weighted', zero_division=0)
    rec = recall_score(c_test[mask], c_pred[mask], average='weighted', zero_division=0)
    f1_c = f1_score(c_test[mask], c_pred[mask], average='weighted', zero_division=0)
    count = mask.sum()
    print(f"  Cluster {cid}: Acc={acc:.3f}, Prec={prec:.3f}, Rec={rec:.3f}, F1={f1_c:.3f} (n={count})")

In [None]:
# Confusion matrix
cm = confusion_matrix(c_test, c_pred)

fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, cbar_kws={'label': 'Count'})
ax.set_xlabel('Predicted Cluster')
ax.set_ylabel('True Cluster')
ax.set_title('Classification RNN: Confusion Matrix on Test Set')
plt.tight_layout()
plt.savefig('classification_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()
print("-> Confusion matrix saved")

In [None]:
# Prediction confidence analysis
max_probs = probs.max(axis=1)
correct = (c_pred == c_test)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Confidence distribution
axes[0].hist(max_probs[correct], bins=30, alpha=0.7, label='Correct', color='green')
axes[0].hist(max_probs[~correct], bins=30, alpha=0.7, label='Incorrect', color='red')
axes[0].set_xlabel('Max Probability')
axes[0].set_ylabel('Count')
axes[0].set_title('Prediction Confidence Distribution')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Accuracy vs confidence
confidence_bins = np.linspace(0, 1, 11)
bin_accs = []
bin_counts = []
for i in range(len(confidence_bins)-1):
    mask = (max_probs >= confidence_bins[i]) & (max_probs < confidence_bins[i+1])
    if mask.sum() > 0:
        bin_acc = (c_pred[mask] == c_test[mask]).mean()
        bin_accs.append(bin_acc)
        bin_counts.append(mask.sum())
    else:
        bin_accs.append(0)
        bin_counts.append(0)

bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2
axes[1].bar(bin_centers, bin_accs, width=0.08, color='steelblue', alpha=0.7)
axes[1].axhline(y=accuracy, color='r', linestyle='--', label=f'Overall Accuracy: {accuracy:.3f}')
axes[1].set_xlabel('Confidence Bin')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy vs Prediction Confidence')
axes[1].set_ylim([0, 1.05])
axes[1].legend()
axes[1].grid(alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('classification_confidence_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print("->Confidence analysis saved")

## 3 - Trajectory Predictor Evaluation

In [None]:
print("\n" + "="*70)
print("TRAJECTORY PREDICTOR EVALUATION")
print("="*70)

# Load per-cluster trajectory models
trajectory_models = {}
print("\nLoading per-cluster trajectory models...")

for cid in np.unique(c_test):
    model_path = os.path.join(trajectory_model_dir, f"trajectory_cluster_{cid}.pt")
    
    if not os.path.exists(model_path):
        print(f"Cluster {cid}: Model not found at {model_path}")
        continue
    
    # Create model
    traj_model = TrajectoryPredictor(
        input_dim=input_dim,
        hidden_dim=64,
        output_dim=input_dim,
        num_layers_encoder=1,
        num_layers_decoder=1,
        attn_dim=64
    ).to(device)
    
    # Load weights
    traj_model.load_state_dict(torch.load(model_path, map_location=device))
    traj_model.eval()
    trajectory_models[cid] = traj_model
    print(f"-> Cluster {cid} model loaded")

print(f"\nLoaded {len(trajectory_models)} trajectory models")

In [None]:
# Run trajectory predictions
print("\nRunning trajectory predictions on test set...")

y_pred_list = []
valid_indices = []

for i, (X_i, y_i, c_i) in enumerate(zip(X_test, y_test, c_test)):
    if c_i not in trajectory_models:
        continue  # Skip if model is not available
    
    model = trajectory_models[c_i]
    X_i_t = torch.tensor(X_i, dtype=torch.float32).unsqueeze(0).to(device)  # (1, 30, 5)
    
    with torch.no_grad():
        y_pred_i = model(
            X_i_t,
            target_length=30,
            targets=None,
            teacher_forcing_ratio=0.0  # No teacher forcing for inference
        )  # (1, 30, 5)
    
    y_pred_list.append(y_pred_i.cpu().numpy().squeeze())
    valid_indices.append(i)

y_pred = np.array(y_pred_list)  # (M, 30, 5) where M <= N
y_test_valid = y_test[valid_indices]
c_test_valid = c_test[valid_indices]

print(f"\nPredictions computed for {len(y_pred)}/{len(y_test)} samples")
print(f"  y_pred shape: {y_pred.shape}")
print(f"  y_test_valid shape: {y_test_valid.shape}")

In [None]:
# Compute trajectory metrics
def compute_trajectory_metrics(y_true, y_pred):
    """Compute trajectory prediction metrics."""
    mse = np.mean((y_true - y_pred) ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(y_true - y_pred))
    
    # Per-timestep error
    timestep_mse = np.mean((y_true - y_pred) ** 2, axis=(0, 2))  # (30,)
    
    return mse, rmse, mae, timestep_mse

# Overall metrics
mse, rmse, mae, ts_mse = compute_trajectory_metrics(y_test_valid, y_pred)

print(f"\n Trajectory Prediction Metrics (Test Set):")
print(f"MSE:  {mse:.6f}")
print(f"RMSE: {rmse:.6f}")
print(f"MAE:  {mae:.6f}")

# Per-cluster metrics
print(f"\n Per-Cluster Trajectory Performance:")
for cid in np.unique(c_test_valid):
    mask = c_test_valid == cid
    c_mse, c_rmse, c_mae, _ = compute_trajectory_metrics(y_test_valid[mask], y_pred[mask])
    count = mask.sum()
    print(f"  Cluster {cid}: MSE={c_mse:.6f}, RMSE={c_rmse:.6f}, MAE={c_mae:.6f} (n={count})")

In [None]:
# Per-timestep error analysis
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Timestep MSE
axes[0, 0].plot(ts_mse, marker='o', linewidth=2, markersize=4)
axes[0, 0].set_xlabel('Timestep (minutes)')
axes[0, 0].set_ylabel('MSE')
axes[0, 0].set_title('Prediction Error by Timestep')
axes[0, 0].grid(alpha=0.3)

# Per-feature error
feature_names = ['utm_x', 'utm_y', 'SOG', 'v_east', 'v_north']
feature_mse = np.mean((y_test_valid - y_pred) ** 2, axis=(0, 1))  # (5,)
axes[0, 1].bar(range(len(feature_names)), feature_mse, color='steelblue', alpha=0.7)
axes[0, 1].set_xticks(range(len(feature_names)))
axes[0, 1].set_xticklabels(feature_names, rotation=45)
axes[0, 1].set_ylabel('MSE')
axes[0, 1].set_title('Error by Feature')
axes[0, 1].grid(alpha=0.3, axis='y')

# Error distribution
errors = np.sqrt(np.mean((y_test_valid - y_pred) ** 2, axis=(1, 2)))  # (M,) - RMSE per sample
axes[1, 0].hist(errors, bins=30, color='steelblue', alpha=0.7, edgecolor='black')
axes[1, 0].axvline(rmse, color='r', linestyle='--', linewidth=2, label=f'Mean RMSE: {rmse:.4f}')
axes[1, 0].set_xlabel('RMSE per Sample')
axes[1, 0].set_ylabel('Count')
axes[1, 0].set_title('Distribution of Prediction Errors')
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3, axis='y')

# Per-cluster comparison
cluster_mses = []
cluster_labels = []
for cid in np.unique(c_test_valid):
    mask = c_test_valid == cid
    c_mse, _, _, _ = compute_trajectory_metrics(y_test_valid[mask], y_pred[mask])
    cluster_mses.append(c_mse)
    cluster_labels.append(f'C{cid}')

axes[1, 1].bar(range(len(cluster_labels)), cluster_mses, color='steelblue', alpha=0.7)
axes[1, 1].set_xticks(range(len(cluster_labels)))
axes[1, 1].set_xticklabels(cluster_labels)
axes[1, 1].set_ylabel('MSE')
axes[1, 1].set_title('Prediction Error by Cluster')
axes[1, 1].grid(alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('trajectory_error_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print("-> Error analysis saved")

In [None]:
# Visualize sample predictions
num_samples = 6
fig, axes = plt.subplots(num_samples, 5, figsize=(16, 12))

np.random.seed(42)
sample_indices = np.random.choice(len(y_pred), num_samples, replace=False)

for row, idx in enumerate(sample_indices):
    y_true_sample = y_test_valid[idx]  # (30, 5)
    y_pred_sample = y_pred[idx]        # (30, 5)
    
    for col in range(5):
        ax = axes[row, col]
        timesteps = np.arange(30)
        
        ax.plot(timesteps, y_true_sample[:, col], 'o-', label='True', linewidth=2, markersize=4)
        ax.plot(timesteps, y_pred_sample[:, col], 's--', label='Pred', linewidth=2, markersize=4)
        
        if row == 0:
            ax.set_title(feature_names[col], fontsize=12, fontweight='bold')
        if col == 0:
            ax.set_ylabel(f'Sample {idx}', fontsize=11)
        if row == num_samples - 1:
            ax.set_xlabel('Timestep', fontsize=10)
        
        ax.grid(alpha=0.3)
        if row == 0 and col == 4:
            ax.legend(loc='upper right')

plt.tight_layout()
plt.savefig('trajectory_sample_predictions.png', dpi=150, bbox_inches='tight')
plt.show()
print("-> Sample predictions visualization saved")

## 4 - Final Summary Report

In [None]:
# Generate comprehensive report
report = f"""
{'='*70}
FINAL MODEL EVALUATION REPORT
{'='*70}

TEST SET OVERVIEW
{'-'*70}
Total samples:        {len(X_test)}
Test samples used:    {len(y_pred)}
Features:             {input_dim}
Sequence length:      30 timesteps
Unique clusters:      {len(np.unique(c_test))}

CLASSIFICATION RNN RESULTS
{'-'*70}
Accuracy:             {accuracy:.4f} ({accuracy*100:.2f}%)
Precision (weighted): {precision:.4f}
Recall (weighted):    {recall:.4f}
F1-Score (weighted):  {f1:.4f}

TRAJECTORY PREDICTOR RESULTS
{'-'*70}
Mean Squared Error:   {mse:.6f}
Root Mean Sq. Error:  {rmse:.6f}
Mean Absolute Error:  {mae:.6f}
Best timestep:        T={np.argmin(ts_mse)} (MSE={np.min(ts_mse):.6f})
Worst timestep:       T={np.argmax(ts_mse)} (MSE={np.max(ts_mse):.6f})

FEATURE-WISE PREDICTION ERROR
{'-'*70}
"""

for i, fname in enumerate(feature_names):
    report += f"{fname:12s}: MSE={feature_mse[i]:.6f}\n"

report += f"""
VISUALIZATIONS GENERATED
{'-'*70}
-> classification_confusion_matrix.png
-> classification_confidence_analysis.png
-> trajectory_error_analysis.png
-> trajectory_sample_predictions.png

INTERPRETATION NOTES
{'-'*70}
• Classification accuracy {('GOOD' if accuracy > 0.7 else 'MODERATE' if accuracy > 0.5 else 'POOR')} ({accuracy:.1%})
• Trajectory RMSE: {rmse:.6f} (units: meters for spatial, m/s for velocity)
• Prediction quality {'improves' if ts_mse[-1] < ts_mse[0] else 'degrades'} over prediction horizon
• Error is relatively {'low' if rmse < 50 else 'moderate' if rmse < 100 else 'high'} for maritime data

{'='*70}
"""

print(report)

# Save report
with open('test_report.txt', 'w') as f:
    f.write(report)

print("\n-> Report saved to test_report.txt")

In [None]:
# Save detailed metrics to CSV
results_df = pd.DataFrame({
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'MSE', 'RMSE', 'MAE'],
    'Value': [accuracy, precision, recall, f1, mse, rmse, mae]
})

results_df.to_csv('test_metrics.csv', index=False)
print("-> Metrics saved to test_metrics.csv")
print("\nFinal Metrics Summary:")
print(results_df.to_string(index=False))