# Global Config

## Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## config module

In [None]:
!pip install segmentation-models-pytorch

In [None]:
# Imports
from pathlib import Path
import torch
import os
import xarray as xr
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import segmentation_models_pytorch as smp

# Global Paths
ROOT_DIR  = Path("/content/drive/MyDrive/Pós-Graduações/Computer_Vision_Master/Materias/EII/Final_Project")

# Paths
DATASET_PATH = ROOT_DIR / "WMO_heatwave_conditions_1961-present.nc"

# Global variables
HISTORICAL_PERIOD = ("1961-01-01", "1990-12-31")

# Config Device
DEVICE = (
          "cuda"
          if torch.cuda.is_available()
          else "mps"
          if torch.backends.mps.is_available()
          else "cpu"
          )

print(f"GPU is available? {torch.cuda.is_available()}")

print(f"Using {DEVICE} device")

# Model

## Exploration Data

In [None]:
ds = xr.open_dataset(DATASET_PATH)
ds = ds.isel(time=slice(None, -1))
ds.HWC.mean(dim='time').plot()

In [None]:
ds.dims

In [None]:
ds.time.tail(5)

## preprocessing module

In [None]:
class ZScoreNormalizer:
    """
    Z-score normalizer for spatiotemporal tensors.

    This class computes a global mean and standard deviation from the
    training data and applies standardization to inputs and targets.
    It also supports inverse transformation back to physical units.
    """

    def __init__(self, eps=1e-8):
        """
        Initialize the normalizer.

        Parameters
        ----------
        eps : float, optional
            Small constant added to the standard deviation to avoid
            division by zero. Default is 1e-8.
        """

        self.mu = None
        self.sigma = None
        self.eps = eps

    def fit(self, X):
        """
        Compute mean and standard deviation from training data.

        Parameters
        ----------
        X : torch.Tensor
            Input tensor used for fitting the normalizer.
            Expected shape: (N, T, C, H, W) or compatible.

        Returns
        -------
        self : ZScoreNormalizer
            Fitted normalizer.
        """

        self.mu = X.mean()
        self.sigma = X.std() + self.eps
        return self

    def transform(self, X):
        """
        Apply z-score normalization.

        Parameters
        ----------
        X : torch.Tensor
            Input tensor to be normalized.

        Returns
        -------
        torch.Tensor
            Normalized tensor.
        """

        if self.mu is None or self.sigma is None:
            raise RuntimeError("Normalizer must be fitted first.")

        return (X - self.mu) / self.sigma

    def fit_transform(self, X):
        """
        Fit the normalizer and apply normalization.

        Parameters
        ----------
        X : torch.Tensor
            Input tensor used for fitting and transformation.

        Returns
        -------
        torch.Tensor
            Normalized tensor.
        """

        self.fit(X)
        return self.transform(X)

    def inverse_transform(self, X):
        """
        Convert normalized values back to the original physical scale.

        Parameters
        ----------
        X : torch.Tensor
            Normalized tensor.

        Returns
        -------
        torch.Tensor
            Tensor in the original physical units.
        """

        if self.mu is None or self.sigma is None:
            raise RuntimeError("Normalizer must be fitted first.")
        return X * self.sigma + self.mu


class DHWDataset(Dataset):
    """
    PyTorch Dataset for Deep Heatwave (DHW) prediction.

    Each sample consists of a temporal sequence of spatial fields
    and the corresponding target field at the next time step.
    """

    def __init__(self, X, y):
        """
        Initialize the dataset.

        Parameters
        ----------
        X : torch.Tensor
            Input tensor with shape (N, T, C, H, W),
            where N is the number of samples and T is the sequence length.
        y : torch.Tensor
            Target tensor with shape (N, C, H, W).
        """

        self.X = X
        self.y = y

    def __len__(self):
        """
        Return the number of samples in the dataset.
        """

        return self.X.shape[0]

    def __getitem__(self, idx):
        """
        Retrieve a single sample from the dataset.

        Parameters
        ----------
        idx : int
            Index of the sample.

        Returns
        -------
        tuple
            (X_i, y_i), where:
            - X_i has shape (T, C, H, W)
            - y_i has shape (C, H, W)
        """

        return self.X[idx], self.y[idx]


def make_sequences(data, seq_len=12):
    """
    Create input-output sequences for supervised temporal learning.

    Given a time-ordered spatial dataset, this function builds
    sliding windows of length `seq_len` as inputs and the subsequent
    time step as the target.

    Parameters
    ----------
    data : numpy.ndarray
        Input array with shape (time, latitude, longitude).
    seq_len : int, optional
        Length of the input temporal sequence. Default is 12.

    Returns
    -------
    X : numpy.ndarray
        Input sequences with shape (N, seq_len, 1, H, W).
    y : numpy.ndarray
        Target fields with shape (N, 1, H, W).
    """

    X, y = [], []

    for t in range(seq_len, len(data)):
        X.append(data[t-seq_len:t])  # 12 months
        y.append(data[t])            # next month

    X = np.array(X)[:, :, None, :, :]
    y = np.array(y)[:, None, :, :]

    return X, y

