# Part 2: ENSO Phase Prediction

**Goal**: Predict the ENSO phase (Neutral / El Niño / La Niña) 1–3 months ahead using the ONI computed from ERSSTv6 data.

## 1. What is ENSO?

![enso](imgs/enso.png)

The **El Niño–Southern Oscillation (ENSO)** is a climate phenomenon in the Pacific Ocean characterized by periodic warming and cooling of sea surface temperatures.

- **El Niño**: Warmer-than-average SSTs in the central/eastern equatorial Pacific -> droughts in Australia/SE Asia, flooding in South America, warmer winters in North America.
- **La Niña**: Cooler-than-average SSTs in the same region -> opposite effects.
- **Neutral**: Near-average conditions.

ENSO is one of the most important drivers of global climate variability. Predicting it months in advance is a major challenge.

### The Oceanic Niño Index (ONI)
The **ONI** is the standard metric for ENSO monitoring. It is computed as:
1. Average SST anomalies over the **Niño 3.4 region** (5°S–5°N, 170°W–120°W)
2. Apply a **3-month centered running mean** to smooth out noise

ENSO phase is then defined by thresholds on ONI:
- **El Niño**: ONI ≥ +0.5°C
- **La Niña**: ONI ≤ −0.5°C
- **Neutral**: −0.5°C < ONI < +0.5°C

### Our Task
Given the **last 12 months of ONI values**, predict the **ENSO class** for each of the **next 3 months**. This is a **sequence -> classification** problem.

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import xarray as xr
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using device: {device}")

plt.rcParams["font.family"] = "monospace"

## 2. Load ERSSTv6 Data

We load the combined dataset that was prepared in Part 2A.

In [None]:
ds = xr.open_dataset("../data/processed/ersstv6_combined.nc")
print(f"Time: {str(ds.time.values[0])[:7]} to {str(ds.time.values[-1])[:7]} ({len(ds.time)} months)")
print(f"Grid: {len(ds.lat)} lat × {len(ds.lon)} lon")
print(f"Variables: {list(ds.data_vars)}")

## 3. Compute the Oceanic Niño Index (ONI)

The ONI is computed in two steps:
1. **Spatial average**: Area-weighted mean SSTA over the Niño 3.4 box (5°S–5°N, 170°W–120°W)
2. **Temporal smoothing**: 3-month centered running mean

In [None]:
# Niño 3.4 region: 5°S-5°N, 170°W-120°W
# In 0-360° longitude: 170°W = 190°, 120°W = 240°
nino34 = ds.ssta.sel(lat=slice(-5, 5), lon=slice(190, 240))

# Area-weighted spatial average
weights = np.cos(np.deg2rad(nino34.lat))
nino34_index = nino34.weighted(weights).mean(dim=['lat', 'lon'])

# 3-month centered running mean → ONI
oni = nino34_index.rolling(time=3, center=True).mean().dropna('time')

oni_series = oni.to_series()
oni_series.name = 'ONI'

print(f"ONI time series: {len(oni_series)} months")
print(f"Range: {oni_series.index[0].strftime('%Y-%m')} to {oni_series.index[-1].strftime('%Y-%m')}")
print(f"Mean: {oni_series.mean():.3f}, Std: {oni_series.std():.3f}")

In [None]:
fig, ax = plt.subplots(figsize=(15, 4))

ax.fill_between(oni_series.index, oni_series.values, 0.5,
                where=oni_series.values >= 0.5, color='red', alpha=0.4, label='El Niño')
ax.fill_between(oni_series.index, oni_series.values, -0.5,
                where=oni_series.values <= -0.5, color='blue', alpha=0.4, label='La Niña')
ax.plot(oni_series.index, oni_series.values, color='black', linewidth=0.8)
ax.axhline(0.5, color='red', linestyle='--', alpha=0.5)
ax.axhline(-0.5, color='blue', linestyle='--', alpha=0.5)
ax.axhline(0, color='gray', linewidth=0.5)

