In [None]:
# Install torch_geometric
!pip install torch_geometric


In [None]:
# Import libraries
import torch
import xarray as xr
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import pickle
from torch_geometric.data import Data

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
import logging
from tqdm.notebook import tqdm  # Use notebook-specific tqdm for single-line updates
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Load the NetCDF file
nc_file1 = "/content/drive/MyDrive/ERA5_datasets/data_accum.nc"
nc_file2 = "/content/drive/MyDrive/ERA5_datasets/data_instant.nc"
ds1 = xr.open_dataset(nc_file1)
ds2 = xr.open_dataset(nc_file2)

ds = xr.merge([ds1, ds2], compat='override')
# Print dataset details
print(ds)

# Select variables
variables = ['tp', 't2m', 'd2m', 'u10', 'v10', 'msl', 'tcwv', 'cape', 'slhf', 'sshf']
data_arrays = [ds[var].values for var in variables]

# Stack variables into tensor [time, lat, lon, channels]
data = np.stack(data_arrays, axis=-1)  # Shape: [hours, lat, lon, 10]

# Normalize using StandardScaler (fit on train only)
hours = ds.valid_time.values
total_hours = len(hours)
train_end = int(0.8 * total_hours)
val_end = train_end + int(0.1 * total_hours)
scaler = StandardScaler()
train_data_reshaped = data[:train_end].reshape(-1, data.shape[-1])
scaler.fit(train_data_reshaped)
# Normalize in-place to save memory
data = scaler.transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)

# Store tp for labels
tp = ds['tp'].values

# Define dataset class
class ERA5Dataset(Dataset):
    def __init__(self, data, tp, indices, input_window=12, target_window=6, threshold=0.001):
        self.data = data
        self.tp = tp
        self.indices = indices
        self.input_window = input_window
        self.target_window = target_window
        self.threshold = threshold
        self.grid_size = (data.shape[1], data.shape[2])

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        t = self.indices[idx]
        input_data = self.data[t:t+self.input_window]
        future_tp = self.tp[t+self.input_window:t+self.input_window+self.target_window]
        rain_label = (future_tp.sum(axis=0) > self.threshold).astype(np.int64)
        if not np.all(np.isin(rain_label, [0, 1])):
            print(f"Warning: Non-binary labels detected at index {idx}: {np.unique(rain_label)}")
            rain_label = np.clip(rain_label, 0, 1)
        return (torch.tensor(input_data, dtype=torch.float32),
                torch.tensor(rain_label, dtype=torch.long))

# Create indices for each split
input_window = 12
target_window = 6
train_indices = list(range(0, train_end - input_window - target_window + 1))
val_indices = list(range(train_end, val_end - input_window - target_window + 1))
test_indices = list(range(val_end, total_hours - input_window - target_window + 1))

# Initialize datasets
train_dataset = ERA5Dataset(data, tp, train_indices, input_window, target_window)
val_dataset = ERA5Dataset(data, tp, val_indices, input_window, target_window)
test_dataset = ERA5Dataset(data, tp, test_indices, input_window, target_window)

# Create data loaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Compute rain percentage (sample-based to save memory)
rain_count = 0
total_nodes = 0
for idx in train_indices[:100]:
    future_tp = tp[idx+input_window:idx+input_window+target_window]
    rain_label = (future_tp.sum(axis=0) > 0.001).astype(int)
    rain_count += rain_label.sum()
    total_nodes += rain_label.size
rain_percentage = (rain_count / total_nodes) * 100 if total_nodes > 0 else 0
print(f"Approximate percentage of rain in training dataset: {rain_percentage:.2f}%")