## model module

In [None]:
class TemporalUnetTransformer(nn.Module):
    """
    Spatiotemporal U-Net with a temporal Transformer at the bottleneck.

    This model combines a convolutional U-Net encoder–decoder architecture
    with a Transformer-based temporal module applied at the deepest spatial
    representation. The encoder is applied independently to each time step,
    while temporal dependencies are learned at the bottleneck level using
    self-attention.

    The output is a spatial field corresponding to the prediction at the
    next time step.
    """

    def __init__(
                 self,
                 in_channels=1,
                 out_channels=1,
                 encoder_name="resnet18",
                 encoder_weights=None,
                 n_heads=4,
                 num_layers=1,
                 ):
        """
        Initialize the Temporal U-Net Transformer model.

        Parameters
        ----------
        in_channels : int, optional
            Number of input channels per time step. Default is 1.
        out_channels : int, optional
            Number of output channels. Default is 1.
        encoder_name : str, optional
            Name of the encoder backbone used in the U-Net architecture.
            Must be compatible with `segmentation_models_pytorch`.
            Default is "resnet34".
        encoder_weights : str or None, optional
            Pretrained weights for the encoder backbone.
            Common options are "imagenet" or None. Default is None.
        n_heads : int, optional
            Number of attention heads in the temporal Transformer.
            Default is 8.
        num_layers : int, optional
            Number of Transformer encoder layers.
            Default is 2.
        """

        super().__init__()

        # U-Net backbone
        self.unet = smp.Unet(
                             encoder_name=encoder_name,
                             encoder_weights=encoder_weights,
                             in_channels=in_channels,
                             classes=out_channels,
                             activation=None,
                             )

        self.encoder = self.unet.encoder
        self.decoder = self.unet.decoder
        self.head = self.unet.segmentation_head

        # Number of channels at the bottleneck (deepest encoder level)
        bottleneck_channels = self.encoder.out_channels[-1]

        # Temporal Transformer
        encoder_layer = nn.TransformerEncoderLayer(
                                                   d_model=bottleneck_channels,
                                                   nhead=n_heads,
                                                   dim_feedforward=4 * bottleneck_channels,
                                                   batch_first=True,
                                                   dropout=0.1,
                                                   )

        self.temporal_transformer = nn.TransformerEncoder(
                                                          encoder_layer,
                                                          num_layers=num_layers,
                                                          )

    def forward(self, x):
        """
        Forward pass of the model.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor with shape (B, T, C, H, W), where:
            - B is the batch size
            - T is the number of time steps
            - C is the number of input channels
            - H is the height of the spatial grid
            - W is the width of the spatial grid

        Returns
        -------
        torch.Tensor
            Output tensor with shape (B, out_channels, H, W),
            representing the predicted spatial field at the next time step.
        """

        B, T, _, H, W = x.shape

        bottlenecks = []
        skips_all = []

        # Encoder applied per time step
        for t in range(T):
            feats = self.encoder(x[:, t])   # multi-scale feature maps
            skips_all.append(feats[:-1])    # skip connections
            bottlenecks.append(feats[-1])   # bottleneck features (B, C, h, w)

        # Stack bottleneck features along the temporal dimension
        # Shape: (B, T, C, h, w)
        Z = torch.stack(bottlenecks, dim=1)

        B, T, C, h, w = Z.shape

        # Temporal Transformer
        # Each spatial location is treated as a temporal sequence
        Z = Z.permute(0, 3, 4, 1, 2)      # (B, h, w, T, C)
        Z = Z.reshape(B * h * w, T, C)    # (B*h*w, T, C)

        Z = self.temporal_transformer(Z)  # (B*h*w, T, C)
        Z = Z[:, -1]                      # last temporal step

        # Restore spatial structure
        Z = Z.reshape(B, h, w, C)
        Z = Z.permute(0, 3, 1, 2)         # (B, C, h, w)

        # U-Net decoder
        # Uses skip connections from the last time step
        features = skips_all[-1] + [Z]

        y = self.decoder(features)
        y = self.head(y)

        return y


