# Temporal Fusion Transformer (TFT) - Complete Tutorial

This notebook provides a comprehensive guide to using the Temporal Fusion Transformer for time series forecasting. We'll cover:

1. **Understanding TFT Architecture** - What makes TFT special
2. **Data Preparation** - Generating and preparing time series data
3. **Model Configuration** - Setting up the model parameters
4. **Training** - Training the model with callbacks
5. **Prediction** - Making forecasts
6. **Interpretability** - Understanding model decisions through attention and variable importance

---

## 1. TFT Architecture Overview

The Temporal Fusion Transformer (TFT) is a state-of-the-art deep learning architecture for **interpretable multi-horizon time series forecasting**.

### Key Features:
- **Multi-horizon forecasting**: Predicts multiple future time steps simultaneously
- **Probabilistic predictions**: Outputs quantiles for uncertainty estimation
- **Interpretability**: Built-in attention mechanisms and variable selection
- **Heterogeneous inputs**: Handles static, known future, and observed-only features

### Architecture Flow:

```
┌─────────────────────────────────────────────────────────────────────────────┐
│                         TEMPORAL FUSION TRANSFORMER                          │
└─────────────────────────────────────────────────────────────────────────────┘

              ┌──────────────┐   ┌──────────────┐   ┌──────────────┐
              │   STATIC     │   │ KNOWN FUTURE │   │  OBSERVED    │
              │  FEATURES    │   │   FEATURES   │   │   FEATURES   │
              │  (e.g., ID)  │   │ (e.g., time) │   │ (e.g., temp) │
              └──────┬───────┘   └──────┬───────┘   └──────┬───────┘
                     │                  │                  │
                     ▼                  │                  │
        ┌────────────────────────┐      │                  │
        │   STATIC COVARIATE     │      │                  │
        │       ENCODER          │      │                  │
        │  (Creates 4 context    │      │                  │
        │   vectors for gating)  │      │                  │
        └───────────┬────────────┘      │                  │
                    │                   │                  │
         Context    │    ┌──────────────┴──────────────────┤
         Vectors    │    │                                 │
                    ▼    ▼                                 ▼
        ┌─────────────────────────────────────────────────────────────┐
        │              VARIABLE SELECTION NETWORKS                     │
        │   ┌─────────────────────┐   ┌─────────────────────┐         │
        │   │  Encoder Variables  │   │  Decoder Variables  │         │
        │   │  (Historical data)  │   │  (Future known)     │         │
        │   │                     │   │                     │         │
        │   │  Softmax weights    │   │  Softmax weights    │         │
        │   │  select important   │   │  select important   │         │
        │   │  features           │   │  features           │         │
        │   └──────────┬──────────┘   └──────────┬──────────┘         │
        └──────────────┼──────────────────────────┼───────────────────┘
                       │                          │
                       ▼                          ▼
        ┌─────────────────────────────────────────────────────────────┐
        │                    LSTM ENCODER-DECODER                      │
        │   ┌─────────────────────┐   ┌─────────────────────┐         │
        │   │    LSTM Encoder     │──▶│    LSTM Decoder     │         │
        │   │  (Processes past)   │   │  (Processes future) │         │
        │   └─────────────────────┘   └─────────────────────┘         │
        └──────────────────────────────┬──────────────────────────────┘
                                       │
                                       ▼
        ┌─────────────────────────────────────────────────────────────┐
        │              STATIC ENRICHMENT LAYER                         │
        │         (Enhances temporal features with                     │
        │          static context via GRN)                             │
        └──────────────────────────────┬──────────────────────────────┘
                                       │
                                       ▼
        ┌─────────────────────────────────────────────────────────────┐
        │           INTERPRETABLE MULTI-HEAD ATTENTION                 │
        │                                                              │
        │   • Uses ADDITIVE attention (not dot-product)               │
        │   • Allows inspection of temporal relationships             │
        │   • Shows which past time steps matter most                 │
        │                                                              │
        └──────────────────────────────┬──────────────────────────────┘
                                       │
                                       ▼
        ┌─────────────────────────────────────────────────────────────┐
        │              POSITION-WISE FEED-FORWARD                      │
        │               (GRN for each position)                        │
        └──────────────────────────────┬──────────────────────────────┘
                                       │
                                       ▼
        ┌─────────────────────────────────────────────────────────────┐
        │                 QUANTILE OUTPUT HEADS                        │
        │                                                              │
        │   ┌─────────┐  ┌─────────┐  ┌─────────┐                     │
        │   │  Q=0.1  │  │  Q=0.5  │  │  Q=0.9  │   ...               │
        │   │ (Lower) │  │(Median) │  │ (Upper) │                     │
        │   └────┬────┘  └────┬────┘  └────┬────┘                     │
        └────────┼────────────┼────────────┼──────────────────────────┘
                 │            │            │
                 ▼            ▼            ▼
        ┌─────────────────────────────────────────────────────────────┐
        │                    PREDICTIONS                               │
        │        (Multi-horizon probabilistic forecasts)               │
        └─────────────────────────────────────────────────────────────┘
```