ax.set_title('Oceanic Niño Index (ONI) — Computed from ERSSTv6')
ax.set_ylabel('ONI (°C)')
ax.set_xlabel('Year')
ax.legend(loc='upper left')
plt.tight_layout()

## 4. Label Each Month by ENSO Phase

Using the standard ONI thresholds:
- **0 = Neutral**: −0.5 < ONI < 0.5
- **1 = El Niño**: ONI ≥ 0.5
- **2 = La Niña**: ONI ≤ −0.5

In [None]:
# Label each month
def oni_to_label(val):
    if val >= 0.5:
        return 1  # El Niño
    elif val <= -0.5:
        return 2  # La Niña
    else:
        return 0  # Neutral

labels = oni_series.apply(oni_to_label).values

phase_names = ['Neutral', 'El Niño', 'La Niña']
for i, name in enumerate(phase_names):
    count = (labels == i).sum()
    print(f"  {name}: {count} months ({100*count/len(labels):.1f}%)")

## 5. Prepare Sequences & DataLoaders

We create a sliding window dataset:
- **Input**: 12 consecutive months of ONI values
- **Target**: ENSO class at lead 1, 2, and 3 months ahead

First, we define the class that will hanlde our dataset for ingestion to the ML model

In [None]:
class ONIDataset(Dataset):
    def __init__(self, oni_values, labels, seq_len=12, n_leads=3):
        self.seq_len = seq_len
        self.n_leads = n_leads
        self.samples = []
        self.targets = []
        
        for i in range(len(oni_values) - seq_len - n_leads):
            x = oni_values[i : i + seq_len]
            y = [labels[i + seq_len + lead] for lead in range(n_leads)]
            self.samples.append(x)
            self.targets.append(y)
        
        self.samples = torch.tensor(np.array(self.samples), dtype=torch.float32)
        self.targets = torch.tensor(np.array(self.targets), dtype=torch.long)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx], self.targets[idx]

We can now do the split of our dataset. It is a good practise to have 3 sets: training, validation and testing.

In [None]:
SEQ_LEN = 12  # 12 months of ONI history
N_LEADS = 3   # predict 1, 2, 3 months ahead

# 70/15/15 split

total_samples = len(oni_series) - SEQ_LEN - N_LEADS
train_size = int(0.7 * total_samples)
val_size = int(0.15 * total_samples)
test_size = total_samples - train_size - val_size

train_end = train_size + SEQ_LEN + N_LEADS
val_start = train_size
val_end = val_start + val_size + SEQ_LEN + N_LEADS
test_start = val_start + val_size

train_dataset = ONIDataset(oni_series[:train_end], labels[:train_end], SEQ_LEN, N_LEADS)
val_dataset = ONIDataset(oni_series[val_start:val_end], labels[val_start:val_end], SEQ_LEN, N_LEADS)
test_dataset = ONIDataset(oni_series[test_start:], labels[test_start:], SEQ_LEN, N_LEADS)

BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples:   {len(val_dataset)}")
print(f"Test samples:  {len(test_dataset)}")
print(f"\nSample input shape: {train_dataset[0][0].shape}  (12 months of ONI)")
print(f"Sample target: {train_dataset[0][1]}  (ENSO class at leads 1,2,3)")

## 6. Models

We compare three architectures on the 1D ONI sequence:

1. **Linear**: Simplest baseline
2. **MLP**: Feedforward network that can learn non-linear patterns
3. **CNN**: 1D Convolutional network

In [None]:
class LinearModel(nn.Module):
    def __init__(self, seq_len, n_classes=3, n_leads=3):
        super().__init__()
        self.heads = nn.ModuleList([nn.Linear(seq_len, n_classes) for _ in range(n_leads)])
    
    def forward(self, x):
        return tuple(head(x) for head in self.heads)


