<a href="https://colab.research.google.com/github/fishee82oo/nfs-oil-price-prediction/blob/main/GDELT_Model_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Temporal GNN Training for Oil Price Prediction

This notebook contains the full training workflow for predicting oil price changes from temporal graph snapshots. All dataset utilities, model definitions, and training helpers live here so you can run end-to-end experiments from a single place.

## Environment setup

Install the required Python packages before running the rest of the notebook.

In [None]:
!pip install -q -r requirements.txt

In [None]:
from __future__ import annotations

import json
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Callable, List, Optional, Sequence, Tuple

import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, Subset
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GraphNorm, SAGEConv, global_mean_pool

## Dataset utilities

`GraphSnapshotDataset` loads snapshot metadata, enforces a 5–10 year training window, and materialises `torch_geometric.data.Data` objects on demand.

In [None]:
"""Dataset helpers for temporal graph snapshots with oil price labels."""

@dataclass(frozen=True)
class SnapshotMetadata:
    """Metadata describing a single graph snapshot.

    Attributes
    ----------
    timestamp:
        Datetime of the snapshot. Used to build temporal splits.
    graph_path:
        Path on disk pointing to a serialized :class:`~torch_geometric.data.Data`
        object. ``torch.load`` will be used to materialise the object when the
        dataset item is requested.
    label:
        Target value representing the observed oil price change aligned with the
        graph snapshot. Regression targets are expected to be floats; however, it
        is up to the caller to ensure the underlying data matches the learning
        task.
    extra:
        Optional additional metadata loaded from the CSV file. Retained for
        downstream logging/debugging without affecting training.
    """

    timestamp: datetime
    graph_path: Path
    label: float
    extra: dict


