# Temporal Fusion Transformer (TFT) Tutorial
## Electricity Load Forecasting with ElectricityLoadDiagrams20112014 Dataset

This notebook provides a comprehensive tutorial on using the **Temporal Fusion Transformer (TFT)** for multi-horizon time series forecasting. We'll use real-world electricity consumption data to demonstrate the complete workflow.

### What is TFT?

The Temporal Fusion Transformer is a state-of-the-art deep learning architecture designed for:
- **Multi-horizon forecasting**: Predict multiple future time steps at once
- **Interpretability**: Understand which features and time steps matter most
- **Mixed inputs**: Handle static, known future, and observed-only features

### Key Components

```
┌─────────────────────────────────────────────────────────────────┐
│                    TFT Architecture                             │
├─────────────────────────────────────────────────────────────────┤
│  Static Features ──► Variable Selection ──► Context Vectors    │
│                                                    │            │
│  Historical Data ──► Variable Selection ──► LSTM Encoder ─┐    │
│                                                            │    │
│  Future Known ────► Variable Selection ──► LSTM Decoder ──┼──► │
│                                                            │    │
│                              Interpretable Multi-Head ◄────┘    │
│                                   Attention                     │
│                                      │                          │
│                              Quantile Outputs                   │
│                          (Probabilistic Forecasts)              │
└─────────────────────────────────────────────────────────────────┘
```

### Dataset

We'll use the **ElectricityLoadDiagrams20112014** dataset from UCI ML Repository:
- 370 clients' electricity consumption
- 15-minute intervals from 2011-2014
- ~140,000 time steps per client

## 1. Setup and Imports

In [None]:
import sys
sys.path.append('..')

import os
import zipfile
import urllib.request
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch

# TFT imports
from tft.models import TemporalFusionTransformer
from tft.data import create_dataloaders
from tft.training import TFTTrainer, EarlyStopping, ModelCheckpoint
from tft.utils import TFTConfig, get_device, print_device_info

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Check device
print_device_info()
device = get_device('auto')

## 2. Download and Load Data

First, let's download the electricity dataset and explore its structure.

In [None]:
# Configuration
DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/00321/LD2011_2014.txt.zip"
DATA_DIR = Path("data")
DATA_FILE = DATA_DIR / "LD2011_2014.txt"