In [None]:
# Updated CNNRainPredictor with third Conv3d layer
class CNNRainPredictor(nn.Module):
    def __init__(self, input_channels, hidden_dim, output_dim, grid_size):
        super(CNNRainPredictor, self).__init__()
        self.grid_size = grid_size
        self.conv1 = nn.Conv3d(input_channels, hidden_dim, kernel_size=(3, 3, 3), padding=1)
        self.conv2 = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=(3, 3, 3), padding=1)
        self.conv3 = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=(3, 3, 3), padding=1)  # New layer
        self.fc = nn.Conv2d(hidden_dim, output_dim, kernel_size=1)
        self.dropout = nn.Dropout(0.4)
        self.pool = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.upsample = nn.Upsample(size=grid_size, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = x.permute(0, 4, 1, 2, 3)  # (batch_size, T, H, W, C) -> (batch_size, C, T, H, W)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = x.mean(dim=2)
        x = self.fc(x)
        x = self.upsample(x)
        return x

input_channels = 10
hidden_dim = 128
output_dim = 2
grid_size = (data.shape[1], data.shape[2])
model = CNNRainPredictor(input_channels, hidden_dim, output_dim, grid_size).to(device)


In [None]:

# Define loss functions and training utilities
def soft_f1_loss(y_true, y_pred, epsilon=1e-6):
    y_true = y_true.float()
    y_pred = torch.softmax(y_pred, dim=1)[:, 1, :, :]
    tp = torch.sum(y_true * y_pred)
    fp = torch.sum((1 - y_true) * y_pred)
    fn = torch.sum(y_true * (1 - y_pred))
    soft_precision = tp / (tp + fp + epsilon)
    soft_recall = tp / (tp + fn + epsilon)
    soft_f1 = 2 * (soft_precision * soft_recall) / (soft_precision + soft_recall + epsilon)
    return 1 - soft_f1

def focal_loss(y_pred, y_true, gamma=2.0, alpha=None):
    ce_loss = F.cross_entropy(y_pred, y_true, weight=alpha, reduction='none')
    p_t = torch.exp(-ce_loss)
    focal_loss = (1 - p_t) ** gamma * ce_loss
    return focal_loss.mean()

criterion_ce = nn.CrossEntropyLoss()
logging.basicConfig(filename='training_log.log', level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger()
train_losses = []
train_accuracies = []
val_f1_scores = []
test_f1_scores = []

# Modified training cell with threshold grid search, fixed class weight, and optimal threshold for test
def train(model, train_loader, optimizer, device, class_weights, alpha=0.5, accum_steps=4):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    optimizer.zero_grad()
    # Tqdm progress bar for training batches
    for i, (data, target) in enumerate(tqdm(train_loader, desc="Training Batches", leave=False)):
        data, target = data.to(device), target.to(device)
        out = model(data)
        if out.shape[2:] != target.shape[1:]:
            raise ValueError(f"Output spatial shape {out.shape[2:]} doesn't match target {target.shape[1:]}")
        ce_loss = criterion_ce(out, target)
        f1_loss = soft_f1_loss(target, out)
        focal = focal_loss(out, target, alpha=class_weights.to(device))
        loss = alpha * ce_loss + (1 - alpha) * f1_loss + focal
        loss = loss / accum_steps
        loss.backward()
        if (i + 1) % accum_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()
        total_loss += loss.item() * accum_steps
        preds = out.argmax(dim=1)
        correct += (preds == target).sum().item()
        total += target.numel()
    if len(train_loader) % accum_steps != 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(train_loader)
    accuracy = correct / total
    return avg_loss, accuracy

def evaluate(model, loader, device, threshold=0.77, dataset_name="Test"):
    model.eval()
    all_probs = []
    all_labels = []
    total_loss = 0
    total_samples = 0
    expected_samples = len(loader.dataset)

    with torch.no_grad():
        for data, target in tqdm(loader, desc=f"{dataset_name} Batches", leave=False):
            data, target = data.to(device), target.to(device)
            out = model(data)
            loss = criterion_ce(out, target)
            total_loss += loss.item() * data.size(0)

            # Extract probabilities for rain class
            probs = torch.softmax(out, dim=1)[:, 1, :, :]  # Shape: [batch_size, lat, lon]
            all_probs.append(probs.cpu().numpy())
            all_labels.append(target.cpu().numpy())
            total_samples += data.size(0)

    # Verify total samples
    if total_samples != expected_samples:
        print(f"Warning: Processed {total_samples} samples, expected {expected_samples}")

    # Concatenate across batches
    all_probs = np.concatenate(all_probs, axis=0)  # Shape: [num_samples, lat, lon]
    all_labels = np.concatenate(all_labels, axis=0)  # Shape: [num_samples, lat, lon]

    # Check shape consistency
    if all_probs.shape != all_labels.shape:
        raise ValueError(f"Shape mismatch: all_probs {all_probs.shape}, all_labels {all_labels.shape}")

    avg_loss = total_loss / total_samples if total_samples > 0 else float('inf')

    # Compute predictions with explicit thresholding
    preds = (all_probs >= threshold).astype(np.int64)
    # Verify preds shape
    if preds.shape != all_labels.shape:
        raise ValueError(f"Preds shape {preds.shape} does not match all_labels shape {all_labels.shape}")

    # Compute metrics
    f1 = f1_score(all_labels.flatten(), preds.flatten(), average='binary')
    accuracy = accuracy_score(all_labels.flatten(), preds.flatten())
    precision = precision_score(all_labels.flatten(), preds.flatten())
    recall = recall_score(all_labels.flatten(), preds.flatten())
    pred_rain = np.sum(preds)
    true_rain = np.sum(all_labels)

    print(f"{dataset_name} - Threshold: {threshold:.2f}, Loss: {avg_loss:.4f}, F1: {f1:.4f}, "
          f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")
    print(f"{dataset_name} - Predicted Rain Nodes: {pred_rain}, True Rain Nodes: {true_rain}")

    return f1, accuracy, precision, recall, avg_loss, all_probs, all_labels

In [None]:
torch.cuda.empty_cache()

# Training with fixed class weight and threshold
class_weight = [1.0, 8.0]
fixed_threshold = 0.77
num_epochs = 20
best_val_f1 = 0.0

print(f"Starting Training with Class Weight: {class_weight} and Fixed Threshold: {fixed_threshold:.2f}")
model = CNNRainPredictor(input_channels, hidden_dim, output_dim, grid_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
class_weights = torch.tensor(class_weight, dtype=torch.float32).to(device)
patience = 6
patience_counter = 0

for epoch in tqdm(range(1, num_epochs + 1), desc="Epochs"):
    train_loss, train_acc = train(model, train_loader, optimizer, device, class_weights, alpha=0.5, accum_steps=4)
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_f1, val_acc, val_prec, val_rec, val_loss, _, _ = evaluate(model, val_loader, device,
                                                                 threshold=fixed_threshold, dataset_name="Validation")
    val_f1_scores.append(val_f1)
    test_f1, test_acc, test_prec, test_rec, test_loss, _, _ = evaluate(model, test_loader, device,
                                                                      threshold=fixed_threshold, dataset_name="Test")
    test_f1_scores.append(test_f1)
    log_msg = (f'Class Weight: {class_weight}, Threshold: {fixed_threshold:.2f}, Epoch {epoch:02d}, '
               f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val F1: {val_f1:.4f}, Test F1: {test_f1:.4f}')
    logger.info(log_msg)
    print(log_msg)
    scheduler.step(val_f1)

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), '/content/drive/MyDrive/best_CNN_model.pth')
        patience_counter = 0
        print(f"New best model saved with Val F1: {val_f1:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

print(f"\nFinal Evaluation of Best Model with Fixed Threshold: {fixed_threshold:.2f}")
model.load_state_dict(torch.load('/content/drive/MyDrive/best_CNN_model.pth'))
test_f1, test_acc, test_prec, test_rec, test_loss, all_probs, all_labels = evaluate(
    model, test_loader, device, threshold=fixed_threshold, dataset_name="Test")

In [None]:
!pip install cartopy

In [None]:
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import matplotlib.patches as mpatches

In [None]:
# Verify shapes before confusion matrix
print(f"Confusion Matrix Inputs - all_labels shape: {all_labels.shape}, all_probs shape: {all_probs.shape}")
if all_labels.shape != all_probs.shape:
    raise ValueError(f"Shape mismatch: all_labels {all_labels.shape}, all_probs {all_probs.shape}")

# Compute predictions with explicit thresholding
preds = (all_probs >= fixed_threshold).astype(np.int64)
print(f"Preds shape after thresholding: {preds.shape}")

# Verify flattened shapes
flat_labels = all_labels.flatten()
flat_preds = preds.flatten()
print(f"Flattened shapes - labels: {flat_labels.shape}, preds: {flat_preds.shape}")
if flat_labels.shape != flat_preds.shape:
    raise ValueError(f"Flattened shape mismatch: labels {flat_labels.shape}, preds {flat_preds.shape}")

cm = confusion_matrix(flat_labels, flat_preds)
cm_percent = cm / cm.sum(axis=1, keepdims=True) * 100  # Normalize by row
sns.set(font_scale=1.5)
plt.figure(figsize=(8, 6))
sns.heatmap(cm_percent, annot=True, fmt='.2f', cmap='Blues', cbar=False,
            xticklabels=['No Rain', 'Rain'], yticklabels=['No Rain', 'Rain'],
            annot_kws={'size': 15})
for text in plt.gca().texts:
    text.set_text(f"{float(text.get_text()):.2f}%")
plt.title('Confusion Matrix (CNN_Model)', fontsize=22)
plt.xlabel('Predicted', fontsize = 20)
plt.ylabel('True', fontsize = 20)
plt.savefig('/content/drive/MyDrive/CNN_confusion_matrix_row_normalized.png')
plt.show()

In [None]:
# Compute ROC curve
fpr, tpr, _ = roc_curve(flat_labels, all_probs.flatten())
roc_auc = roc_auc_score(flat_labels, all_probs.flatten())
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')
plt.savefig('/content/drive/MyDrive/CNN_roc_curve.png')
plt.show()

# Compute Precision-Recall curve
precision, recall, _ = precision_recall_curve(flat_labels, all_probs.flatten())
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, label='Precision-Recall Curve')
plt.xlabel('Recall', fontsize=20)
plt.ylabel('Precision', fontsize=20)
plt.title('Precision-Recall Curve (CNN Model)', fontsize=22)
plt.legend(loc='lower left')
plt.savefig('/content/drive/MyDrive/CNN_precision_recall_curve.png')
plt.show()

In [None]:
# Plot F1 score vs. epochs
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(val_f1_scores) + 1), val_f1_scores, label='Validation F1')
plt.plot(range(1, len(test_f1_scores) + 1), test_f1_scores, label='Test F1')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.title('F1 Score vs. Epoch')
plt.legend()
plt.grid(True)
plt.savefig('/content/drive/MyDrive/CNN_f1_score_vs_epoch.png')
plt.show()

In [None]:
# Spatial Map Comparison for Selected Test Samples
# Define latitude and longitude
lat = ds.latitude.values  # Shape: (105,)
lon = ds.longitude.values  # Shape: (237,)

# Debug coordinate shapes and ranges
#print(f"Longitude shape: {lon.shape}, range: {lon.min():.2f} to {lon.max():.2f}")
#print(f"Latitude shape: {lat.shape}, range: {lat.min():.2f} to {lat.max():.2f}")

test_sample_indices = [0, 50, 100]  # Valid indices for len(test_dataset) = 199

for test_sample_idx in test_sample_indices:
    if test_sample_idx >= len(test_dataset):
        print(f"Skipping index {test_sample_idx}: Out of bounds (max {len(test_dataset)-1})")
        continue
    input_data, true_labels = test_dataset[test_sample_idx]  # Unpack input and labels
    input_data = input_data.unsqueeze(0).to(device)  # Add batch dimension
    with torch.no_grad():
        model.eval()
        out = model(input_data)
        probs = torch.softmax(out, dim=1)[:, 1, :, :]  # Shape: [batch_size, N_lat, N_lon]
        preds = (probs >= fixed_threshold).float()
    preds = preds.cpu().numpy().squeeze()  # Shape: [105, 237]
    true = true_labels.cpu().numpy()  # Shape: [105, 237]

    # Debug shapes
    #print(f"Sample {test_sample_idx}: true shape: {true.shape}, preds shape: {preds.shape}")

    # Get indices where rain occurs
    true_rain_idx = np.where(true == 1)
    pred_rain_idx = np.where(preds == 1)

    # Debug index ranges
    #print(f"Sample {test_sample_idx}: True rain row indices max: {true_rain_idx[0].max() if len(true_rain_idx[0]) > 0 else 'None'}")
    #print(f"Sample {test_sample_idx}: True rain col indices max: {true_rain_idx[1].max() if len(true_rain_idx[1]) > 0 else 'None'}")

    # Map indices to coordinates
    true_lat = lat[true_rain_idx[0]]  # Rows (latitude, size 105)
    true_lon = lon[true_rain_idx[1]]  # Columns (longitude, size 237)
    pred_lat = lat[pred_rain_idx[0]]
    pred_lon = lon[pred_rain_idx[1]]

    # Debug number of points and coordinate ranges
    print(f"Sample {test_sample_idx}: True rain points: {len(true_lon)}, ")
    #      f"lon range: {true_lon.min():.2f} to {true_lon.max():.2f}, "
    #      f"lat range: {true_lat.min():.2f} to {true_lat.max():.2f}")
    print(f"Sample {test_sample_idx}: Predicted rain points: {len(pred_lon)}, ")
    #      f"lon range: {pred_lon.min():.2f} to {pred_lon.max():.2f}, "
    #      f"lat range: {pred_lat.min():.2f} to {pred_lat.max():.2f}")
    #
    # Spatial Map
    fig = plt.figure(figsize=(12, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.set_extent([-125, -65, 24, 50], crs=ccrs.PlateCarree())
    ax.add_feature(cfeature.LAND, facecolor='lightgray')
    ax.add_feature(cfeature.OCEAN)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS)
    ax.add_feature(cfeature.STATES, linestyle='--')

    # Scatter plots
    ax.scatter(true_lon, true_lat, c='red', s=50, alpha=0.5, label='True Rain', transform=ccrs.PlateCarree())
    ax.scatter(pred_lon, pred_lat, c='blue', s=50, alpha=0.5, label='Predicted Rain', transform=ccrs.PlateCarree())

    plt.title(f'True vs. Predicted Rain (Test Sample {test_sample_idx}, CNN Model)')
    plt.legend()
    plt.savefig(f'/content/drive/MyDrive/CNN_spatial_prediction_{test_sample_idx}.png')
    plt.show()