### Key Building Blocks:

| Component | Purpose |
|-----------|--------|
| **GRN (Gated Residual Network)** | Non-linear processing with skip connections |
| **GLU (Gated Linear Unit)** | Controls information flow through gating |
| **Variable Selection Network** | Learns which features are important |
| **Multi-Head Attention** | Captures temporal dependencies |
| **Quantile Outputs** | Provides prediction intervals |

## 2. Setup and Imports

Let's start by importing the necessary modules.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns

# TFT imports
from tft.models import TemporalFusionTransformer
from tft.data import create_dataloaders
from tft.training import TFTTrainer, EarlyStopping, ModelCheckpoint
from tft.utils import TFTConfig, get_device, print_device_info
from tft.utils.visualization import (
    plot_predictions, 
    plot_training_history,
    plot_attention_weights,
    plot_variable_importance,
    plot_forecast_fan,
)
from tft.interpret import (
    extract_attention_weights,
    average_attention_over_heads,
    plot_attention_heatmap,
    plot_attention_by_head,
    analyze_attention_focus,
    extract_variable_selection_weights,
    rank_variables_by_importance,
    plot_temporal_variable_importance,
)

# Set style for plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Check available device
print("=" * 60)
print("Device Information")
print("=" * 60)
print_device_info()
device = get_device('auto')
print(f"\nUsing device: {device}")

## 3. Data Generation

We'll create synthetic time series data that demonstrates TFT's ability to handle:
- **Static features**: Time-invariant characteristics (e.g., series ID)
- **Known future features**: Features known at forecast time (e.g., time of day)
- **Observed features**: Historical-only features (e.g., temperature)

### Data Structure:
```
┌─────────────────────────────────────────────────────────────┐
│                    TIME SERIES DATA                          │
├─────────────────────────────────────────────────────────────┤
│  series_id │ time_idx │ target │ hour_sin │ temperature │   │
│  (static)  │          │        │ (known)  │ (observed)  │   │
├────────────┼──────────┼────────┼──────────┼─────────────┤   │
│     0      │    0     │  5.2   │   0.26   │    21.3     │   │
│     0      │    1     │  5.8   │   0.50   │    22.1     │   │
│     0      │    2     │  6.1   │   0.71   │    23.0     │   │
│    ...     │   ...    │  ...   │   ...    │    ...      │   │
└─────────────────────────────────────────────────────────────┘
```