def download_data():
    """Download the electricity dataset."""
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    zip_path = DATA_DIR / "LD2011_2014.txt.zip"
    
    if DATA_FILE.exists():
        print(f"Data already exists at {DATA_FILE}")
        return
    
    print(f"Downloading from {DATA_URL}...")
    urllib.request.urlretrieve(DATA_URL, zip_path)
    
    print("Extracting...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(DATA_DIR)
    zip_path.unlink()
    print("Done!")

download_data()

In [None]:
# Load the raw data
print("Loading data...")
df_raw = pd.read_csv(
    DATA_FILE,
    sep=';',
    decimal=',',
    index_col=0,
    parse_dates=True,
)

print(f"Raw data shape: {df_raw.shape}")
print(f"Date range: {df_raw.index.min()} to {df_raw.index.max()}")
print(f"Number of clients: {len(df_raw.columns)}")
df_raw.head()

## 3. Data Exploration

Let's visualize the electricity consumption patterns to understand our data better.

In [None]:
# Select a few clients for visualization
sample_clients = df_raw.columns[:5].tolist()
df_sample = df_raw[sample_clients]

# Resample to daily for better visualization
df_daily = df_sample.resample('1D').sum()

# Plot time series
fig, axes = plt.subplots(2, 1, figsize=(14, 8))

# Full time series
ax1 = axes[0]
for client in sample_clients:
    ax1.plot(df_daily.index, df_daily[client], label=client, alpha=0.7)
ax1.set_title('Daily Electricity Consumption (Full Period)', fontsize=14)
ax1.set_xlabel('Date')
ax1.set_ylabel('Consumption (kWh)')
ax1.legend(loc='upper right')

# Zoom into 1 month
ax2 = axes[1]
df_month = df_sample['2012-06'].resample('1h').sum()
for client in sample_clients[:3]:
    ax2.plot(df_month.index, df_month[client], label=client, alpha=0.8)
ax2.set_title('Hourly Consumption (June 2012)', fontsize=14)
ax2.set_xlabel('Date')
ax2.set_ylabel('Consumption (kWh)')
ax2.legend(loc='upper right')

plt.tight_layout()
plt.show()

In [None]:
# Analyze seasonality patterns
client = sample_clients[0]
df_hourly = df_raw[client].resample('1h').sum()

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Hourly pattern (average by hour of day)
hourly_pattern = df_hourly.groupby(df_hourly.index.hour).mean()
axes[0].bar(hourly_pattern.index, hourly_pattern.values, color='steelblue', alpha=0.8)
axes[0].set_title(f'Average Consumption by Hour\n({client})', fontsize=12)
axes[0].set_xlabel('Hour of Day')
axes[0].set_ylabel('Avg Consumption')
axes[0].set_xticks(range(0, 24, 3))

# Daily pattern (average by day of week)
daily_pattern = df_hourly.groupby(df_hourly.index.dayofweek).mean()
days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
axes[1].bar(days, daily_pattern.values, color='coral', alpha=0.8)
axes[1].set_title(f'Average Consumption by Day of Week\n({client})', fontsize=12)
axes[1].set_xlabel('Day of Week')
axes[1].set_ylabel('Avg Consumption')

# Monthly pattern
monthly_pattern = df_hourly.groupby(df_hourly.index.month).mean()
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
axes[2].bar(months, monthly_pattern.values, color='seagreen', alpha=0.8)
axes[2].set_title(f'Average Consumption by Month\n({client})', fontsize=12)
axes[2].set_xlabel('Month')
axes[2].set_ylabel('Avg Consumption')
axes[2].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

print("\nKey Observations:")
print("- Clear daily pattern: higher consumption during daytime")
print("- Weekly pattern: lower consumption on weekends")
print("- Seasonal pattern: varies by month (heating/cooling needs)")

## 4. Data Preprocessing

Now let's prepare the data for TFT. We need to:
1. Resample to a manageable frequency
2. Convert to long format
3. Add time-based features
4. Split into train/val/test sets

In [None]:
# Configuration for this tutorial (reduced for faster training)
NUM_CLIENTS = 3        # Number of clients to use
RESAMPLE_FREQ = '4h'   # Resample frequency
ENCODER_LENGTH = 42    # ~1 week at 4-hour intervals
DECODER_LENGTH = 6     # ~24 hours ahead

print(f"Configuration:")
print(f"  - Number of clients: {NUM_CLIENTS}")
print(f"  - Resample frequency: {RESAMPLE_FREQ}")
print(f"  - Encoder length: {ENCODER_LENGTH} steps (~{ENCODER_LENGTH * 4} hours)")
print(f"  - Decoder length: {DECODER_LENGTH} steps (~{DECODER_LENGTH * 4} hours)")

In [None]:
def preprocess_data(df_raw, num_clients, resample_freq):
    """
    Preprocess the electricity data for TFT.
    
    Steps:
    1. Select subset of clients
    2. Resample to target frequency
    3. Convert to long format
    4. Add time features
    """
    # Select clients
    selected_clients = df_raw.columns[:num_clients].tolist()
    df = df_raw[selected_clients].copy()
    
    # Resample
    df = df.resample(resample_freq).sum()
    df = df.dropna()
    df = df[(df != 0).any(axis=1)]
    
    print(f"After resampling: {df.shape}")
    
    # Convert to long format
    df = df.reset_index()
    df = df.melt(
        id_vars=['index'],
        var_name='client_id',
        value_name='consumption',
    )
    df = df.rename(columns={'index': 'datetime'})
    
    # Create numeric client ID
    client_mapping = {c: i for i, c in enumerate(selected_clients)}
    df['client_id_num'] = df['client_id'].map(client_mapping)
    
    return df, selected_clients

df, clients = preprocess_data(df_raw, NUM_CLIENTS, RESAMPLE_FREQ)
print(f"\nTotal samples: {len(df)}")
print(f"Clients: {clients}")
df.head(10)

In [None]:
def add_time_features(df):
    """
    Add time-based features for the TFT model.
    
    TFT uses three types of features:
    1. Static: Don't change over time (e.g., client_id)
    2. Known future: Known at prediction time (e.g., hour, day of week)
    3. Observed only: Only known historically (e.g., past consumption)
    """
    df = df.copy()
    
    # Extract time components
    df['hour'] = df['datetime'].dt.hour
    df['day_of_week'] = df['datetime'].dt.dayofweek
    df['day_of_month'] = df['datetime'].dt.day
    df['month'] = df['datetime'].dt.month
    df['week_of_year'] = df['datetime'].dt.isocalendar().week.astype(int)
    
    # Cyclical encoding (important for periodic features!)
    # This helps the model understand that hour 23 is close to hour 0
    df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
    df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
    df['day_of_week_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
    df['day_of_week_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
    df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
    df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
    
    # Create continuous time index per client
    df = df.sort_values(['client_id', 'datetime'])
    df['time_idx'] = df.groupby('client_id').cumcount()
    
    # Binary features
    df['is_weekend'] = (df['day_of_week'] >= 5).astype(float)
    
    return df

df = add_time_features(df)
print(f"Features: {list(df.columns)}")
df.head()

In [None]:
# Visualize cyclical encoding
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Hour encoding
hours = np.arange(24)
hour_sin = np.sin(2 * np.pi * hours / 24)
hour_cos = np.cos(2 * np.pi * hours / 24)

ax1 = axes[0]
ax1.plot(hours, hour_sin, 'b-', label='sin(hour)', linewidth=2)
ax1.plot(hours, hour_cos, 'r-', label='cos(hour)', linewidth=2)
ax1.set_xlabel('Hour of Day')
ax1.set_ylabel('Encoded Value')
ax1.set_title('Cyclical Hour Encoding')
ax1.legend()
ax1.set_xticks(range(0, 24, 3))

# 2D representation
ax2 = axes[1]
scatter = ax2.scatter(hour_cos, hour_sin, c=hours, cmap='viridis', s=100)
for i, h in enumerate(hours):
    ax2.annotate(str(h), (hour_cos[i]+0.05, hour_sin[i]+0.05), fontsize=8)
ax2.set_xlabel('cos(hour)')
ax2.set_ylabel('sin(hour)')
ax2.set_title('Hours in 2D Space\n(Notice: hour 23 is close to hour 0!)')
ax2.set_aspect('equal')
plt.colorbar(scatter, label='Hour')

plt.tight_layout()
plt.show()

print("\nWhy cyclical encoding?")
print("- Hour 23 should be close to hour 0 (they're adjacent in time)")
print("- With simple encoding (0-23), the model sees them as far apart")
print("- Cyclical encoding preserves the circular nature of time")

## 5. Train/Validation/Test Split

For time series, we split chronologically to avoid data leakage.

In [None]:
def split_data(df, train_frac=0.7, val_frac=0.15):
    """
    Split data chronologically into train/val/test sets.
    
    Important: For time series, we MUST split by time, not randomly!
    Random splits would leak future information into training.
    """
    unique_times = sorted(df['time_idx'].unique())
    n_times = len(unique_times)
    
    train_end = int(n_times * train_frac)
    val_end = int(n_times * (train_frac + val_frac))
    
    train_times = set(unique_times[:train_end])
    val_times = set(unique_times[train_end:val_end])
    test_times = set(unique_times[val_end:])
    
    train_df = df[df['time_idx'].isin(train_times)].copy()
    val_df = df[df['time_idx'].isin(val_times)].copy()
    test_df = df[df['time_idx'].isin(test_times)].copy()
    
    return train_df, val_df, test_df

train_df, val_df, test_df = split_data(df)

print(f"Train: {len(train_df):,} samples ({len(train_df)/len(df)*100:.1f}%)")
print(f"Val:   {len(val_df):,} samples ({len(val_df)/len(df)*100:.1f}%)")
print(f"Test:  {len(test_df):,} samples ({len(test_df)/len(df)*100:.1f}%)")

In [None]:
# Visualize the split
fig, ax = plt.subplots(figsize=(14, 4))

client = clients[0]
train_client = train_df[train_df['client_id'] == client]
val_client = val_df[val_df['client_id'] == client]
test_client = test_df[test_df['client_id'] == client]

ax.plot(train_client['datetime'], train_client['consumption'], 
        'b-', alpha=0.7, label='Train')
ax.plot(val_client['datetime'], val_client['consumption'], 
        'orange', alpha=0.7, label='Validation')
ax.plot(test_client['datetime'], test_client['consumption'], 
        'g-', alpha=0.7, label='Test')

ax.set_xlabel('Date')
ax.set_ylabel('Consumption')
ax.set_title(f'Train/Val/Test Split for {client}')
ax.legend(loc='upper right')

plt.tight_layout()
plt.show()

## 6. Model Configuration

Now let's configure the TFT model. The key decisions are:
1. **Input types**: Which features are static, known future, or observed only
2. **Sequence lengths**: How much history and how far to forecast
3. **Architecture**: Hidden size, attention heads, LSTM layers

In [None]:
# Create TFT configuration
config = TFTConfig(
    # === Input Features ===
    # Static: Don't change over time (one value per time series)
    static_variables=['client_id_num'],
    
    # Known future: We know these values for future time steps
    # (time-based features are always known in advance)
    known_future=[
        'hour_sin', 'hour_cos',
        'day_of_week_sin', 'day_of_week_cos',
        'month_sin', 'month_cos',
        'is_weekend',
    ],
    
    # Observed only: Only available in historical data
    # (the target itself is observed only)
    observed_only=[],
    
    # Target variable to forecast
    target='consumption',
    
    # === Sequence Configuration ===
    encoder_length=ENCODER_LENGTH,  # How much history to use
    decoder_length=DECODER_LENGTH,  # How far to forecast
    
    # === Architecture ===
    hidden_size=64,      # Size of hidden layers
    num_heads=4,         # Number of attention heads
    num_lstm_layers=1,   # Number of LSTM layers
    dropout=0.1,         # Dropout rate
    
    # === Quantiles for Probabilistic Forecasting ===
    # Output prediction intervals, not just point estimates
    quantiles=[0.1, 0.5, 0.9],  # 10th, 50th (median), 90th percentile
    
    # === Training ===
    batch_size=32,
    learning_rate=1e-3,
    max_epochs=3,  # Reduced for tutorial
    gradient_clip_val=1.0,
)

print("Model Configuration:")
print(f"  Static variables:    {config.static_variables}")
print(f"  Known future:        {config.known_future}")
print(f"  Target:              {config.target}")
print(f"  Encoder length:      {config.encoder_length}")
print(f"  Decoder length:      {config.decoder_length}")
print(f"  Hidden size:         {config.hidden_size}")
print(f"  Attention heads:     {config.num_heads}")
print(f"  Quantiles:           {config.quantiles}")

### Understanding Input Types

```
┌─────────────────────────────────────────────────────────────────┐
│                         Time Axis                               │
│   ◄──────── Encoder (History) ────────►│◄─ Decoder (Future) ─►│
├─────────────────────────────────────────┼───────────────────────┤
│                                         │                       │
│ Static Variables (e.g., client_id)      │                       │
│ ════════════════════════════════════════════════════════════    │
│ Same value across all time steps        │                       │
│                                         │                       │
│ Known Future (e.g., hour, day_of_week)  │                       │
│ ────────────────────────────────────────────────────────────    │
│ Available for both history and future   │                       │
│                                         │                       │
│ Observed Only (e.g., past consumption)  │                       │
│ ────────────────────────────────────────│ ? ? ? ? ? ?           │
│ Only available in history               │ Unknown in future     │
│                                         │                       │
└─────────────────────────────────────────┴───────────────────────┘
```

## 7. Create DataLoaders and Model

In [None]:
# Create data loaders
print("Creating data loaders...")
train_loader, val_loader, test_loader = create_dataloaders(
    train_data=train_df,
    val_data=val_df,
    test_data=test_df,
    config=config,
    batch_size=config.batch_size,
    num_workers=0,
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")
print(f"Test batches:  {len(test_loader)}")

In [None]:
# Examine a batch
batch = next(iter(train_loader))
print("Batch contents:")
for key, value in batch.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: shape={value.shape}, dtype={value.dtype}")
    else:
        print(f"  {key}: {type(value)}")

In [None]:
# Create the TFT model
print("Creating TFT model...")
model = TemporalFusionTransformer(config)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Statistics:")
print(f"  Total parameters:     {num_params:,}")
print(f"  Trainable parameters: {num_trainable:,}")

In [None]:
# Print model architecture
print("\nModel Architecture:")
print(model)

## 8. Training

Now let's train the model. We use:
- **Quantile Loss**: For probabilistic forecasting
- **Early Stopping**: To prevent overfitting
- **Model Checkpointing**: To save the best model

In [None]:
# Create trainer
trainer = TFTTrainer(
    model=model,
    config=config,
    device=device,
)

# Setup callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=5,
        min_delta=1e-4,
        verbose=True,
    ),
    ModelCheckpoint(
        filepath='tft_electricity_best.pth',
        monitor='val_loss',
        save_best_only=True,
        verbose=True,
    ),
]

print("Starting training...")

In [None]:
# Train the model
trainer.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=config.max_epochs,
    callbacks=callbacks,
)

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curves
ax1 = axes[0]
epochs = range(1, len(trainer.history['train_loss']) + 1)
ax1.plot(epochs, trainer.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
ax1.plot(epochs, trainer.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Learning rate (if available)
ax2 = axes[1]
if 'learning_rate' in trainer.history:
    ax2.plot(epochs, trainer.history['learning_rate'], 'g-', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Learning Rate')
    ax2.set_title('Learning Rate Schedule')
else:
    # Show loss difference
    loss_diff = np.array(trainer.history['train_loss']) - np.array(trainer.history['val_loss'])
    colors = ['green' if d < 0 else 'red' for d in loss_diff]
    ax2.bar(epochs, loss_diff, color=colors, alpha=0.7)
    ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Train - Val Loss')
    ax2.set_title('Overfitting Indicator\n(Positive = Overfitting)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"  Train Loss: {trainer.history['train_loss'][-1]:.6f}")
print(f"  Val Loss:   {trainer.history['val_loss'][-1]:.6f}")

## 9. Evaluation

Let's evaluate the model on the test set and analyze the results.

In [None]:
# Evaluate on test set
test_metrics = trainer.validate(test_loader)

print("Test Set Metrics:")
print("=" * 50)
for metric, value in test_metrics.items():
    print(f"  {metric:25s}: {value:.6f}")

In [None]:
# Visualize key metrics
metrics_to_plot = {
    'MAE': test_metrics.get('mae', 0),
    'RMSE': test_metrics.get('rmse', 0),
    'R²': test_metrics.get('r2', 0),
    'SMAPE': test_metrics.get('smape', 0),
}

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Error metrics
ax1 = axes[0]
error_metrics = ['MAE', 'RMSE', 'SMAPE']
error_values = [metrics_to_plot[m] for m in error_metrics]
bars = ax1.bar(error_metrics, error_values, color=['steelblue', 'coral', 'seagreen'])
ax1.set_ylabel('Value')
ax1.set_title('Error Metrics (Lower is Better)')
for bar, val in zip(bars, error_values):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{val:.3f}', ha='center', va='bottom', fontsize=10)

# Quantile coverage
ax2 = axes[1]
q_names = ['10% Quantile', '50% Quantile', '90% Quantile']
q_coverage = [
    test_metrics.get('q10_coverage', 0),
    test_metrics.get('q50_coverage', 0),
    test_metrics.get('q90_coverage', 0),
]
expected = [0.1, 0.5, 0.9]

x = np.arange(len(q_names))
width = 0.35
bars1 = ax2.bar(x - width/2, q_coverage, width, label='Actual', color='steelblue')
bars2 = ax2.bar(x + width/2, expected, width, label='Expected', color='lightgray')
ax2.set_ylabel('Coverage')
ax2.set_title('Quantile Coverage\n(Actual vs Expected)')
ax2.set_xticks(x)
ax2.set_xticklabels(q_names)
ax2.legend()
ax2.set_ylim(0, 1)

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print(f"  R² = {metrics_to_plot['R²']:.3f} means the model explains {metrics_to_plot['R²']*100:.1f}% of variance")
print(f"  SMAPE = {metrics_to_plot['SMAPE']:.1f}% average percentage error")

## 10. Predictions Visualization

Let's visualize the model's predictions with uncertainty intervals.

In [None]:
# Generate predictions
results = trainer.predict(
    test_loader,
    return_attention=True,
    return_variable_selection=True,
)

predictions = results['predictions']  # Shape: (samples, horizon, quantiles)
targets = results['targets']          # Shape: (samples, horizon)

print(f"Predictions shape: {predictions.shape}")
print(f"Targets shape: {targets.shape}")

In [None]:
def plot_forecast(predictions, targets, sample_idx, config):
    """
    Plot a single forecast with prediction intervals.
    """
    fig, ax = plt.subplots(figsize=(12, 5))
    
    horizon = predictions.shape[1]
    time_steps = np.arange(horizon)
    
    pred = predictions[sample_idx].numpy()
    target = targets[sample_idx].numpy()
    
    # Plot prediction intervals
    q10 = pred[:, 0]  # 10th percentile
    q50 = pred[:, 1]  # 50th percentile (median)
    q90 = pred[:, 2]  # 90th percentile
    
    # Confidence interval
    ax.fill_between(time_steps, q10, q90, alpha=0.3, color='blue', 
                    label='80% Prediction Interval')
    
    # Median prediction
    ax.plot(time_steps, q50, 'b-', linewidth=2, label='Median Prediction (q50)')
    
    # Actual values
    ax.plot(time_steps, target, 'ro-', linewidth=2, markersize=8, 
            label='Actual Values')
    
    ax.set_xlabel('Forecast Horizon (steps)')
    ax.set_ylabel('Consumption')
    ax.set_title(f'TFT Forecast - Sample {sample_idx}')
    ax.legend(loc='best')
    ax.grid(True, alpha=0.3)
    
    # Add step labels
    ax.set_xticks(time_steps)
    ax.set_xticklabels([f'+{(i+1)*4}h' for i in time_steps])
    
    plt.tight_layout()
    return fig

# Plot a few samples
for idx in [0, 100, 500]:
    if idx < len(predictions):
        plot_forecast(predictions, targets, idx, config)
        plt.show()

In [None]:
# Plot multiple forecasts together
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

sample_indices = [0, 50, 150, 300]

for ax, idx in zip(axes, sample_indices):
    if idx >= len(predictions):
        continue
        
    horizon = predictions.shape[1]
    time_steps = np.arange(horizon)
    
    pred = predictions[idx].numpy()
    target = targets[idx].numpy()
    
    q10, q50, q90 = pred[:, 0], pred[:, 1], pred[:, 2]
    
    ax.fill_between(time_steps, q10, q90, alpha=0.3, color='blue')
    ax.plot(time_steps, q50, 'b-', linewidth=2, label='Prediction')
    ax.plot(time_steps, target, 'ro-', linewidth=2, markersize=6, label='Actual')
    
    ax.set_xlabel('Horizon')
    ax.set_ylabel('Consumption')
    ax.set_title(f'Sample {idx}')
    ax.legend(loc='best', fontsize=8)
    ax.grid(True, alpha=0.3)

plt.suptitle('TFT Forecasts Across Different Samples', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Scatter plot: Predicted vs Actual
fig, ax = plt.subplots(figsize=(8, 8))

# Use median predictions
pred_median = predictions[:, :, 1].flatten().numpy()
actual = targets.flatten().numpy()

# Sample for visualization (too many points otherwise)
n_sample = min(5000, len(pred_median))
indices = np.random.choice(len(pred_median), n_sample, replace=False)

ax.scatter(actual[indices], pred_median[indices], alpha=0.3, s=10)

# Perfect prediction line
min_val = min(actual.min(), pred_median.min())
max_val = max(actual.max(), pred_median.max())
ax.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect Prediction')

ax.set_xlabel('Actual Consumption')
ax.set_ylabel('Predicted Consumption')
ax.set_title('Predicted vs Actual Values')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_aspect('equal')

plt.tight_layout()
plt.show()

# Calculate correlation
correlation = np.corrcoef(actual, pred_median)[0, 1]
print(f"Correlation between predicted and actual: {correlation:.4f}")

## 11. Error Analysis

In [None]:
# Error by forecast horizon
errors_by_horizon = []
for h in range(predictions.shape[1]):
    pred_h = predictions[:, h, 1].numpy()  # Median prediction
    actual_h = targets[:, h].numpy()
    mae_h = np.abs(pred_h - actual_h).mean()
    errors_by_horizon.append(mae_h)

fig, ax = plt.subplots(figsize=(10, 5))
horizons = range(1, len(errors_by_horizon) + 1)
ax.bar(horizons, errors_by_horizon, color='steelblue', alpha=0.8)
ax.set_xlabel('Forecast Horizon (steps)')
ax.set_ylabel('Mean Absolute Error')
ax.set_title('Prediction Error by Forecast Horizon\n(Error typically increases with horizon)')
ax.set_xticks(horizons)
ax.set_xticklabels([f'+{h*4}h' for h in horizons])

plt.tight_layout()
plt.show()

print("\nObservation: Error tends to increase with forecast horizon")
print("This is expected - predicting further ahead is harder!")

In [None]:
# Error distribution
errors = (predictions[:, :, 1] - targets).flatten().numpy()

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histogram
ax1 = axes[0]
ax1.hist(errors, bins=50, density=True, alpha=0.7, color='steelblue')
ax1.axvline(x=0, color='red', linestyle='--', linewidth=2, label='Zero Error')
ax1.axvline(x=errors.mean(), color='green', linestyle='--', linewidth=2, 
            label=f'Mean: {errors.mean():.3f}')
ax1.set_xlabel('Prediction Error')
ax1.set_ylabel('Density')
ax1.set_title('Error Distribution')
ax1.legend()

# Q-Q plot
from scipy import stats
ax2 = axes[1]
stats.probplot(errors, dist="norm", plot=ax2)
ax2.set_title('Q-Q Plot (Normality Check)')

plt.tight_layout()
plt.show()

print(f"Error Statistics:")
print(f"  Mean:     {errors.mean():.4f}")
print(f"  Std:      {errors.std():.4f}")
print(f"  Skewness: {stats.skew(errors):.4f}")

## 12. Model Interpretability

One of TFT's key strengths is interpretability. Let's examine:
1. **Variable Selection Weights**: Which features matter most?
2. **Attention Weights**: Which time steps are most important?

In [None]:
# Check what interpretability outputs are available
print("Available interpretation outputs:")
for key in results.keys():
    if 'attention' in key.lower() or 'variable' in key.lower() or 'weight' in key.lower():
        value = results[key]
        if isinstance(value, torch.Tensor):
            print(f"  {key}: shape={value.shape}")
        elif isinstance(value, dict):
            print(f"  {key}: dict with keys {list(value.keys())}")

In [None]:
# Plot attention weights if available
if 'attention_weights' in results:
    attention = results['attention_weights']
    
    # Average attention across samples
    if isinstance(attention, torch.Tensor):
        avg_attention = attention.mean(dim=0).numpy()
        
        fig, ax = plt.subplots(figsize=(12, 6))
        
        im = ax.imshow(avg_attention, aspect='auto', cmap='Blues')
        ax.set_xlabel('Encoder Time Steps (History)')
        ax.set_ylabel('Decoder Time Steps (Future)')
        ax.set_title('Average Attention Weights\n(Which historical steps matter for each future prediction?)')
        plt.colorbar(im, label='Attention Weight')
        
        plt.tight_layout()
        plt.show()
        
        print("\nInterpretation:")
        print("  Brighter colors = higher attention = more important time steps")
        print("  The model learns which historical patterns are most predictive")
else:
    print("Attention weights not available in results")
    print("This might be due to the model configuration or batch size")

## 13. Summary and Next Steps

### What We Learned

1. **TFT Architecture**: Combines LSTMs, attention, and variable selection for interpretable forecasting
2. **Feature Types**: Static, known future, and observed-only features require different handling
3. **Probabilistic Forecasting**: Quantile outputs provide uncertainty estimates
4. **Cyclical Encoding**: Important for time-based features

### Performance Summary

In [None]:
# Final summary
print("=" * 60)
print("                    TRAINING SUMMARY")
print("=" * 60)
print(f"\nDataset: ElectricityLoadDiagrams20112014")
print(f"Clients: {NUM_CLIENTS}")
print(f"Resampling: {RESAMPLE_FREQ}")
print(f"\nModel Configuration:")
print(f"  Encoder length: {ENCODER_LENGTH} steps")
print(f"  Decoder length: {DECODER_LENGTH} steps")
print(f"  Hidden size: {config.hidden_size}")
print(f"  Parameters: {num_params:,}")
print(f"\nTraining Results:")
print(f"  Final train loss: {trainer.history['train_loss'][-1]:.6f}")
print(f"  Final val loss:   {trainer.history['val_loss'][-1]:.6f}")
print(f"\nTest Metrics:")
print(f"  MAE:   {test_metrics.get('mae', 0):.4f}")
print(f"  RMSE:  {test_metrics.get('rmse', 0):.4f}")
print(f"  R²:    {test_metrics.get('r2', 0):.4f}")
print(f"  SMAPE: {test_metrics.get('smape', 0):.2f}%")
print("=" * 60)

### Improving Results

To get better performance, try:

1. **More data**: Increase `NUM_CLIENTS` and use hourly resampling
2. **Larger model**: Increase `hidden_size` to 128 or 256
3. **More epochs**: Train for 20-50 epochs with early stopping
4. **Additional features**: Add weather data, holidays, etc.
5. **GPU acceleration**: Use CUDA for faster training

### References

- [Original TFT Paper](https://arxiv.org/abs/1912.09363): "Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting"
- [UCI Dataset](https://archive.ics.uci.edu/ml/datasets/ElectricityLoadDiagrams20112014)

In [None]:
print("\nTutorial complete! Happy forecasting!")