class GraphSnapshotDataset(Dataset):
    """Graph dataset aligned with oil price changes.

    The dataset expects a metadata CSV with at least three columns:

    ``timestamp``
        Datetime (ISO 8601, YYYY-MM-DD, or any pandas-compatible format)
        indicating when the snapshot was recorded.
    ``graph_path``
        Path to a serialized :class:`torch_geometric.data.Data` object. Paths can
        be absolute or relative to the CSV location.
    ``price_change``
        Numeric value representing the observed change in oil price for the
        snapshot (target label).

    Additional columns are preserved inside :class:`SnapshotMetadata.extra` for
    logging or advanced filtering.
    """

    def __init__(
        self,
        metadata_csv: Path | str,
        *,
        start_year: Optional[int] = None,
        end_year: Optional[int] = None,
        min_years: int = 5,
        max_years: int = 10,
        label_column: str = "price_change",
        time_column: str = "timestamp",
        graph_path_column: str = "graph_path",
        transform: Optional[Callable[[Data], Data]] = None,
    ) -> None:
        super().__init__()
        self._metadata_path = Path(metadata_csv)
        if not self._metadata_path.exists():
            raise FileNotFoundError(f"Metadata CSV not found: {self._metadata_path}")

        self.transform = transform
        self._label_column = label_column
        self._time_column = time_column
        self._graph_column = graph_path_column

        df = pd.read_csv(self._metadata_path)
        if time_column not in df.columns:
            raise ValueError(f"Missing time column '{time_column}' in metadata")
        if graph_path_column not in df.columns:
            raise ValueError(
                f"Missing graph path column '{graph_path_column}' in metadata"
            )
        if label_column not in df.columns:
            raise ValueError(f"Missing label column '{label_column}' in metadata")

        df[time_column] = pd.to_datetime(df[time_column], utc=True, errors="coerce")
        if df[time_column].isna().any():
            raise ValueError("Some timestamps could not be parsed. Ensure ISO-8601 format.")

        df = df.sort_values(time_column)
        available_years = df[time_column].dt.year
        min_available_year = int(available_years.min())
        max_available_year = int(available_years.max())

        if start_year is None and end_year is None:
            end_year = max_available_year
            start_year = max(min_available_year, end_year - max_years + 1)
        elif start_year is None:
            start_year = max(min_available_year, int(end_year) - max_years + 1)
        elif end_year is None:
            end_year = min(max_available_year, int(start_year) + max_years - 1)

        if start_year > end_year:
            raise ValueError(
                f"Invalid year range start={start_year}, end={end_year}. start must be <= end."
            )

        year_span = end_year - start_year + 1
        if year_span < min_years:
            raise ValueError(
                "Year range must span at least "
                f"{min_years} years; received {year_span} years ({start_year}-{end_year})."
            )
        if year_span > max_years:
            raise ValueError(
                "Year range must be within the configured maximum. "
                f"Got {year_span} years, maximum allowed is {max_years}."
            )

        mask = (df[time_column].dt.year >= start_year) & (df[time_column].dt.year <= end_year)
        df = df.loc[mask]
        if df.empty:
            raise ValueError(
                "No snapshots found for the requested time window. "
                f"Available years: {min_available_year}-{max_available_year}."
            )

        base_dir = self._metadata_path.parent
        metadata: List[SnapshotMetadata] = []
        for row in df.itertuples(index=False):
            timestamp = getattr(row, time_column)
            label = float(getattr(row, label_column))
            raw_graph_path = getattr(row, graph_path_column)
            graph_path = (base_dir / raw_graph_path).resolve()
            if not graph_path.exists():
                raise FileNotFoundError(
                    f"Graph snapshot referenced in metadata does not exist: {graph_path}"
                )

            extra = {
                column: getattr(row, column)
                for column in df.columns
                if column not in {time_column, graph_path_column, label_column}
            }
            metadata.append(
                SnapshotMetadata(
                    timestamp=timestamp.to_pydatetime(),
                    graph_path=graph_path,
                    label=label,
                    extra=extra,
                )
            )

        if len(metadata) < 2:
            raise ValueError("Dataset must contain at least two snapshots for training.")

        self._metadata = tuple(metadata)

    @property
    def start_timestamp(self) -> datetime:
        return self._metadata[0].timestamp

    @property
    def end_timestamp(self) -> datetime:
        return self._metadata[-1].timestamp

    @property
    def metadata(self) -> Sequence[SnapshotMetadata]:
        return self._metadata

    def __len__(self) -> int:  # type: ignore[override]
        return len(self._metadata)

    def __getitem__(self, index: int) -> Data:  # type: ignore[override]
        meta = self._metadata[index]
        data = torch.load(meta.graph_path)
        if not isinstance(data, Data):
            raise TypeError(
                f"Snapshot at {meta.graph_path} is not a torch_geometric.data.Data object"
            )

        y = torch.tensor([meta.label], dtype=torch.float)
        data.y = y
        data.snapshot_timestamp = meta.timestamp  # type: ignore[attr-defined]
        data.snapshot_metadata = meta.extra  # type: ignore[attr-defined]

        if self.transform is not None:
            data = self.transform(data)

        return data

    def split_indices(
        self,
        train_ratio: float = 0.7,
        val_ratio: float = 0.15,
        *,
        shuffle_within_year: bool = False,
        generator: Optional[torch.Generator] = None,
    ) -> Tuple[List[int], List[int], List[int]]:
        """Create chronological splits for train/val/test.

        Parameters
        ----------
        train_ratio, val_ratio:
            Fractions of the dataset assigned to training and validation sets.
            The remainder is assigned to the test set. Ratios refer to the number
            of snapshots (not years) and will preserve chronological ordering.
        shuffle_within_year:
            When enabled, shuffles the order of snapshots captured in the same
            calendar year before applying the split. This can help reduce bias if
            multiple snapshots exist per year.
        generator:
            Optional random generator for deterministic shuffling.
        """

        if train_ratio <= 0 or val_ratio < 0:
            raise ValueError("train_ratio must be > 0 and val_ratio must be >= 0")
        if train_ratio + val_ratio >= 1:
            raise ValueError("train_ratio + val_ratio must be < 1 to leave room for test")

        grouped_indices: List[int] = list(range(len(self._metadata)))
        if shuffle_within_year:
            grouped: dict[int, List[int]] = {}
            for idx, meta in enumerate(self._metadata):
                grouped.setdefault(meta.timestamp.year, []).append(idx)
            grouped_indices = []
            for year in sorted(grouped.keys()):
                indices = grouped[year]
                if generator is not None:
                    perm = torch.randperm(len(indices), generator=generator).tolist()
                    indices = [indices[i] for i in perm]
                grouped_indices.extend(indices)

        n_total = len(grouped_indices)
        n_train = int(n_total * train_ratio)
        n_val = int(n_total * val_ratio)
        n_test = n_total - n_train - n_val
        if n_train == 0 or n_val == 0 or n_test == 0:
            raise ValueError(
                "Not enough samples to satisfy split ratios. "
                f"Dataset size {n_total}, train {n_train}, val {n_val}, test {n_test}."
            )

        train_idx = grouped_indices[:n_train]
        val_idx = grouped_indices[n_train : n_train + n_val]
        test_idx = grouped_indices[n_train + n_val :]
        return train_idx, val_idx, test_idx


