# Part 2: ENSO Phase Prediction

**Goal**: Predict the ENSO phase (Neutral / El Niño / La Niña) 1–3 months ahead using the ONI computed from ERSSTv6 data.

## 1. What is ENSO?

![enso](imgs/enso.png)

The **El Niño–Southern Oscillation (ENSO)** is a climate phenomenon in the Pacific Ocean characterized by periodic warming and cooling of sea surface temperatures.

- **El Niño**: Warmer-than-average SSTs in the central/eastern equatorial Pacific -> droughts in Australia/SE Asia, flooding in South America, warmer winters in North America.
- **La Niña**: Cooler-than-average SSTs in the same region -> opposite effects.
- **Neutral**: Near-average conditions.

ENSO is one of the most important drivers of global climate variability. Predicting it months in advance is a major challenge.

### The Oceanic Niño Index (ONI)
The **ONI** is the standard metric for ENSO monitoring. It is computed as:
1. Average SST anomalies over the **Niño 3.4 region** (5°S–5°N, 170°W–120°W)
2. Apply a **3-month centered running mean** to smooth out noise

ENSO phase is then defined by thresholds on ONI:
- **El Niño**: ONI ≥ +0.5°C
- **La Niña**: ONI ≤ −0.5°C
- **Neutral**: −0.5°C < ONI < +0.5°C

### Our Task
Given the **last 12 months of ONI values**, predict the **ENSO class** for each of the **next 3 months**. This is a **sequence -> classification** problem.

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import xarray as xr
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using device: {device}")

## 2. Load ERSSTv6 Data

We load the combined dataset that was prepared in Part 2A.

In [None]:
ds = xr.open_dataset("../data/processed/ersstv6_combined.nc")
print(f"Time: {str(ds.time.values[0])[:7]} to {str(ds.time.values[-1])[:7]} ({len(ds.time)} months)")
print(f"Grid: {len(ds.lat)} lat × {len(ds.lon)} lon")
print(f"Variables: {list(ds.data_vars)}")

## 3. Compute the Oceanic Niño Index (ONI)

The ONI is computed in two steps:
1. **Spatial average**: Area-weighted mean SSTA over the Niño 3.4 box (5°S–5°N, 170°W–120°W)
2. **Temporal smoothing**: 3-month centered running mean

In [None]:
# Niño 3.4 region: 5°S-5°N, 170°W-120°W
# In 0-360° longitude: 170°W = 190°, 120°W = 240°
nino34 = ds.ssta.sel(lat=slice(-5, 5), lon=slice(190, 240))

# Step 1: Area-weighted spatial average
weights = np.cos(np.deg2rad(nino34.lat))
nino34_index = nino34.weighted(weights).mean(dim=['lat', 'lon'])

# Step 2: 3-month centered running mean → ONI
oni = nino34_index.rolling(time=3, center=True).mean().dropna('time')

# Convert to pandas
oni_series = oni.to_series()
oni_series.name = 'ONI'

print(f"ONI time series: {len(oni_series)} months")
print(f"Range: {oni_series.index[0].strftime('%Y-%m')} to {oni_series.index[-1].strftime('%Y-%m')}")
print(f"Mean: {oni_series.mean():.3f}, Std: {oni_series.std():.3f}")

In [None]:
fig, ax = plt.subplots(figsize=(15, 4))

ax.fill_between(oni_series.index, oni_series.values, 0.5,
                where=oni_series.values >= 0.5, color='red', alpha=0.4, label='El Niño')
ax.fill_between(oni_series.index, oni_series.values, -0.5,
                where=oni_series.values <= -0.5, color='blue', alpha=0.4, label='La Niña')
ax.plot(oni_series.index, oni_series.values, color='black', linewidth=0.8)
ax.axhline(0.5, color='red', linestyle='--', alpha=0.5)
ax.axhline(-0.5, color='blue', linestyle='--', alpha=0.5)
ax.axhline(0, color='gray', linewidth=0.5)

ax.set_title('Oceanic Niño Index (ONI) — Computed from ERSSTv6')
ax.set_ylabel('ONI (°C)')
ax.set_xlabel('Year')
ax.legend(loc='upper left')
plt.tight_layout()
plt.show()

## 4. Label Each Month by ENSO Phase

Using the standard ONI thresholds:
- **0 = Neutral**: −0.5 < ONI < 0.5
- **1 = El Niño**: ONI ≥ 0.5
- **2 = La Niña**: ONI ≤ −0.5

In [None]:
# Label each month
def oni_to_label(val):
    if val >= 0.5:
        return 1  # El Niño
    elif val <= -0.5:
        return 2  # La Niña
    else:
        return 0  # Neutral

labels = oni_series.apply(oni_to_label).values

