# ACC → GRF Transformer: Prediction Visualization

This notebook provides interactive visualization of model predictions and attention weights.

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

from src.data_loader import CMJDataLoader
from src.transformer import SignalTransformer
from src.evaluate import evaluate_model, print_evaluation_summary
from src.biomechanics import compute_jump_height, compute_peak_power

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')

## 1. Load Data and Model

In [None]:
# Load data
loader = CMJDataLoader(use_resultant=True)
train_ds, val_ds, info = loader.create_datasets(test_size=0.2, batch_size=32)

print(f"Training samples: {info['n_train_samples']}")
print(f"Validation samples: {info['n_val_samples']}")
print(f"Input shape: {info['input_shape']}")
print(f"Output shape: {info['output_shape']}")

In [None]:
# Load trained model (update path as needed)
MODEL_PATH = '../outputs/checkpoints/best_model.keras'

# Or train a quick model for testing
TRAIN_NEW_MODEL = True

if TRAIN_NEW_MODEL:
    from src.transformer import build_signal_transformer
    model = build_signal_transformer(
        seq_len=500,
        input_dim=1,
        d_model=64,
        num_heads=4,
        num_layers=3,
    )
    # Quick training for visualization
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=10,
        verbose=1
    )
else:
    model = keras.models.load_model(MODEL_PATH)

model.summary()

## 2. Get Predictions

In [None]:
# Extract validation data
X_val_list, y_val_list = [], []
for X_batch, y_batch in val_ds:
    X_val_list.append(X_batch.numpy())
    y_val_list.append(y_batch.numpy())
X_val = np.concatenate(X_val_list, axis=0)
y_val = np.concatenate(y_val_list, axis=0)

# Get predictions
y_pred_normalized = model.predict(X_val, verbose=0)

# Denormalize
y_val_bw = loader.denormalize_grf(y_val)
y_pred_bw = loader.denormalize_grf(y_pred_normalized)

print(f"Predictions shape: {y_pred_bw.shape}")

## 3. Visualize Predicted vs Actual GRF Curves

In [None]:
def plot_grf_comparison(y_true, y_pred, indices=None, n_samples=5):
    """Plot predicted vs actual GRF curves."""
    if indices is None:
        indices = np.random.choice(len(y_true), min(n_samples, len(y_true)), replace=False)
    
    fig, axes = plt.subplots(len(indices), 1, figsize=(12, 3*len(indices)))
    if len(indices) == 1:
        axes = [axes]
    
    for i, idx in enumerate(indices):
        ax = axes[i]
        time = np.arange(500)
        
        actual = y_true[idx].flatten()
        predicted = y_pred[idx].flatten()
        
        ax.plot(time, actual, 'b-', linewidth=2, label='Actual', alpha=0.8)
        ax.plot(time, predicted, 'r--', linewidth=2, label='Predicted', alpha=0.8)
        
        # Compute metrics
        rmse = np.sqrt(np.mean((actual - predicted)**2))
        jh_actual = compute_jump_height(actual)
        jh_pred = compute_jump_height(predicted)
        
        ax.set_xlabel('Sample')
        ax.set_ylabel('GRF (BW)')
        ax.set_title(f'Sample {idx} | RMSE: {rmse:.4f} BW | JH Actual: {jh_actual*100:.1f}cm, Pred: {jh_pred*100:.1f}cm')
        ax.legend(loc='upper right')
        ax.axhline(y=1.0, color='gray', linestyle=':', alpha=0.5)
        ax.axhline(y=0.0, color='gray', linestyle=':', alpha=0.5)
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

# Plot random samples
plot_grf_comparison(y_val_bw, y_pred_bw, n_samples=5)
plt.show()

## 4. Biomechanical Metrics Scatter Plots

In [None]:
from src.biomechanics import compute_jump_metrics_batch

# Compute metrics for all samples
actual_metrics = compute_jump_metrics_batch(y_val_bw)
pred_metrics = compute_jump_metrics_batch(y_pred_bw)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Jump Height
ax = axes[0]
ax.scatter(actual_metrics['jump_height']*100, pred_metrics['jump_height']*100, 
           alpha=0.6, edgecolors='black', linewidth=0.5)

# Identity line
jh_min = min(actual_metrics['jump_height'].min(), pred_metrics['jump_height'].min()) * 100
jh_max = max(actual_metrics['jump_height'].max(), pred_metrics['jump_height'].max()) * 100
ax.plot([jh_min, jh_max], [jh_min, jh_max], 'r--', linewidth=2)

# Compute R²
ss_res = np.sum((actual_metrics['jump_height'] - pred_metrics['jump_height'])**2)
ss_tot = np.sum((actual_metrics['jump_height'] - np.mean(actual_metrics['jump_height']))**2)
r2_jh = 1 - ss_res/ss_tot if ss_tot > 0 else 0

ax.set_xlabel('Actual Jump Height (cm)')
ax.set_ylabel('Predicted Jump Height (cm)')
ax.set_title(f'Jump Height (R² = {r2_jh:.3f})')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal')

# Peak Power
ax = axes[1]
ax.scatter(actual_metrics['peak_power'], pred_metrics['peak_power'],
           alpha=0.6, edgecolors='black', linewidth=0.5)

pp_min = min(actual_metrics['peak_power'].min(), pred_metrics['peak_power'].min())
pp_max = max(actual_metrics['peak_power'].max(), pred_metrics['peak_power'].max())
ax.plot([pp_min, pp_max], [pp_min, pp_max], 'r--', linewidth=2)