class MLPModel(nn.Module):
    def __init__(self, seq_len, hidden=64, n_classes=3, n_leads=3):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(seq_len, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 32),
            nn.ReLU(),
        )
        self.heads = nn.ModuleList([nn.Linear(32, n_classes) for _ in range(n_leads)])
    
    def forward(self, x):
        features = self.backbone(x)
        return tuple(head(features) for head in self.heads)


class CNN1DModel(nn.Module):
    def __init__(self, seq_len, n_classes=3, n_leads=3):
        super().__init__()
        # Input: (Batch, 1, Seq_Len) -> simple 1D Conv
        self.features = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
            nn.Flatten()
        )
        
        self.flatten_dim = 32 * (seq_len // 4) 
        
        self.heads = nn.ModuleList([nn.Linear(self.flatten_dim, n_classes) for _ in range(n_leads)])
    
    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.features(x)
        return tuple(head(x) for head in self.heads)


## 7. Training with PyTorch Lightning

The total loss is the sum of Cross-Entropy losses across all three lead times. We use Early Stopping to prevent overfitting.

In [None]:
class ENSOPredictor(pl.LightningModule):
    def __init__(self, model, learning_rate=0.001):
        super().__init__()
        self.model = model
        self.lr = learning_rate
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def _compute_loss(self, batch):
        x, y = batch
        outputs = self(x)
        loss = sum(self.criterion(outputs[i], y[:, i]) for i in range(len(outputs)))
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._compute_loss(batch)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._compute_loss(batch)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

### 7.1 Linear Model

In [None]:
linear_model = LinearModel(SEQ_LEN)
pl_linear = ENSOPredictor(linear_model)

early_stop = EarlyStopping(monitor='val_loss', patience=5, mode='min', verbose=False)

trainer_linear = pl.Trainer(
    max_epochs=50,
    callbacks=[early_stop],
    enable_progress_bar=True,
    logger=False,
    enable_checkpointing=False
)

trainer_linear.fit(pl_linear, train_loader, val_loader)
trained_models = {}
results = {}
trained_models['Linear'] = linear_model

# Evaluate
linear_model.eval()
correct = [0, 0, 0]
total = 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = linear_model(data)
        total += data.size(0)
        for i in range(3):
            _, predicted = torch.max(outputs[i], 1)
            correct[i] += (predicted == target[:, i]).sum().item()

accs = [100 * c / total for c in correct]
results['Linear'] = accs
print(f"Linear Results: Lead 1: {accs[0]:.1f}% | Lead 2: {accs[1]:.1f}% | Lead 3: {accs[2]:.1f}%")

### 7.2 MLP Model

In [None]:
mlp_model = MLPModel(SEQ_LEN)
pl_mlp = ENSOPredictor(mlp_model)

early_stop = EarlyStopping(monitor='val_loss', patience=10, mode='min', verbose=False)

trainer_mlp = pl.Trainer(
    max_epochs=200,
    callbacks=[early_stop],
    enable_progress_bar=True,
    logger=False,
    enable_checkpointing=False
)

trainer_mlp.fit(pl_mlp, train_loader, val_loader)
trained_models['MLP'] = mlp_model

# Evaluate
mlp_model.eval()
correct = [0, 0, 0]
total = 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = mlp_model(data)
        total += data.size(0)
        for i in range(3):
            _, predicted = torch.max(outputs[i], 1)
            correct[i] += (predicted == target[:, i]).sum().item()

accs = [100 * c / total for c in correct]
results['MLP'] = accs
print(f"MLP Results: Lead 1: {accs[0]:.1f}% | Lead 2: {accs[1]:.1f}% | Lead 3: {accs[2]:.1f}%")

### 7.3 CNN 1D Model

![cnn1d](imgs/Conv1D.gif)

In [None]:
cnn_model = CNN1DModel(SEQ_LEN)
pl_cnn = ENSOPredictor(cnn_model)

early_stop = EarlyStopping(monitor='val_loss', patience=10, mode='min', verbose=False)

trainer_cnn = pl.Trainer(
    max_epochs=200,
    callbacks=[early_stop],
    enable_progress_bar=True,
    logger=False,
    enable_checkpointing=False
)

trainer_cnn.fit(pl_cnn, train_loader, val_loader)
trained_models['CNN'] = cnn_model

# Evaluate
cnn_model.eval()
correct = [0, 0, 0]
total = 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = cnn_model(data)
        total += data.size(0)
        for i in range(3):
            _, predicted = torch.max(outputs[i], 1)
            correct[i] += (predicted == target[:, i]).sum().item()

accs = [100 * c / total for c in correct]
results['CNN'] = accs
print(f"CNN Results: Lead 1: {accs[0]:.1f}% | Lead 2: {accs[1]:.1f}% | Lead 3: {accs[2]:.1f}%")

## 8. Results Comparison

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=True)
colors = ['gray', 'blue', 'green']
lead_names = ['1-month lead', '2-month lead', '3-month lead']