## train_eval_loop module

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer, device):
    """
    Execute one training epoch.

    This function performs a full pass over the training dataset, computing
    forward and backward passes, updating model parameters, and accumulating
    the training loss.

    Parameters
    ----------
    dataloader : torch.utils.data.DataLoader
        DataLoader providing batches of training data. Each batch must return
        a tuple (X, y), where:
        - X has shape (B, T, C, H, W)
        - y has shape (B, C, H, W)
    model : torch.nn.Module
        Neural network model to be trained.
    loss_fn : callable
        Loss function used for optimization (e.g., torch.nn.MSELoss).
    optimizer : torch.optim.Optimizer
        Optimizer used to update model parameters.
    device : torch.device or str
        Device on which computations are performed ("cpu", "cuda", or "mps").

    Returns
    -------
    float
        Average training loss (MSE) over the entire epoch.
    """

    model.train()
    running_loss = 0.0
    size = len(dataloader.dataset)

    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)

        # Forward
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if batch % 50 == 0:
            current = batch * X.size(0)
            print(f"batch {batch:4d} | loss {loss.item():.6f}")

    epoch_loss = running_loss / len(dataloader)
    return epoch_loss


def eval_loop(dataloader, model, loss_fn, device):
    """
    Evaluate the model on a validation or test dataset.

    This function runs the model in evaluation mode, disables gradient
    computation, and computes the mean squared error (MSE) and root mean
    squared error (RMSE) over the provided dataset.

    Parameters
    ----------
    dataloader : torch.utils.data.DataLoader
        DataLoader providing batches of validation or test data. Each batch
        must return a tuple (X, y), where:
        - X has shape (B, T, C, H, W)
        - y has shape (B, C, H, W)
    model : torch.nn.Module
        Trained neural network model to be evaluated.
    loss_fn : callable
        Loss function used for evaluation (e.g., torch.nn.MSELoss).
    device : torch.device or str
        Device on which computations are performed ("cpu", "cuda", or "mps").

    Returns
    -------
    avg_mse : float
        Mean squared error averaged over all batches.
    rmse : float
        Root mean squared error computed from the averaged MSE.
    """

    model.eval()
    total_mse = 0.0

    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)

            pred = model(X)
            mse = loss_fn(pred, y)

            total_mse += mse.item()

    avg_mse = total_mse / len(dataloader)
    rmse = np.sqrt(avg_mse)

    return avg_mse, rmse


## main

In [None]:
# Heatwave conditions
dhw = ds.HWC

# Monthly Climatology (1961–1990)
clim = dhw.sel(time=slice(HISTORICAL_PERIOD[0], HISTORICAL_PERIOD[1])).groupby("time.month").mean("time")

# Monthly Anomalies
dhw_anom = dhw.groupby("time.month") - clim

In [None]:
dhw_np = dhw_anom.transpose("time", "latitude", "longitude").values

In [None]:
dhw_np.shape

In [None]:
# Formating X, y for Torch Datasets
X, y = make_sequences(dhw_np, seq_len=12)

In [None]:
print(X.shape)
print(y.shape)

In [None]:
X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()

In [None]:
train_end_norm = int(0.8 * X.shape[0])

normalizer = ZScoreNormalizer()

X[:train_end_norm] = normalizer.fit_transform(X[:train_end_norm])
X[train_end_norm:] = normalizer.transform(X[train_end_norm:])

y = normalizer.transform(y)

In [None]:
dataset = DHWDataset(X, y)

In [None]:
n = len(dataset)

train = int(0.8 * n)
val   = int(0.9 * n)

train_ds = torch.utils.data.Subset(dataset, range(0, train))
val_ds   = torch.utils.data.Subset(dataset, range(train, val))
test_ds  = torch.utils.data.Subset(dataset, range(val, n))

In [None]:
train_loader = DataLoader(train_ds, batch_size=1, shuffle=False)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False)

In [None]:
# Validating all dimensions
xb, yb = next(iter(train_loader))
names_dim_X = ["batch", "time", "channel", "latitude", "Longitude"]
names_dim_y = ["batch", "time + 1", "latitude", "Longitude"]

print("xb shape:", xb.shape)
print(names_dim_X)
print("yb shape:", yb.shape)
print(names_dim_y)

In [None]:
model = TemporalUnetTransformer().to(DEVICE)

xb, yb = next(iter(train_loader))
xb = xb.to(DEVICE)

with torch.no_grad():
    yhat = model(xb)

print(yhat.shape)

In [None]:
# Model configs
# Hiperparmeters
learning_rate = 1e-4      # transformer + UNet -> lower LR
batch_size    = 1
epochs        = 50
weight_decay  = 1e-5

# loss function
loss_fn = nn.MSELoss()

# optmizer
optimizer = torch.optim.AdamW(
                              model.parameters(),
                              lr=learning_rate,
                              weight_decay=weight_decay
                              )

# scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                                                       optimizer,
                                                       mode="min",
                                                       factor=0.5,
                                                       patience=5,
                                                       )