In [None]:
def generate_synthetic_data(
    num_samples: int = 1000,
    num_series: int = 5,
    noise_level: float = 0.1,
    seed: int = 42,
) -> pd.DataFrame:
    """
    Generate synthetic time series data with trend, seasonality, and noise.
    
    The data simulates multiple related time series (e.g., different stores/sensors)
    with common patterns but individual characteristics.
    """
    np.random.seed(seed)
    data_list = []

    for series_id in range(num_series):
        # Time index
        time_idx = np.arange(num_samples)

        # Components of the target signal
        trend = 0.02 * time_idx  # Linear trend
        seasonality = 10 * np.sin(2 * np.pi * time_idx / 50)  # Seasonal pattern (period=50)
        noise = noise_level * np.random.randn(num_samples)  # Random noise

        # Target value: combination of components + series-specific offset
        target = trend + seasonality + noise + series_id * 2

        # Known future features (time-based, known at forecast time)
        hour = (time_idx % 24).astype(float)
        day_of_week = ((time_idx // 24) % 7).astype(float)
        
        # Cyclical encoding for hour
        hour_sin = np.sin(2 * np.pi * hour / 24)
        hour_cos = np.cos(2 * np.pi * hour / 24)

        # Observed feature (correlated with target, only known historically)
        temperature = 20 + 5 * np.sin(2 * np.pi * time_idx / 50) + np.random.randn(num_samples) * 0.5

        series_df = pd.DataFrame({
            'series_id': series_id,
            'time_idx': time_idx,
            'target': target,
            'hour': hour,
            'day_of_week': day_of_week,
            'hour_sin': hour_sin,
            'hour_cos': hour_cos,
            'temperature': temperature,
        })
        data_list.append(series_df)

    return pd.concat(data_list, axis=0, ignore_index=True)


# Generate data
df = generate_synthetic_data(num_samples=1000, num_series=5, noise_level=0.5, seed=42)

print(f"Dataset shape: {df.shape}")
print(f"Number of series: {df['series_id'].nunique()}")
print(f"Time steps per series: {df.groupby('series_id').size().iloc[0]}")
print(f"\nColumns: {list(df.columns)}")
print(f"\nData types:\n{df.dtypes}")

In [None]:
# Preview the data
df.head(10)

### Visualize the Data

Let's visualize the synthetic time series to understand its components.

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Target for all series
ax1 = axes[0, 0]
for series_id in df['series_id'].unique():
    series_data = df[df['series_id'] == series_id]
    ax1.plot(series_data['time_idx'], series_data['target'], label=f'Series {series_id}', alpha=0.7)
ax1.set_xlabel('Time Index')
ax1.set_ylabel('Target Value')
ax1.set_title('Target Time Series (All Series)')
ax1.legend()

# Plot 2: Single series with components
ax2 = axes[0, 1]
series_0 = df[df['series_id'] == 0].head(200)
ax2.plot(series_0['time_idx'], series_0['target'], 'b-', label='Target', linewidth=2)
ax2.plot(series_0['time_idx'], series_0['temperature'], 'r--', label='Temperature', alpha=0.7)
ax2.set_xlabel('Time Index')
ax2.set_ylabel('Value')
ax2.set_title('Series 0: Target vs Temperature (Correlated)')
ax2.legend()

# Plot 3: Known future features (cyclical encoding)
ax3 = axes[1, 0]
sample = df[df['series_id'] == 0].head(100)
ax3.plot(sample['time_idx'], sample['hour_sin'], 'g-', label='Hour (sin)', linewidth=2)
ax3.plot(sample['time_idx'], sample['hour_cos'], 'orange', label='Hour (cos)', linewidth=2)
ax3.set_xlabel('Time Index')
ax3.set_ylabel('Encoded Value')
ax3.set_title('Known Future Features (Cyclical Hour Encoding)')
ax3.legend()

# Plot 4: Distribution of target
ax4 = axes[1, 1]
for series_id in df['series_id'].unique():
    series_data = df[df['series_id'] == series_id]
    ax4.hist(series_data['target'], bins=30, alpha=0.5, label=f'Series {series_id}')
ax4.set_xlabel('Target Value')
ax4.set_ylabel('Frequency')
ax4.set_title('Distribution of Target Values')
ax4.legend()

plt.tight_layout()
plt.show()

## 4. Data Splitting

We split the data temporally to ensure no data leakage:

```
├────────────────────────────────────────────────────────────────────┤
│              TRAIN (70%)            │   VAL (15%)   │  TEST (15%)  │
├────────────────────────────────────────────────────────────────────┤
                                      ↑               ↑
                               train_end_idx    val_end_idx
```

In [None]:
def prepare_data(df: pd.DataFrame) -> tuple:
    """Split data temporally: 70% train, 15% val, 15% test."""
    unique_times = sorted(df['time_idx'].unique())
    n_times = len(unique_times)

    train_end = int(n_times * 0.7)
    val_end = int(n_times * 0.85)

    train_times = unique_times[:train_end]
    val_times = unique_times[train_end:val_end]
    test_times = unique_times[val_end:]

    train_df = df[df['time_idx'].isin(train_times)].copy()
    val_df = df[df['time_idx'].isin(val_times)].copy()
    test_df = df[df['time_idx'].isin(test_times)].copy()

    return train_df, val_df, test_df


train_df, val_df, test_df = prepare_data(df)

print("Data Split Summary:")
print(f"  Train: {len(train_df):,} samples ({len(train_df)/len(df)*100:.1f}%)")
print(f"  Val:   {len(val_df):,} samples ({len(val_df)/len(df)*100:.1f}%)")
print(f"  Test:  {len(test_df):,} samples ({len(test_df)/len(df)*100:.1f}%)")
print(f"\nTime ranges:")
print(f"  Train: {train_df['time_idx'].min()} - {train_df['time_idx'].max()}")
print(f"  Val:   {val_df['time_idx'].min()} - {val_df['time_idx'].max()}")
print(f"  Test:  {test_df['time_idx'].min()} - {test_df['time_idx'].max()}")

## 5. Model Configuration

TFT requires careful configuration of:
1. **Input features** - Which features are static, known future, or observed-only
2. **Sequence lengths** - How much history to use (encoder) and how far to forecast (decoder)
3. **Architecture** - Hidden dimensions, attention heads, LSTM layers
4. **Quantiles** - Which prediction intervals to output

In [None]:
config = TFTConfig(
    # === Input Features ===
    # Static: Time-invariant features (constant for each series)
    static_variables=['series_id'],
    
    # Known Future: Features known at forecast time (e.g., calendar features)
    known_future=['hour_sin', 'hour_cos', 'day_of_week'],
    
    # Observed Only: Features only available historically (not known in future)
    observed_only=['temperature'],
    
    # Target variable to forecast
    target='target',

    # === Sequence Configuration ===
    encoder_length=50,   # Look back 50 time steps
    decoder_length=10,   # Forecast 10 steps ahead

    # === Architecture ===
    hidden_size=64,      # Hidden dimension (larger = more capacity)
    num_heads=4,         # Number of attention heads
    num_lstm_layers=1,   # LSTM depth
    dropout=0.1,         # Regularization

    # === Quantiles for Probabilistic Forecasting ===
    quantiles=[0.1, 0.5, 0.9],  # 10th, 50th (median), 90th percentile

    # === Training ===
    batch_size=32,
    learning_rate=1e-3,
    max_epochs=20,  # Increased for better convergence
    gradient_clip_val=1.0,
)

print("TFT Configuration:")
print("=" * 50)
print(f"\nInput Features:")
print(f"  Static variables:    {config.static_variables}")
print(f"  Known future:        {config.known_future}")
print(f"  Observed only:       {config.observed_only}")
print(f"  Target:              {config.target}")
print(f"\nSequence Configuration:")
print(f"  Encoder length:      {config.encoder_length} steps")
print(f"  Decoder length:      {config.decoder_length} steps")
print(f"  Total window:        {config.encoder_length + config.decoder_length} steps")
print(f"\nArchitecture:")
print(f"  Hidden size:         {config.hidden_size}")
print(f"  Attention heads:     {config.num_heads}")
print(f"  LSTM layers:         {config.num_lstm_layers}")
print(f"  Dropout:             {config.dropout}")
print(f"\nQuantiles:             {config.quantiles}")

### Understanding Encoder/Decoder Split

```
┌────────────────────────────────────────────────────────────────────────────┐
│                         INPUT SEQUENCE WINDOW                              │
├────────────────────────────────────┬───────────────────────────────────────┤
│         ENCODER (Historical)       │        DECODER (Forecast)             │
│         encoder_length = 50        │        decoder_length = 10            │
├────────────────────────────────────┼───────────────────────────────────────┤
│                                    │                                       │
│  ← Static features available →     │  ← Static features available →        │
│  ← Known future features →         │  ← Known future features →            │
│  ← Observed features →             │  ✗ Observed features NOT available    │
│  ← Target (historical) →           │  ← Target (to predict) →              │
│                                    │                                       │
│     t-49  t-48  ...  t-1  t=0      │     t+1  t+2  ...  t+9  t+10         │
└────────────────────────────────────┴───────────────────────────────────────┘
```

## 6. Create Data Loaders

Data loaders handle windowing the time series into sequences suitable for the model.

In [None]:
train_loader, val_loader, test_loader = create_dataloaders(
    train_data=train_df,
    val_data=val_df,
    test_data=test_df,
    config=config,
    batch_size=config.batch_size,
    num_workers=0,
)

print("DataLoader Summary:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")
print(f"  Test batches:  {len(test_loader)}")

# Inspect a batch
sample_batch = next(iter(train_loader))
print(f"\nSample batch contents:")
for key, value in sample_batch.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: shape={value.shape}, dtype={value.dtype}")

## 7. Create and Inspect the Model

Let's create the TFT model and examine its architecture.

In [None]:
model = TemporalFusionTransformer(config)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model Summary:")
print("=" * 50)
print(f"Total parameters:     {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"\nModel Architecture:")
print(model)

## 8. Training

We'll train the model with:
- **Early Stopping**: Stop if validation loss doesn't improve
- **Model Checkpoint**: Save the best model

The loss function is **Quantile Loss** (Pinball Loss), which trains the model to predict multiple quantiles simultaneously.

In [None]:
# Create trainer
trainer = TFTTrainer(
    model=model,
    config=config,
    device=device,
)

# Setup callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=5,
        min_delta=1e-4,
        verbose=True,
    ),
    ModelCheckpoint(
        filepath='tft_best_model.pth',
        monitor='val_loss',
        save_best_only=True,
        verbose=True,
    ),
]

print("Starting training...")
print("=" * 60)

In [None]:
# Train the model
trainer.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=config.max_epochs,
    callbacks=callbacks,
)