phase_names = ['Neutral', 'El Niño', 'La Niña']
for i, name in enumerate(phase_names):
    count = (labels == i).sum()
    print(f"  {name}: {count} months ({100*count/len(labels):.1f}%)")

## 5. Prepare Sequences & DataLoaders

We create a sliding window dataset:
- **Input**: 12 consecutive months of ONI values
- **Target**: ENSO class at lead 1, 2, and 3 months ahead

As in Part 1, we split **by time** (chrono cross-validation): first 80% for training, last 20% for testing.

In [None]:
SEQ_LEN = 12  # 12 months of ONI history
N_LEADS = 3   # predict 1, 2, 3 months ahead

class ONIDataset(Dataset):
    def __init__(self, oni_values, labels, seq_len=12, n_leads=3):
        self.seq_len = seq_len
        self.n_leads = n_leads
        self.samples = []
        self.targets = []
        
        for i in range(len(oni_values) - seq_len - n_leads):
            x = oni_values[i : i + seq_len]
            y = [labels[i + seq_len + lead] for lead in range(n_leads)]
            self.samples.append(x)
            self.targets.append(y)
        
        self.samples = torch.tensor(np.array(self.samples), dtype=torch.float32)
        self.targets = torch.tensor(np.array(self.targets), dtype=torch.long)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx], self.targets[idx]

# Prepare data
oni_values = oni_series.values.astype(np.float32)

# 80/20 chrono split
total_samples = len(oni_values) - SEQ_LEN - N_LEADS
train_size = int(0.8 * total_samples)

train_end = train_size + SEQ_LEN + N_LEADS
train_dataset = ONIDataset(oni_values[:train_end], labels[:train_end], SEQ_LEN, N_LEADS)
test_dataset = ONIDataset(oni_values[train_size:], labels[train_size:], SEQ_LEN, N_LEADS)

BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples:  {len(test_dataset)}")
print(f"\nSample input shape: {train_dataset[0][0].shape}  (12 months of ONI)")
print(f"Sample target: {train_dataset[0][1]}  (ENSO class at leads 1,2,3)")

## 6. Models

We compare three architectures on the 1D ONI sequence:

1. **Linear**: Simplest baseline — weighted sum of 12 ONI values → 3×3 outputs
2. **MLP**: Feedforward network that can learn non-linear patterns
3. **LSTM**: Recurrent network that explicitly models temporal order

In [None]:
N_CLASSES = 3

class LinearModel(nn.Module):
    def __init__(self, seq_len, n_classes=3, n_leads=3):
        super().__init__()
        self.heads = nn.ModuleList([nn.Linear(seq_len, n_classes) for _ in range(n_leads)])
    
    def forward(self, x):
        return tuple(head(x) for head in self.heads)


class MLPModel(nn.Module):
    def __init__(self, seq_len, hidden=64, n_classes=3, n_leads=3):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(seq_len, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 32),
            nn.ReLU(),
        )
        self.heads = nn.ModuleList([nn.Linear(32, n_classes) for _ in range(n_leads)])
    
    def forward(self, x):
        features = self.backbone(x)
        return tuple(head(features) for head in self.heads)


class LSTMModel(nn.Module):
    def __init__(self, hidden=64, n_classes=3, n_leads=3):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=hidden, batch_first=True)
        self.heads = nn.ModuleList([nn.Linear(hidden, n_classes) for _ in range(n_leads)])
    
    def forward(self, x):
        x = x.unsqueeze(-1)
        out, _ = self.lstm(x)
        last = out[:, -1, :]
        return tuple(head(last) for head in self.heads)

## 7. Training with PyTorch Lightning

The total loss is the sum of Cross-Entropy losses across all three lead times. We use Early Stopping to prevent overfitting.

In [None]:
class ENSOPredictor(pl.LightningModule):
    def __init__(self, model, learning_rate=0.001):
        super().__init__()
        self.model = model
        self.lr = learning_rate
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def _compute_loss(self, batch):
        x, y = batch
        outputs = self(x)
        loss = sum(self.criterion(outputs[i], y[:, i]) for i in range(len(outputs)))
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._compute_loss(batch)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._compute_loss(batch)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

In [None]:
# Train all models
model_configs = {
    'Linear': LinearModel(SEQ_LEN),
    'MLP': MLPModel(SEQ_LEN),
    'LSTM': LSTMModel(),
}

trained_models = {}
results = {}