## Metrics and training helpers

Regression metrics, early stopping, and the epoch-level training loop live in this section.

In [None]:
@dataclass
class RegressionMetrics:
    mae: float
    rmse: float
    r2: float


@dataclass
class TrainingHistoryEntry:
    epoch: int
    train_loss: float
    val_loss: float
    val_metrics: RegressionMetrics


class EarlyStopping:
    """Simple early stopping helper."""

    def __init__(self, patience: int = 20, min_delta: float = 0.0) -> None:
        if patience <= 0:
            raise ValueError("patience must be > 0")
        self.patience = patience
        self.min_delta = min_delta
        self._best_score: Optional[float] = None
        self._best_state_dict: Optional[dict[str, torch.Tensor]] = None
        self._counter = 0

    @property
    def best_state_dict(self) -> Optional[dict[str, torch.Tensor]]:
        return self._best_state_dict

    def step(self, score: float, model: nn.Module) -> bool:
        if self._best_score is None or score < self._best_score - self.min_delta:
            self._best_score = score
            self._counter = 0
            self._best_state_dict = {
                k: v.detach().cpu().clone() for k, v in model.state_dict().items()
            }
            return False

        self._counter += 1
        return self._counter >= self.patience


def compute_regression_metrics(
    y_true: torch.Tensor, y_pred: torch.Tensor
) -> RegressionMetrics:
    if y_true.ndim > 1:
        y_true = y_true.view(-1)
    if y_pred.ndim > 1:
        y_pred = y_pred.view(-1)

    mae = torch.mean(torch.abs(y_true - y_pred)).item()
    mse = torch.mean((y_true - y_pred) ** 2)
    rmse = torch.sqrt(mse).item()

    y_true_mean = torch.mean(y_true)
    ss_tot = torch.sum((y_true - y_true_mean) ** 2)
    ss_res = torch.sum((y_true - y_pred) ** 2)
    if ss_tot == 0:
        r2 = 0.0
    else:
        r2 = (1 - ss_res / ss_tot).item()

    return RegressionMetrics(mae=mae, rmse=rmse, r2=r2)


