# Multi-Target Forecasting with Temporal Fusion Transformer
This notebook demonstrates how to use PyTorch Forecasting's **Temporal Fusion Transformer** to predict multiple oceanic variables (temperature, salinity, wind speed, wind direction, and tide height) for the next 24 hours based on the previous 24 hours of data.

In [None]:
!pip install pytorch-lightning pytorch-forecasting --quiet

In [None]:
import pandas as pd
import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MultiLoss, QuantileLoss

In [None]:
# Generate synthetic IoT ocean data (7 days hourly)
from datetime import datetime, timedelta
hours = 24 * 7
timestamps = [datetime(2025, 7, 4) + timedelta(hours=i) for i in range(hours)]
np.random.seed(42)

# Features
temperature = 26 + 2 * np.sin(np.linspace(0, 3*np.pi, hours)) + np.random.normal(0, 0.3, hours)
salinity = 35 + 0.5 * np.sin(np.linspace(0, 2*np.pi, hours)) + np.random.normal(0, 0.1, hours)
pH = 8 + 0.1 * np.cos(np.linspace(0, 4*np.pi, hours)) + np.random.normal(0, 0.05, hours)
wind_speed = 3 + np.abs(np.sin(np.linspace(0, 3*np.pi, hours))) + np.random.normal(0, 0.2, hours)
wind_direction = (np.linspace(0, 360, hours) + np.random.normal(0, 10, hours)) % 360
tide_height = 1 + 0.5 * np.sin(np.linspace(0, 4*np.pi, hours)) + np.random.normal(0, 0.05, hours)

data = pd.DataFrame({
    'timestamp': timestamps,
    'group': 'buoy_1',
    'temperature': temperature,
    'salinity': salinity,
    'pH': pH,
    'wind_speed': wind_speed,
    'wind_direction': wind_direction,
    'tide_height': tide_height
})
# Create a time index
data['time_idx'] = np.arange(len(data))

data.head()

In [None]:
# Prepare TimeSeriesDataSet
max_encoder_length = 24  # past 24 hours
max_prediction_length = 24  # future 24 hours

training_cutoff = data['time_idx'].max() - max_prediction_length
training = TimeSeriesDataSet(
    data[data.time_idx <= training_cutoff],
    time_idx='time_idx',
    target=['temperature','salinity','wind_speed','wind_direction','tide_height'],
    group_ids=['group'],
    min_encoder_length=max_encoder_length,
    max_encoder_length=max_encoder_length,
    min_prediction_length=max_prediction_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=['group'],
    time_varying_known_reals=['time_idx'],
    time_varying_unknown_reals=['temperature','salinity','wind_speed','wind_direction','tide_height'],
    target_normalizer=GroupNormalizer(
        groups=['group'], transformation='softplus', center=False
    ),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

# Create dataloaders
train_dataloader = training.to_dataloader(train=True, batch_size=16, num_workers=0)
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
val_dataloader = validation.to_dataloader(train=False, batch_size=16, num_workers=0)

In [None]:
# Define Temporal Fusion Transformer
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=5,  # one output per target
    loss=MultiLoss([QuantileLoss(0.1), QuantileLoss(0.5), QuantileLoss(0.9)])
)

# Train
trainer = Trainer(max_epochs=10, gradient_clip_val=0.1)
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

In [None]:
# Predict next 24 hours
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
raw_predictions, x = best_tft.predict(val_dataloader, mode='raw', return_x=True)

# Plot temperature forecasts
import matplotlib.pyplot as plt
for idx in range(3):  # first 3 examples
    actual = x['decoder_target'][idx][:, 0]
    prediction = raw_predictions['prediction'][idx][:, 0]
    plt.figure()
    plt.plot(actual, label='Actual Temperature')
    plt.plot(prediction, label='Predicted Temperature')
    plt.title(f'Example {idx+1}')
    plt.xlabel('Hour')
    plt.ylabel('Temperature')
    plt.legend()
    plt.show()