for name, model_arch in model_configs.items():
    print(f"\n{'='*50}")
    print(f"Training {name}")
    print(f"{'='*50}")
    
    pl_model = ENSOPredictor(model_arch, learning_rate=0.001)
    
    early_stop = EarlyStopping(
        monitor='train_loss',
        patience=10,
        mode='min',
        verbose=False
    )
    
    trainer = pl.Trainer(
        max_epochs=100,
        callbacks=[early_stop],
        enable_progress_bar=True,
        logger=False,
        enable_checkpointing=False
    )
    
    trainer.fit(pl_model, train_loader)
    trained_models[name] = model_arch
    
    # Evaluate on test set
    model_arch.eval()
    correct = [0, 0, 0]
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            outputs = model_arch(data)
            total += data.size(0)
            for i in range(3):
                _, predicted = torch.max(outputs[i], 1)
                correct[i] += (predicted == target[:, i]).sum().item()
    
    accs = [100 * c / total for c in correct]
    results[name] = accs
    print(f"  Lead 1: {accs[0]:.1f}%  |  Lead 2: {accs[1]:.1f}%  |  Lead 3: {accs[2]:.1f}%")

## 8. Results Comparison

In [None]:
# Bar chart comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=True)
colors = ['gray', 'blue', 'green']
lead_names = ['1-month lead', '2-month lead', '3-month lead']

for i, ax in enumerate(axes):
    model_names = list(results.keys())
    accs = [results[name][i] for name in model_names]
    bars = ax.bar(model_names, accs, color=colors)
    ax.axhline(33.3, color='red', linestyle='--', alpha=0.5, label='Random (33%)')
    ax.set_title(lead_names[i])
    ax.set_ylabel('Accuracy (%)' if i == 0 else '')
    ax.set_ylim(0, 100)
    ax.legend()
    for bar, acc in zip(bars, accs):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{acc:.1f}%', ha='center', va='bottom', fontsize=10)

plt.suptitle('ENSO Phase Prediction: Model Comparison', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Confusion matrices for the best model
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

best_name = max(results, key=lambda k: sum(results[k]))
best_model = trained_models[best_name]
print(f"Best model: {best_name}")

best_model.eval()
all_preds = [[], [], []]
all_targets = [[], [], []]

with torch.no_grad():
    for data, target in test_loader:
        outputs = best_model(data)
        for i in range(3):
            _, predicted = torch.max(outputs[i], 1)
            all_preds[i].extend(predicted.numpy())
            all_targets[i].extend(target[:, i].numpy())

fig, axes = plt.subplots(1, 3, figsize=(16, 4))
for i, ax in enumerate(axes):
    cm = confusion_matrix(all_targets[i], all_preds[i], labels=[0, 1, 2])
    disp = ConfusionMatrixDisplay(cm, display_labels=phase_names)
    disp.plot(ax=ax, cmap='Blues', colorbar=False)
    ax.set_title(f"Lead {i+1} ({lead_names[i]})")

plt.suptitle(f'Confusion Matrices: {best_name}', fontsize=14)
plt.tight_layout()
plt.show()

## 9. Live Forecast

Let's use the last 12 months of our ONI time series to predict ENSO for the upcoming months.

In [None]:
# Use the last 12 months of ONI as input
last_12 = oni_values[-SEQ_LEN:]
last_date = oni_series.index[-1]

print(f"Input: ONI from {oni_series.index[-SEQ_LEN].strftime('%Y-%m')} to {last_date.strftime('%Y-%m')}")
print(f"Last 12 ONI values: {np.round(last_12, 2)}")

# Predict
best_model.eval()
with torch.no_grad():
    inp = torch.tensor(last_12, dtype=torch.float32).unsqueeze(0)
    outputs = best_model(inp)

# Display forecast
print(f"\n{'='*50}")
print(f"ENSO Forecast (from {best_name} model)")
print(f"{'='*50}")

for i in range(N_LEADS):
    probs = torch.softmax(outputs[i], dim=1)[0]
    pred_class = outputs[i].argmax().item()
    forecast_month = last_date + pd.DateOffset(months=i+1)
    print(f"\n  {forecast_month.strftime('%B %Y')}:")
    for j, name in enumerate(phase_names):
        marker = " ◀" if j == pred_class else ""
        print(f"    {name}: {probs[j]:.1%}{marker}")

## 10. Discussion

### Key Takeaways
- **ONI as a feature**: The ONI time series captures the essential ENSO signal — a single number per month computed from gridded SST data
- **Sequence → Classification**: Framing ENSO prediction as classification avoids the difficulties of exact regression
- **Lead time decay**: Accuracy decreases with longer lead times — a fundamental limit of predictability
- **Real-world impact**: ENSO forecasts inform agriculture, disaster preparedness, and water resource management worldwide

### Next Steps
- Use the **full 2D SST maps** instead of just ONI (needs a 2D CNN — more powerful but more complex)
- Try **multi-month input sequences of SST maps** (3D: time × lat × lon)
- Use **WeatherBench2** for a more sophisticated global weather prediction approach