print("\n" + "=" * 60)
print("Training completed!")
print(f"Best validation loss: {trainer.best_val_loss:.6f}")

### Visualize Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax1 = axes[0]
epochs = range(1, len(trainer.history['train_loss']) + 1)
ax1.plot(epochs, trainer.history['train_loss'], 'b-', label='Training Loss', linewidth=2)
ax1.plot(epochs, trainer.history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Quantile Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Loss difference (overfitting indicator)
ax2 = axes[1]
loss_diff = np.array(trainer.history['val_loss']) - np.array(trainer.history['train_loss'])
colors = ['green' if d < 0.1 else 'orange' if d < 0.3 else 'red' for d in loss_diff]
ax2.bar(epochs, loss_diff, color=colors, alpha=0.7)
ax2.axhline(y=0, color='black', linestyle='--', linewidth=1)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Val Loss - Train Loss')
ax2.set_title('Generalization Gap (Lower is Better)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Evaluation and Prediction

Let's evaluate the model on the test set and generate predictions.

In [None]:
# Evaluate on test set
print("Evaluating on test set...")
test_metrics = trainer.validate(test_loader)

print("\nTest Set Metrics:")
print("=" * 40)
for metric, value in test_metrics.items():
    print(f"  {metric}: {value:.6f}")

In [None]:
# Generate predictions with attention and variable selection weights
print("Generating predictions...")
results = trainer.predict(
    test_loader,
    return_attention=True,
    return_variable_selection=True,
)

predictions = results['predictions']
targets = results['targets']

print(f"\nPredictions shape: {predictions.shape}")
print(f"  - Batch dimension: {predictions.shape[0]} samples")
print(f"  - Time dimension: {predictions.shape[1]} forecast steps")
print(f"  - Quantile dimension: {predictions.shape[2]} quantiles {config.quantiles}")
print(f"\nTargets shape: {targets.shape}")

### Visualize Predictions

TFT outputs **quantile predictions**, providing uncertainty estimates:
- **Q0.1**: 10th percentile (lower bound)
- **Q0.5**: 50th percentile (median prediction)
- **Q0.9**: 90th percentile (upper bound)

In [None]:
# Plot predictions for multiple samples
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for idx, ax in enumerate(axes.flat):
    sample_idx = idx * 10  # Sample every 10th prediction
    
    pred = predictions[sample_idx].cpu().numpy()
    target = targets[sample_idx].cpu().numpy()
    
    time_steps = np.arange(len(target))
    
    # Plot actual values
    ax.plot(time_steps, target, 'k-', label='Actual', linewidth=2, marker='o', markersize=4)
    
    # Plot median prediction
    median_idx = config.quantiles.index(0.5)
    ax.plot(time_steps, pred[:, median_idx], 'b-', label='Predicted (Q0.5)', linewidth=2, marker='s', markersize=4)
    
    # Plot prediction interval
    ax.fill_between(
        time_steps,
        pred[:, 0],  # Q0.1
        pred[:, -1],  # Q0.9
        alpha=0.3,
        color='blue',
        label='80% Prediction Interval'
    )
    
    ax.set_xlabel('Forecast Step')
    ax.set_ylabel('Value')
    ax.set_title(f'Sample {sample_idx}')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

plt.suptitle('TFT Predictions with Uncertainty Intervals', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Calculate prediction accuracy metrics
pred_np = predictions.cpu().numpy()
target_np = targets.cpu().numpy()

# Median predictions
median_idx = config.quantiles.index(0.5)
median_pred = pred_np[:, :, median_idx]

# Calculate metrics
mae = np.mean(np.abs(median_pred - target_np))
rmse = np.sqrt(np.mean((median_pred - target_np) ** 2))
mape = np.mean(np.abs((median_pred - target_np) / (target_np + 1e-8))) * 100

# Coverage: What % of actuals fall within prediction interval
lower = pred_np[:, :, 0]  # Q0.1
upper = pred_np[:, :, -1]  # Q0.9
coverage = np.mean((target_np >= lower) & (target_np <= upper)) * 100

print("Prediction Quality Metrics:")
print("=" * 40)
print(f"  MAE:  {mae:.4f}")
print(f"  RMSE: {rmse:.4f}")
print(f"  MAPE: {mape:.2f}%")
print(f"\n  80% Interval Coverage: {coverage:.1f}% (target: 80%)")

## 10. Model Interpretability

One of TFT's key advantages is **built-in interpretability**. Let's examine:
1. **Attention Weights**: Which past time steps influence predictions?
2. **Variable Importance**: Which features matter most?

### 10.1 Attention Weights Analysis

Attention weights show how the model weighs different historical time steps when making predictions.

In [None]:
# Check if attention weights are available
if 'attention_weights' in results:
    attention = results['attention_weights']
    print(f"Attention weights shape: {attention.shape}")
    print(f"  - Samples: {attention.shape[0]}")
    print(f"  - Attention heads: {attention.shape[1]}")
    print(f"  - Target positions: {attention.shape[2]}")
    print(f"  - Source positions: {attention.shape[3]}")
else:
    print("Attention weights not available in results")
    attention = None

In [None]:
if attention is not None:
    # Average attention across heads for visualization
    avg_attention = attention.mean(dim=1)  # Average over heads
    
    # Plot attention heatmap for a sample
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Single sample attention heatmap
    sample_idx = 0
    sample_attention = avg_attention[sample_idx].cpu().numpy()
    
    ax1 = axes[0]
    im1 = ax1.imshow(sample_attention, cmap='viridis', aspect='auto')
    ax1.set_xlabel('Source Position (Historical)')
    ax1.set_ylabel('Target Position (Forecast)')
    ax1.set_title(f'Attention Heatmap (Sample {sample_idx})')
    plt.colorbar(im1, ax=ax1, label='Attention Weight')
    
    # Average attention pattern across all samples
    mean_attention = avg_attention.mean(dim=0).cpu().numpy()  # Average over samples
    
    ax2 = axes[1]
    im2 = ax2.imshow(mean_attention, cmap='viridis', aspect='auto')
    ax2.set_xlabel('Source Position (Historical)')
    ax2.set_ylabel('Target Position (Forecast)')
    ax2.set_title('Average Attention Pattern (All Samples)')
    plt.colorbar(im2, ax=ax2, label='Attention Weight')
    
    plt.tight_layout()
    plt.show()

In [None]:
if attention is not None:
    # Analyze which historical positions get the most attention
    # Average over samples, heads, and target positions
    temporal_importance = avg_attention.mean(dim=(0, 1)).cpu().numpy()
    
    fig, ax = plt.subplots(figsize=(12, 4))
    
    positions = np.arange(len(temporal_importance))
    colors = plt.cm.viridis(temporal_importance / temporal_importance.max())
    
    ax.bar(positions, temporal_importance, color=colors)
    ax.set_xlabel('Historical Time Step')
    ax.set_ylabel('Average Attention Weight')
    ax.set_title('Temporal Attention Distribution: Which Past Time Steps Matter Most?')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Mark the most important positions
    top_k = 5
    top_indices = np.argsort(temporal_importance)[-top_k:][::-1]
    for idx in top_indices:
        ax.annotate(f't-{len(temporal_importance)-idx-1}', 
                   xy=(idx, temporal_importance[idx]),
                   xytext=(idx, temporal_importance[idx] + 0.01),
                   ha='center', fontsize=9)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nTop {top_k} most attended historical positions:")
    for i, idx in enumerate(top_indices):
        print(f"  {i+1}. Position {idx} (t-{len(temporal_importance)-idx-1}): {temporal_importance[idx]:.4f}")

### 10.2 Variable Importance Analysis

TFT's Variable Selection Networks learn which input features are most important for predictions.

In [None]:
# Check available variable selection outputs
print("Available keys in results:")
for key in results.keys():
    if isinstance(results[key], torch.Tensor):
        print(f"  {key}: {results[key].shape}")
    else:
        print(f"  {key}: {type(results[key])}")

In [None]:
# Variable importance visualization
# We'll compute importance based on the encoder variable selection

if 'encoder_variable_selection' in results:
    encoder_weights = results['encoder_variable_selection']
    
    # Define encoder variable names (observed + known future + target)
    encoder_vars = config.observed_only + config.known_future + [config.target]
    
    # Average importance
    avg_importance = encoder_weights.mean(dim=(0, 1)).cpu().numpy()
    
    # Create importance DataFrame
    importance_df = pd.DataFrame({
        'Variable': encoder_vars[:len(avg_importance)],
        'Importance': avg_importance[:len(encoder_vars)]
    }).sort_values('Importance', ascending=True)
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    colors = plt.cm.RdYlGn(importance_df['Importance'] / importance_df['Importance'].max())
    ax.barh(importance_df['Variable'], importance_df['Importance'], color=colors)
    ax.set_xlabel('Selection Weight (Importance)')
    ax.set_ylabel('Variable')
    ax.set_title('Encoder Variable Importance\n(Learned by Variable Selection Network)')
    ax.grid(True, alpha=0.3, axis='x')
    
    plt.tight_layout()
    plt.show()
    
    print("\nVariable Importance Ranking:")
    for i, (_, row) in enumerate(importance_df.iloc[::-1].iterrows()):
        print(f"  {i+1}. {row['Variable']}: {row['Importance']:.4f}")
else:
    print("Variable selection weights not available in results")

## 11. Summary

### What We Learned:

1. **TFT Architecture**: A sophisticated model combining LSTMs, attention, and gating mechanisms for interpretable forecasting

2. **Data Handling**: TFT gracefully handles three types of inputs:
   - Static features (time-invariant)
   - Known future features (available at forecast time)
   - Observed features (historical only)

3. **Probabilistic Forecasting**: Quantile predictions provide uncertainty estimates, not just point forecasts

4. **Interpretability**: Built-in mechanisms reveal:
   - Which historical time steps influence predictions (attention)
   - Which features are most important (variable selection)

### Key Takeaways:

| Aspect | TFT Advantage |
|--------|---------------|
| **Accuracy** | State-of-the-art forecasting performance |
| **Uncertainty** | Quantile outputs for confidence intervals |
| **Interpretability** | Attention + variable importance |
| **Flexibility** | Handles mixed input types |
| **Scalability** | Works with multiple time series |

In [None]:
# Final summary
print("=" * 60)
print("TFT TUTORIAL COMPLETE")
print("=" * 60)
print(f"\nModel Performance:")
print(f"  Final Training Loss:   {trainer.history['train_loss'][-1]:.6f}")
print(f"  Final Validation Loss: {trainer.history['val_loss'][-1]:.6f}")
print(f"  Best Validation Loss:  {trainer.best_val_loss:.6f}")
print(f"  Test Loss:             {test_metrics['loss']:.6f}")
print(f"\nModel Configuration:")
print(f"  Parameters:      {total_params:,}")
print(f"  Encoder Length:  {config.encoder_length}")
print(f"  Decoder Length:  {config.decoder_length}")
print(f"  Hidden Size:     {config.hidden_size}")
print(f"\nFiles Saved:")
print(f"  Model Checkpoint: tft_best_model.pth")