# early stopping
early_stopping_patience = 7
early_stopping_counter = 0
best_val_loss = float("inf")

In [None]:
# Train Loop
best_val_loss = float("inf")

for epoch in range(1, epochs + 1):
    print(f"\nEpoch {epoch}/{epochs}")
    print("-" * 30)

    train_loss = train_loop(train_loader, model, loss_fn, optimizer, DEVICE)

    val_mse, val_rmse = eval_loop(val_loader, model, loss_fn, DEVICE)

    # Scheduler --> validation loss
    scheduler.step(val_mse)

    print(
          f"Epoch {epoch} | "
          f"Train MSE: {train_loss:.6f} | "
          f"Val MSE: {val_mse:.6f} | "
          f"Val RMSE: {val_rmse:.6f}"
          )

    if val_mse < best_val_loss:
        best_val_loss = val_mse
        early_stopping_counter = 0
        torch.save(
                   {
                    "model_state": model.state_dict(),
                    "mu": normalizer.mu,
                    "sigma": normalizer.sigma,
                    },
                   "best_temporal_unet_transformer.pt"
                   )

    else:
        early_stopping_counter += 1

    if early_stopping_counter >= early_stopping_patience:
        print("Early stopping")
        break


# Test and Inference

In [None]:
# test
ckpt = torch.load("best_temporal_unet_transformer.pt", map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
normalizer.mu = ckpt["mu"]
normalizer.sigma = ckpt["sigma"]

normalizer.mu = normalizer.mu.to(DEVICE)
normalizer.sigma = normalizer.sigma.to(DEVICE)

model.eval()

test_mse, test_rmse = eval_loop(test_loader, model, loss_fn, DEVICE)

print(f"Test MSE (normalized):  {test_mse:.6f}")
print(f"Test RMSE (normalized): {test_rmse:.6f}")

# RMSE (units in days)
test_rmse_phys = test_rmse * normalizer.sigma.item()
print(f"Test RMSE (physical units): {test_rmse_phys:.6f}")

# Baseline
seq_len = 12
all_times = ds.time[seq_len:]

n_samples = X.shape[0]
train_end = int(0.8 * n_samples)
val_end   = int(0.9 * n_samples)

test_times = all_times[val_end:]

y_clim_list = []

for t in test_times:
    month = t.dt.month.item()
    y_clim_list.append(clim.sel(month=month).values)

y_climatology = np.stack(y_clim_list)  # (N_test, lat, lon)

y_true_phys = []

for idx in range(len(test_ds)):
    _, yb = test_ds[idx]
    yb = normalizer.inverse_transform(yb.unsqueeze(0).to(DEVICE))
    y_true_phys.append(yb[0, 0].cpu().numpy())

y_true_phys = np.stack(y_true_phys)  # (N_test, lat, lon)

MSE_baseline = np.mean((y_true_phys - y_climatology) ** 2)
print(f"MSE baseline (climatology): {MSE_baseline:.6f}")

MSE_model = test_rmse_phys ** 2
print(f"MSE model (physical units): {MSE_model:.6f}")

Skill = 1.0 - (MSE_model / MSE_baseline)
print(f"Skill score: {Skill:.3f}")

In [None]:
model.eval()
idx = 0  # zero-based --> 0 = january

xb, yb = test_ds[idx]
xb = xb.unsqueeze(0).to(DEVICE)
yb = yb.unsqueeze(0).to(DEVICE)

with torch.no_grad():
    y_hat_phys = normalizer.inverse_transform(model(xb))
    y_true_phys = normalizer.inverse_transform(yb)

y_hat_map = y_hat_phys[0, 0].cpu().numpy()
y_true_map = y_true_phys[0, 0].cpu().numpy()

seq_len = 12
all_times = ds.time[seq_len:]

n_samples = X.shape[0]
train_end = int(0.8 * n_samples)
val_end   = int(0.9 * n_samples)

test_times = all_times[val_end:]

t = test_times[idx]

lat = ds.latitude
lon = ds.longitude

da_pred = xr.DataArray(
                       y_hat_map,
                       dims=("latitude", "longitude"),
                       coords={
                               "latitude": lat,
                               "longitude": lon,
                               },
                       name="HWC_pred"
                       )

da_true = xr.DataArray(
                       y_true_map,
                       dims=("latitude", "longitude"),
                       coords={
                               "latitude": lat,
                               "longitude": lon,
                               },
                       name="HWC_true"
                       )

da_pred.attrs["time"] = str(t.values)
da_true.attrs["time"] = str(t.values)

In [None]:
(da_true - da_pred).plot(
                         cmap="coolwarm_r",
                         center=0
                         )