ss_res = np.sum((actual_metrics['peak_power'] - pred_metrics['peak_power'])**2)
ss_tot = np.sum((actual_metrics['peak_power'] - np.mean(actual_metrics['peak_power']))**2)
r2_pp = 1 - ss_res/ss_tot if ss_tot > 0 else 0

ax.set_xlabel('Actual Peak Power (W/kg)')
ax.set_ylabel('Predicted Peak Power (W/kg)')
ax.set_title(f'Peak Power (R² = {r2_pp:.3f})')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal')

plt.tight_layout()
plt.show()

## 5. Attention Weight Visualization

In [None]:
def visualize_attention(model, x_sample, layer_idx=-1, head_idx=0):
    """Visualize attention weights for a single sample."""
    # Forward pass to get attention weights
    _ = model(x_sample[np.newaxis, ...], training=False)
    
    # Get attention weights from specified layer
    attn_weights = model.get_attention_weights(layer_idx)
    
    # Shape: (1, num_heads, seq_len, seq_len)
    attn = attn_weights[0, head_idx].numpy()  # Get first sample, specified head
    
    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(attn, cmap='viridis', aspect='auto')
    
    ax.set_xlabel('Key Position')
    ax.set_ylabel('Query Position')
    ax.set_title(f'Attention Weights (Layer {layer_idx}, Head {head_idx})')
    plt.colorbar(im, ax=ax, label='Attention Weight')
    
    return fig, attn

# Visualize attention for a random sample
sample_idx = np.random.randint(len(X_val))
fig, attn = visualize_attention(model, X_val[sample_idx], layer_idx=-1, head_idx=0)
plt.show()

print(f"\nAttention weights shape: {attn.shape}")
print(f"Attention weights sum (should be ~1 per row): {attn.sum(axis=-1)[:5]}...")

In [None]:
def plot_attention_over_signal(model, x_sample, y_true, y_pred, query_position=250):
    """Plot attention weights overlaid on signals."""
    # Get attention weights
    _ = model(x_sample[np.newaxis, ...], training=False)
    attn_weights = model.get_attention_weights(-1)
    
    # Average across heads
    attn_avg = attn_weights[0].numpy().mean(axis=0)  # (seq_len, seq_len)
    
    # Get attention for specific query position
    attn_query = attn_avg[query_position]
    
    fig, axes = plt.subplots(3, 1, figsize=(14, 8), sharex=True)
    
    time = np.arange(500)
    
    # Input signal
    ax = axes[0]
    ax.plot(time, x_sample.flatten(), 'b-', linewidth=1.5)
    ax.axvline(x=query_position, color='red', linestyle='--', alpha=0.7)
    ax.set_ylabel('ACC (normalized)')
    ax.set_title(f'Input Signal (query position: {query_position})')
    ax.grid(True, alpha=0.3)
    
    # Attention weights
    ax = axes[1]
    ax.fill_between(time, 0, attn_query, alpha=0.7)
    ax.axvline(x=query_position, color='red', linestyle='--', alpha=0.7)
    ax.set_ylabel('Attention Weight')
    ax.set_title('Attention Distribution (averaged across heads)')
    ax.grid(True, alpha=0.3)
    
    # Output comparison
    ax = axes[2]
    ax.plot(time, y_true, 'b-', linewidth=2, label='Actual', alpha=0.8)
    ax.plot(time, y_pred, 'r--', linewidth=2, label='Predicted', alpha=0.8)
    ax.axvline(x=query_position, color='red', linestyle='--', alpha=0.7)
    ax.set_xlabel('Sample')
    ax.set_ylabel('GRF (BW)')
    ax.set_title('Output Comparison')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

# Interactive visualization for a sample
sample_idx = np.random.randint(len(X_val))
plot_attention_over_signal(
    model, 
    X_val[sample_idx], 
    y_val_bw[sample_idx].flatten(),
    y_pred_bw[sample_idx].flatten(),
    query_position=300  # Position during propulsion phase
)
plt.show()

## 6. Full Evaluation Summary

In [None]:
# Run comprehensive evaluation
results = evaluate_model(model, X_val, y_val, loader)
print_evaluation_summary(results)

## 7. Best and Worst Predictions

In [None]:
# Compute per-sample RMSE
sample_rmse = np.array([
    np.sqrt(np.mean((y_val_bw[i] - y_pred_bw[i])**2))
    for i in range(len(y_val_bw))
])

# Best predictions (lowest RMSE)
best_indices = np.argsort(sample_rmse)[:5]
print("Best predictions (lowest RMSE):")
for idx in best_indices:
    print(f"  Sample {idx}: RMSE = {sample_rmse[idx]:.4f} BW")

fig = plot_grf_comparison(y_val_bw, y_pred_bw, indices=best_indices)
fig.suptitle('Best Predictions', y=1.02, fontsize=14)
plt.show()

# Worst predictions (highest RMSE)
worst_indices = np.argsort(sample_rmse)[-5:]
print("\nWorst predictions (highest RMSE):")
for idx in worst_indices:
    print(f"  Sample {idx}: RMSE = {sample_rmse[idx]:.4f} BW")

fig = plot_grf_comparison(y_val_bw, y_pred_bw, indices=worst_indices)
fig.suptitle('Worst Predictions', y=1.02, fontsize=14)
plt.show()