def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    *,
    device: Optional[torch.device] = None,
    scheduler: Optional[torch.optim.lr_scheduler.ReduceLROnPlateau] = None,
    max_epochs: int = 200,
    early_stopping: Optional[EarlyStopping] = None,
) -> List[TrainingHistoryEntry]:
    """Train a model, returning the epoch history."""

    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = nn.MSELoss()
    history: List[TrainingHistoryEntry] = []

    for epoch in range(1, max_epochs + 1):
        model.train()
        train_losses: List[float] = []
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            preds = model(batch)
            target = batch.y.to(device).view_as(preds)
            loss = criterion(preds, target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        mean_train_loss = float(torch.tensor(train_losses).mean().item()) if train_losses else 0.0

        model.eval()
        val_losses: List[float] = []
        y_true: List[torch.Tensor] = []
        y_pred: List[torch.Tensor] = []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                preds = model(batch)
                target = batch.y.to(device).view_as(preds)
                loss = criterion(preds, target)
                val_losses.append(loss.item())
                y_true.append(target.detach().cpu())
                y_pred.append(preds.detach().cpu())

        if not y_true or not y_pred:
            raise ValueError("Validation loader produced no batches; cannot compute metrics.")

        mean_val_loss = float(torch.tensor(val_losses).mean().item()) if val_losses else 0.0
        val_metrics = compute_regression_metrics(torch.cat(y_true), torch.cat(y_pred))
        history.append(
            TrainingHistoryEntry(
                epoch=epoch,
                train_loss=mean_train_loss,
                val_loss=mean_val_loss,
                val_metrics=val_metrics,
            )
        )

        if scheduler is not None:
            scheduler.step(mean_val_loss)

        if early_stopping is not None:
            should_stop = early_stopping.step(mean_val_loss, model)
            if should_stop:
                break

    if early_stopping is not None and early_stopping.best_state_dict is not None:
        model.load_state_dict(early_stopping.best_state_dict)

    return history


## Graph neural network architecture

A lightweight GraphSAGE-based regressor with optional timestamp encoding summarises each snapshot.

In [None]:
@dataclass
class OilPriceTemporalGNNConfig:
    input_channels: int
    hidden_channels: int = 128
    output_channels: int = 1
    num_layers: int = 3
    dropout: float = 0.2
    use_timestamp_encoding: bool = True


def _encode_timestamps(batch: Batch | Data) -> Optional[torch.Tensor]:
    """Return a ``(batch_size, 1)`` tensor with normalized timestamps if present."""

    timestamps = getattr(batch, "snapshot_timestamp", None)
    if timestamps is None:
        return None

    if isinstance(timestamps, torch.Tensor):
        ts_tensor = timestamps.float().view(-1, 1)
    else:
        if not isinstance(timestamps, (list, tuple)):
            timestamps = [timestamps]
        values = [torch.as_tensor(ts.timestamp(), dtype=torch.float) for ts in timestamps]
        ts_tensor = torch.stack(values, dim=0).view(-1, 1)

    ts_tensor = ts_tensor.to(batch.x.device if hasattr(batch, "x") else ts_tensor.device)
    ts_mean = ts_tensor.mean()
    ts_std = ts_tensor.std(unbiased=False)
    if ts_std == 0:
        ts_std = ts_std + 1.0
    ts_tensor = (ts_tensor - ts_mean) / ts_std
    return ts_tensor


class OilPriceTemporalGNN(nn.Module):
    """A lightweight GraphSAGE-based model for oil price change prediction."""

    def __init__(self, config: OilPriceTemporalGNNConfig) -> None:
        super().__init__()
        if config.num_layers < 1:
            raise ValueError("num_layers must be >= 1")

        self.config = config
        layers = []
        norms = []
        in_channels = config.input_channels
        for _ in range(config.num_layers):
            conv = SAGEConv(in_channels, config.hidden_channels)
            norm = GraphNorm(config.hidden_channels)
            layers.append(conv)
            norms.append(norm)
            in_channels = config.hidden_channels

        self.convs = nn.ModuleList(layers)
        self.norms = nn.ModuleList(norms)
        self.dropout = nn.Dropout(config.dropout)

        readout_dim = config.hidden_channels
        if config.use_timestamp_encoding:
            readout_dim += 1

        self.regressor = nn.Sequential(
            nn.Linear(readout_dim, config.hidden_channels),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_channels, config.output_channels),
        )

    def forward(self, batch: Batch | Data) -> torch.Tensor:
        x, edge_index, batch_index = batch.x, batch.edge_index, batch.batch
        for conv, norm in zip(self.convs, self.norms):
            x = conv(x, edge_index)
            x = norm(x, batch_index)
            x = torch.relu(x)
            x = self.dropout(x)

        pooled = global_mean_pool(x, batch_index)
        if self.config.use_timestamp_encoding:
            ts_encoding = _encode_timestamps(batch)
            if ts_encoding is not None:
                ts_encoding = ts_encoding.to(pooled.device)
                if ts_encoding.shape[0] != pooled.shape[0]:
                    raise RuntimeError(
                        "Timestamp encoding size mismatch. Ensure each graph has a timestamp."
                    )
                pooled = torch.cat([pooled, ts_encoding], dim=1)
        return self.regressor(pooled)


## End-to-end training helper

Use `run_training` to load the dataset, train the model with early stopping, evaluate on the held-out test split, and optionally persist metrics to disk.