for i, ax in enumerate(axes):
    model_names = list(results.keys())
    accs = [results[name][i] for name in model_names]
    bars = ax.bar(model_names, accs, color=colors)
    ax.axhline(33.3, color='red', linestyle='--', alpha=0.5, label='Random (33%)')
    ax.set_title(lead_names[i])
    ax.set_ylabel('Accuracy (%)' if i == 0 else '')
    ax.set_ylim(0, 100)
    ax.legend()
    for bar, acc in zip(bars, accs):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{acc:.1f}%', ha='center', va='bottom', fontsize=10)

plt.suptitle('ENSO Phase Prediction: Model Comparison', fontsize=14)
plt.tight_layout()

In [None]:
# Confusion matrices for the best model
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

best_name = max(results, key=lambda k: sum(results[k]))
best_model = trained_models[best_name]
print(f"Best model: {best_name}")

best_model.eval()
all_preds = [[], [], []]
all_targets = [[], [], []]

with torch.no_grad():
    for data, target in test_loader:
        outputs = best_model(data)
        for i in range(3):
            _, predicted = torch.max(outputs[i], 1)
            all_preds[i].extend(predicted.numpy())
            all_targets[i].extend(target[:, i].numpy())

fig, axes = plt.subplots(1, 3, figsize=(16, 4))
for i, ax in enumerate(axes):
    cm = confusion_matrix(all_targets[i], all_preds[i], labels=[0, 1, 2])
    disp = ConfusionMatrixDisplay(cm, display_labels=phase_names)
    disp.plot(ax=ax, cmap='Blues', colorbar=False)
    ax.set_title(f"Lead {i+1} ({lead_names[i]})")

plt.suptitle(f'Confusion Matrices: {best_name}', fontsize=14)
plt.tight_layout()

## 9. Live Forecast

Let's use the last 12 months of our ONI time series to predict ENSO for the upcoming months.

In [None]:
# Use the last 12 months of ONI as input
last_12 = oni_series[-SEQ_LEN:]
last_date = oni_series.index[-1]

print(f"Input: ONI from {oni_series.index[-SEQ_LEN].strftime('%Y-%m')} to {last_date.strftime('%Y-%m')}")
print(f"Last 12 ONI values: {np.round(last_12, 2)}")

# Predict
best_model.eval()
with torch.no_grad():
    inp = torch.tensor(last_12.values, dtype=torch.float32).unsqueeze(0)
    outputs = best_model(inp)

# Display forecast
print(f"\n{'='*50}")
print(f"ENSO Forecast (from {best_name} model)")
print(f"{'='*50}")

for i in range(N_LEADS):
    probs = torch.softmax(outputs[i], dim=1)[0]
    pred_class = outputs[i].argmax().item()
    forecast_month = last_date + pd.DateOffset(months=i+1)
    print(f"\n  {forecast_month.strftime('%B %Y')}:")
    for j, name in enumerate(phase_names):
        marker = " ◀" if j == pred_class else ""
        print(f"    {name}: {probs[j]:.1%}{marker}")

## 10. Discussion