In [None]:
def run_training(
    metadata_csv: Path | str,
    *,
    start_year: Optional[int] = None,
    end_year: Optional[int] = None,
    min_years: int = 5,
    max_years: int = 10,
    train_ratio: float = 0.7,
    val_ratio: float = 0.15,
    batch_size: int = 32,
    num_layers: int = 3,
    hidden_channels: int = 128,
    dropout: float = 0.2,
    learning_rate: float = 1e-3,
    weight_decay: float = 1e-4,
    epochs: int = 200,
    patience: int = 25,
    min_delta: float = 1e-4,
    disable_timestamp_encoding: bool = False,
    num_workers: int = 0,
    shuffle_within_year: bool = False,
    output_dir: Optional[Path | str] = None,
) -> tuple[list[dict], dict]:
    dataset = GraphSnapshotDataset(
        metadata_csv,
        start_year=start_year,
        end_year=end_year,
        min_years=min_years,
        max_years=max_years,
    )

    train_idx, val_idx, test_idx = dataset.split_indices(
        train_ratio=train_ratio,
        val_ratio=val_ratio,
        shuffle_within_year=shuffle_within_year,
    )
    train_dataset = Subset(dataset, train_idx)
    val_dataset = Subset(dataset, val_idx)
    test_dataset = Subset(dataset, test_idx)

    sample = dataset[0]
    if sample.x is None:
        raise ValueError("Graph snapshots must contain node features in `data.x`.")

    config = OilPriceTemporalGNNConfig(
        input_channels=sample.num_node_features,
        hidden_channels=hidden_channels,
        num_layers=num_layers,
        dropout=dropout,
        use_timestamp_encoding=not disable_timestamp_encoding,
    )
    model = OilPriceTemporalGNN(config)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=max(5, patience // 3), factor=0.5
    )
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)

    history = train_model(
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler=scheduler,
        max_epochs=epochs,
        early_stopping=early_stopping,
    )

    device = next(model.parameters()).device
    criterion = nn.MSELoss()
    model.eval()
    all_true: List[torch.Tensor] = []
    all_pred: List[torch.Tensor] = []
    test_losses: List[float] = []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            preds = model(batch)
            target = batch.y.to(device).view_as(preds)
            loss = criterion(preds, target)
            all_true.append(target.cpu())
            all_pred.append(preds.cpu())
            test_losses.append(loss.item())

    if not all_true or not all_pred:
        raise ValueError("Test loader produced no samples. Check dataset configuration.")

    test_loss = float(torch.tensor(test_losses).mean().item())
    test_metrics = compute_regression_metrics(torch.cat(all_true), torch.cat(all_pred))

    history_records = [
        {
            "epoch": entry.epoch,
            "train_loss": entry.train_loss,
            "val_loss": entry.val_loss,
            "val_mae": entry.val_metrics.mae,
            "val_rmse": entry.val_metrics.rmse,
            "val_r2": entry.val_metrics.r2,
        }
        for entry in history
    ]
    metrics_dict = {
        "test_loss": test_loss,
        "test_mae": test_metrics.mae,
        "test_rmse": test_metrics.rmse,
        "test_r2": test_metrics.r2,
        "num_snapshots": len(dataset),
        "train_range": dataset.start_timestamp.isoformat(),
        "test_range": dataset.end_timestamp.isoformat(),
    }

    if output_dir is not None:
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        history_path = output_path / "training_history.json"
        metrics_path = output_path / "test_metrics.json"
        with history_path.open("w", encoding="utf-8") as fp:
            json.dump(history_records, fp, indent=2)
        with metrics_path.open("w", encoding="utf-8") as fp:
            json.dump(metrics_dict, fp, indent=2)
        print(f"Training history saved to {history_path}")
        print(f"Test metrics saved to {metrics_path}")

    return history_records, metrics_dict


### Example usage

Uncomment and edit the following cell with the path to your snapshot metadata to launch training.

In [None]:
# Example: run a full training job
# history, metrics = run_training(
#     "snapshots.csv",
#     start_year=2015,
#     end_year=2024,
#     batch_size=16,
#     epochs=150,
#     output_dir="runs/oil-price-gnn",
# )
# metrics