### Key Takeaways
- **ONI as a feature**: The ONI time series captures the essential ENSO signal. A single number per month computed from gridded SST data
- **Sequence -> Classification**: Framing ENSO prediction as classification avoids the difficulties of exact regression
- **Lead time decay**: Accuracy decreases with longer lead times
- **Real-world impact**: ENSO forecasts inform agriculture, disaster preparedness, and water resource management worldwide

### Next Steps
- Use the **full 2D SST maps** instead of just ONI (needs a 2D CNN, more powerful but more complex)

## 11. From 1D to 2D: SST as a predictor

We now advance to using the full **Sea Surface Temperature (SST) maps** as input, rather than just the simplified ONI index.

We will compare two approaches:
1.  **2D CNN Classifier**: Takes the sequence of 12 SST maps and directly predicts the ENSO class (0, 1, 2).
2.  **U-Net Map Predictor**: Takes the sequence of 12 SST maps and predicts the **Next Month's SST Map**. We then calculate the ONI from the predicted map to determine the ENSO Phase.

In [None]:
# SST Dataset for 2D Maps
class SSTDataset(Dataset):
    def __init__(self, sst_data, labels, seq_len=12, n_leads=3):
        self.seq_len = seq_len
        self.n_leads = n_leads
        # Fill NaNs with 0.0 (Land)
        self.sst_data = torch.tensor(np.nan_to_num(sst_data), dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        max_idx_sst = len(self.sst_data) - self.seq_len - self.n_leads
        max_idx_labels = len(self.labels) - self.seq_len - self.n_leads
        max_idx = min(max_idx_sst, max_idx_labels)
        return max(0, max_idx + 1)

    def __getitem__(self, idx):
        # Input: (Seq_Len, Lat, Lon)
        x = self.sst_data[idx : idx + self.seq_len]
        
        # Target Class: (n_leads)
        target_start = idx + self.seq_len
        target_end = target_start + self.n_leads
        y = self.labels[target_start : target_end]
        
        # if len(y) < self.n_leads:
        #     padding = torch.full((self.n_leads - len(y),), y[-1] if len(y)>0 else 0, dtype=torch.long)
        #     y = torch.cat([y, padding])
        
        # Target Map Sequence: (n_leads, Lat, Lon)
        y_map = self.sst_data[target_start : target_end]
        
        # Safety check for map
        if len(y_map) < self.n_leads:
             # Pad with last frame
             pad_frames = self.n_leads - len(y_map)
             last_frame = y_map[-1].unsqueeze(0) if len(y_map) > 0 else torch.zeros(1, *self.sst_data.shape[1:])
             padding = last_frame.repeat(pad_frames, 1, 1)
             y_map = torch.cat([y_map, padding])
        
        return x, y, y_map

sst_values = ds.ssta.values
train_dataset_2d = SSTDataset(sst_values[:train_end], labels[:train_end], SEQ_LEN, N_LEADS)
val_dataset_2d = SSTDataset(sst_values[val_start:val_end], labels[val_start:val_end], SEQ_LEN, N_LEADS)
test_dataset_2d = SSTDataset(sst_values[test_start:], labels[test_start:], SEQ_LEN, N_LEADS)

train_loader_2d = DataLoader(train_dataset_2d, batch_size=16, shuffle=True)
val_loader_2d = DataLoader(val_dataset_2d, batch_size=16)
test_loader_2d = DataLoader(test_dataset_2d, batch_size=16)

print(f"2D Sample Shape: {train_dataset_2d[0][0].shape}")

### 11.1 2D CNN Classifier

![cnn](imgs/cnn2d.jpg)

Convolutional Neural Networks (CNNs) are deep learning models specialized for processing grid-like data (e.g., images). Unlike standard networks that flatten input immediately, CNNs preserve spatial structure to detect patterns like edges and shapes.

#### **Key Components**

1. **Convolutional Layers:** Slide small filters (kernels) across the image to extract features. This allows the model to learn patterns regardless of their position.
2. **Pooling Layers:** Downsample the image (e.g., Max Pooling) to reduce dimensions and computational cost.
3. **Fully Connected Layers:** Standard dense layers at the end that use the extracted features to make a final prediction.

![conv](imgs/same_padding_no_strides.gif)

In [None]:
class CNN2DModel(nn.Module):
    def __init__(self, seq_len, lat_dim, lon_dim, n_classes=3, n_leads=3, dense_hidden=128):
        super().__init__()
        # Input: (Batch, Seq_Len, Lat, Lon) -> Treat Seq_Len as Channels
        self.features = nn.Sequential(
            nn.Conv2d(seq_len, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )
        
        with torch.no_grad():
            dummy = torch.zeros(1, seq_len, lat_dim, lon_dim)
            out = self.features(dummy)
            self.flat_dim = out.shape[1]
            
        # Enhanced Fully Connected Block
        self.fc_block = nn.Sequential(
            nn.Linear(self.flat_dim, dense_hidden),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(dense_hidden, dense_hidden // 2),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Heads take input from the last FC layer
        self.heads = nn.ModuleList([nn.Linear(dense_hidden // 2, n_classes) for _ in range(n_leads)])
        
    def forward(self, x):
        x = self.features(x)
        x = self.fc_block(x)
        return tuple(head(x) for head in self.heads)

lat_dim = len(ds.lat)
lon_dim = len(ds.lon)

cnn2d = CNN2DModel(SEQ_LEN, lat_dim, lon_dim)
pl_cnn2d = ENSOPredictor(cnn2d)

class ENSO2DClassifier(ENSOPredictor):
    def _compute_loss(self, batch):
        x, y, _ = batch # Ignore y_map
        outputs = self(x)
        loss = sum(self.criterion(outputs[i], y[:, i]) for i in range(len(outputs)))
        return loss

pl_cnn2d = ENSO2DClassifier(cnn2d)

early_stop_2d = EarlyStopping(monitor='val_loss', patience=10, mode='min', verbose=False)
trainer_2d = pl.Trainer(max_epochs=200, callbacks=[early_stop_2d], enable_progress_bar=True, check_val_every_n_epoch=2)
trainer_2d.fit(pl_cnn2d, train_loader_2d, val_loader_2d)

### 11.2 U-Net Classifier (Map -> Class)

![unet](imgs/unet.jpg)

In [None]:
class UNetClassifier(nn.Module):
    def __init__(self, in_channels, n_classes=3, n_leads=3, dense_hidden=64):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU(),
                nn.Conv2d(out_c, out_c, 3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU()
            )

        # Encoder
        self.enc1 = conv_block(in_channels, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = conv_block(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = conv_block(64, 128)
        
        # Decoder
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec2 = conv_block(128, 64)
        self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec1 = conv_block(64, 32)
        
        # Classification Head
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.fc_block = nn.Sequential(
            nn.Linear(32, dense_hidden),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        self.heads = nn.ModuleList([nn.Linear(dense_hidden, n_classes) for _ in range(n_leads)])

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        
        b = self.bottleneck(p2)
        
        d2 = self.up2(b)
        if d2.shape != e2.shape: d2 = nn.functional.interpolate(d2, size=e2.shape[2:])
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        if d1.shape != e1.shape: d1 = nn.functional.interpolate(d1, size=e1.shape[2:])
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1(d1)
        
        # Classification Branch
        out = self.global_pool(d1)
        out = out.view(out.size(0), -1)
        out = self.fc_block(out)
        
        return tuple(head(out) for head in self.heads)

# U-Net is now a Classifier, so we use CrossEntropy
unet = UNetClassifier(in_channels=SEQ_LEN)
pl_unet = ENSO2DClassifier(unet) 

early_stop_unet = EarlyStopping(monitor='val_loss', patience=10, mode='min', verbose=False)
trainer_unet = pl.Trainer(max_epochs=200, callbacks=[early_stop_unet], enable_progress_bar=True, check_val_every_n_epoch=2)
trainer_unet.fit(pl_unet, train_loader_2d, val_loader_2d)


In [None]:
# 11.3 Results Comparison (1D vs 2D Models)
models_2d = {'CNN2D': cnn2d, 'UNet': unet}
results_2d = {}

print("Evaluating 2D Models...")
for name, model in models_2d.items():
    model.eval()
    correct = [0, 0, 0]
    total = 0
    with torch.no_grad():
        for data, target, _ in test_loader_2d: # Ignore y_map
            outputs = model(data)
            total += data.size(0)
            for i in range(3):
                _, predicted = torch.max(outputs[i], 1)
                correct[i] += (predicted == target[:, i]).sum().item()
    
    accs = [100 * c / total for c in correct]
    results_2d[name] = accs
    print(f"{name} Results: Lead 1: {accs[0]:.1f}% | Lead 2: {accs[1]:.1f}% | Lead 3: {accs[2]:.1f}%")

# Plot Comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)
lead_names = ['Lead 1 Month', 'Lead 2 Months', 'Lead 3 Months']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'] # Blue, Orange, Green, Red, Purple

# Combine 1D results if available
all_results = results.copy() if 'results' in locals() else {}
all_results.update(results_2d)

model_names = list(all_results.keys())

for i, ax in enumerate(axes):
    accuracies = [all_results[m][i] for m in model_names]
    bars = ax.bar(model_names, accuracies, color=colors[:len(model_names)])
    ax.set_title(lead_names[i])
    ax.set_ylim(0, 100)
    ax.axhline(33.3, color='grey', linestyle='--', label='Random Chance')
    ax.legend()
    
    # Add labels
    for bar, acc in zip(bars, accuracies):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{acc:.1f}%', ha='center', va='bottom')

plt.suptitle("ENSO Prediction Accuracy: 1D (ONI) vs 2D (SST Maps)", fontsize=16)
plt.tight_layout()

## U-Net for Next-Month SST Map Prediction

In this section, we will implement a **U-Net** to predict the **Scale Sea Surface Temperature (SST)** map for the *next month* (Lead 1) based on the past 12 months of anomaly maps.

We modify the problem from a sequence-to-sequence prediction (predicting 3 months) to a **Next-Step Prediction** (predicting 1 month). This simplification allows the model to focus on the immediate future evolution, potentially improving its skill.

### 1. The U-Net Architecture
We use a deeper U-Net architecture with 3 encoding steps, increasing the channel depth to 256 at the bottleneck. This allows the network to capture complex spatial patterns.

In [None]:
# Create Land Mask (1 for Sea, 0 for Land/NaN)
# We use the first time step of the dataset to identify NaN regions
land_mask_xr = ds['sst'].isel(time=0).notnull()
land_mask = torch.tensor(land_mask_xr.values, dtype=torch.float32)

fig, ax = plt.subplots(figsize=(10, 4))
im = land_mask_xr.plot(
    ax=ax, 
    cmap='gray', 
    add_colorbar=True,
    cbar_kwargs={'label': 'Mask Value'}
)

ax.set_title("Valid Ocean Mask (White=Sea, Black=Land)")

In [None]:
class UNetMapModel(nn.Module):
    def __init__(self, in_channels, out_channels=1): # Output 1 channel (Next Month)
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU(),
                nn.Conv2d(out_c, out_c, 3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU()
            )

        # Encoder
        self.enc1 = conv_block(in_channels, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = conv_block(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = conv_block(64, 128)
        self.pool3 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = conv_block(128, 256)
        
        # Decoder
        self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec3 = conv_block(256, 128)
        
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec2 = conv_block(128, 64)
        
        self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec1 = conv_block(64, 32)
        
        self.final = nn.Conv2d(32, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        
        b = self.bottleneck(p3)
        
        d3 = self.up3(b)
        if d3.shape != e3.shape: d3 = nn.functional.interpolate(d3, size=e3.shape[2:])
        d3 = torch.cat((d3, e3), dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        if d2.shape != e2.shape: d2 = nn.functional.interpolate(d2, size=e2.shape[2:])
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        if d1.shape != e1.shape: d1 = nn.functional.interpolate(d1, size=e1.shape[2:])
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1(d1)
        
        return self.final(d1)

### 2. PyTorch Lightning Module
We wrap the U-Net model with PyTorch Lightning to handle training loops, optimization, and logging.

In [None]:
class SSTMapPredictor(pl.LightningModule):
    def __init__(self, model, mask, lr=1e-3):
        super().__init__()
        self.model = model
        self.register_buffer('mask', mask) # Mask moves to GPU with model
        self.lr = lr

    def forward(self, x):
        return self.model(x)
    
    def masked_mse_loss(self, pred, target):
        # Squared Error
        sq_error = (pred - target) ** 2
        # Mask out land pixels
        masked_error = sq_error * self.mask
        # Mean only over valid pixels
        loss = masked_error.sum() / self.mask.sum() / pred.shape[0] # Divide by batch size too
        return loss

    def training_step(self, batch, batch_idx):
        x, _, y_seq = batch
        y_target = y_seq[:, 0, :, :].unsqueeze(1) 
        pred_map = self(x)
        loss = self.masked_mse_loss(pred_map, y_target)
        self.log("train_mse", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _, y_seq = batch
        y_target = y_seq[:, 0, :, :].unsqueeze(1)
        pred_map = self(x)
        loss = self.masked_mse_loss(pred_map, y_target)
        self.log("val_mse", loss, prog_bar=True)
        return loss
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

### 3. Training the Model
We initialize the model with `out_channels=1` and train using the standard PyTorch Lightning Trainer with Early Stopping.

In [None]:
unet_map = UNetMapModel(in_channels=SEQ_LEN, out_channels=1) 
pl_unet_map = SSTMapPredictor(unet_map, mask=land_mask)

early_stop_map = EarlyStopping(monitor='val_mse', patience=10, mode='min', verbose=False)
trainer_map = pl.Trainer(max_epochs=200, callbacks=[early_stop_map], enable_progress_bar=True, check_val_every_n_epoch=2)
trainer_map.fit(pl_unet_map, train_loader_2d, val_loader_2d)

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature

pl_unet_map.eval()
test_sample_idx = 0

with torch.no_grad():
    x, _, y_seq = test_dataset_2d[test_sample_idx]
    # x: (12, Lat, Lon), y_seq: (3, Lat, ]
    pred_map = pl_unet_map(x.unsqueeze(0)).squeeze(0).squeeze(0) # (Lat, Lon)
    target_map = y_seq[0] # Next Month (Lead 1)

# Setup Projection
proj = ccrs.PlateCarree(central_longitude=180)
fig, axes = plt.subplots(1, 3, figsize=(18, 5), subplot_kw={'projection': proj})

def plot_sst(ax, data_tensor, title):
    da = xr.DataArray(data_tensor.numpy(), coords={'lat': ds.lat, 'lon': ds.lon}, dims=('lat', 'lon'))
    im = da.plot(ax=ax, transform=ccrs.PlateCarree(), 
                 cmap='RdBu_r', vmin=-2.5, vmax=2.5, add_colorbar=False)
    ax.add_feature(cfeature.COASTLINE, linewidth=0.8)
    ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.6)
    ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.5)
    ax.set_title(title, fontsize=12)
    return im

# 1. Input (Last Month)
plot_sst(axes[0], x[-1], "Input: Last Month (T-0)")

# 2. Target (Next Month)
plot_sst(axes[1], target_map, "Target: Next Month (Lead 1)")

# 3. Prediction
im = plot_sst(axes[2], pred_map, "Prediction: Next Month (Lead 1)")

# Colorbar
cbar_ax = fig.add_axes([0.92, 0.2, 0.02, 0.6])
cbar = fig.colorbar(im, cax=cbar_ax, label='SST Anomaly (°C)')