In [1]:
# %%
# FourCastNet Sensitivity Analysis Script - FIXED VERSION
# This script performs sensitivity analysis to determine how RH at a target location
# depends on initial conditions across North America through multiple timesteps
import sys
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import TwoSlopeNorm

"""
List of GFS model variables:
----------
u10m   : Zonal Wind Component at 10m [m/s]
         East-west wind speed 10m above ground.
v10m   : Meridional Wind Component at 10m [m/s]
         North-south wind speed 10m above ground.
t2m    : 2m Air Temperature [K]
         Air temperature 2m above ground surface.
sp     : Surface Pressure [hPa]
         Atmospheric pressure at surface level.
msl    : Mean Sea Level Pressure [hPa]
         Pressure reduced to sea level.
t850   : Temperature at 850hPa [K]
         Air temperature at 850hPa (~1,500m altitude).
u1000  : Zonal Wind Component at 1000hPa [m/s]
         East-west wind speed at 1000hPa (~110m altitude).
v1000  : Meridional Wind Component at 1000hPa [m/s]
         North-south wind speed at 1000hPa.
z1000  : Geopotential Height at 1000hPa [m]
         Height of 1000hPa pressure surface above sea level.
u850   : Zonal Wind Component at 850hPa [m/s]
         East-west wind speed at 850hPa.
v850   : Meridional Wind Component at 850hPa [m/s]
         North-south wind speed at 850hPa.
z850   : Geopotential Height at 850hPa [m]
         Height of 850hPa pressure surface above sea level.
u500   : Zonal Wind Component at 500hPa [m/s]
         East-west wind speed at 500hPa (~5,600m altitude).
v500   : Meridional Wind Component at 500hPa [m/s]
         North-south wind speed at 500hPa.
z500   : Geopotential Height at 500hPa [m]
         Height of 500hPa pressure surface above sea level.
t500   : Temperature at 500hPa [K]
         Air temperature at 500hPa.
z50    : Geopotential Height at 50hPa [m]
         Height of 50hPa pressure surface (~19,300m altitude).
r500   : Relative Humidity at 500hPa [%]
         Relative humidity at 500hPa.
r850   : Relative Humidity at 850hPa [%]
         Relative humidity at 850hPa.
tcwv   : Total Column Water Vapor [kg/m²]
         Integrated water vapor content above the surface.
u100m  : Zonal Wind Component at 100m [m/s]
         East-west wind speed 100m above ground.
v100m  : Meridional Wind Component at 100m [m/s]
         North-south wind speed 100m above ground.
u250   : Zonal Wind Component at 250hPa [m/s]
         East-west wind speed at 250hPa (~10,400m altitude).
v250   : Meridional Wind Component at 250hPa [m/s]
         North-south wind speed at 250hPa.
z250   : Geopotential Height at 250hPa [m]
         Height of 250hPa pressure surface above sea level.
t250   : Temperature at 250hPa [K]
         Air temperature at 250hPa.

Notes:
------
- Wind components (u/v) are in meters per second (m/s): u = east-west, v = north-south.
- Temperatures are in Kelvin (K); can be converted to °C or °F if needed.
- Geopotential height (z) gives the elevation of constant-pressure surfaces.
- Pressure is in hectopascals (hPa); altitude estimates are approximate.
- Moisture variables include relative humidity (%) and total column water vapor (kg/m²).
"""

def is_colab():
    """Check if code is running in Google Colab"""
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_colab():
    print("Running in Google Colab - installing dependencies...")
    import subprocess
    subprocess.run("uv pip install earth2studio[fcn] torch numpy xarray netcdf4 loguru tqdm -q", shell=True, check=True)

import torch
from torch.utils.checkpoint import checkpoint
import numpy as np
from collections import OrderedDict
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from typing import List, Optional
from earth2studio.models.px import FCN
from earth2studio.data import GFS
import tqdm
import types
import warnings
warnings.filterwarnings('ignore')


# Suppress DEBUG messages from earth2studio.data.gfs module
# The library uses loguru, not standard logging
from loguru import logger
logger.remove()  # Remove default handler
logger.add(lambda _: None, level="WARNING")  # Only show WARNING and above
print("Done with imports.")

Running in Google Colab - installing dependencies...
Done with imports.


  self.fs = None


In [3]:
import torch
import numpy as np
import datetime as dt
from collections import OrderedDict
from dataclasses import dataclass, field
import tqdm
import types
import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

def patch_model_for_gradients(model):
    """
    Patch the FCN model to enable gradient computation
    by replacing the decorated _forward method
    """
    # Get the original _forward method
    original_forward = model._forward

    # If it's decorated with inference_mode, we need to unwrap it
    if hasattr(original_forward, '__wrapped__'):
        # Get the undecorated function
        unwrapped_forward = original_forward.__wrapped__

        # Bind it back to the model instance
        model._forward = types.MethodType(unwrapped_forward, model)
        return True
    else:
        return False

def model_forward_wrapper(current_state, coords, model, target_index):
    """Forward pass wrapper for checkpointing."""
    result = model(current_state, coords)
    if isinstance(result, tuple):
        output, _ = result
    else:
        output = result
    target_contribution = output[0, 0, target_index]
    return output, target_contribution

@dataclass
class ExperimentConfig:
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size: int = 1
    num_timesteps: int = 60
    lead_time_hours: int = 0
    fcn_variables: list = field(default_factory=lambda: [
        "u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000",
        "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500",
        "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250"
    ])
    target_variable: str = 'r500'
    lat_points: int = 720
    lon_points: int = 1440
    resolution: float = 0.25
    target_lat: float = 33.4482
    target_lon: float = -112.0777
    perturb_std: float = 0.1  # standard deviation for Gaussian perturbations

cfg = ExperimentConfig()

print(f"\n=== SENSITIVITY ANALYSIS ===")
print(f"Configuration: {cfg.num_timesteps} timesteps ({cfg.num_timesteps * 6 / 24:.1f} days) starting from {cfg.start_date}")
print(f"Target location: {cfg.target_lat:.4f}°N, {cfg.target_lon:.4f}°W")

# Load model
print("Loading model...")
device = torch.device(cfg.device)
package = FCN.load_default_package()
model = FCN.load_model(package).to(device).eval()

if patch_model_for_gradients(model):
    print("✓ Model patched for gradients")
else:
    print("⚠ Could not patch model; gradients may fail")

target_index = cfg.fcn_variables.index(cfg.target_variable)

coords_template = OrderedDict({
    "batch": np.empty(cfg.batch_size),
    "lead_time": np.array([np.timedelta64(cfg.lead_time_hours, "h")]),
    "variable": np.array(cfg.fcn_variables),
    "lat": np.linspace(90, -90, cfg.lat_points, endpoint=False),
    "lon": np.linspace(0, 360, cfg.lon_points, endpoint=False)
})

lat_array = coords_template['lat']
lon_array = coords_template['lon']
target_lon_360 = cfg.target_lon % 360
lat_idx = np.argmin(np.abs(lat_array - cfg.target_lat))
lon_idx = np.argmin(np.abs(lon_array - target_lon_360))

print(f"Target grid idx: lat={lat_idx}, lon={lon_idx}")

# Ensemble configuration
ensemble_size = 5
start_dates = [dt.datetime(2024, 7, 15) + dt.timedelta(days=i) for i in range(ensemble_size)]
# store gradients for each ensemble member
ensemble_gradients = {var: [] for var in cfg.fcn_variables}

# Ensemble loop
for ens_idx, start_date in enumerate(start_dates, 1):
    print(f"\n=== Ensemble member {ens_idx}/{ensemble_size} starting {start_date} ===")

    # Load initial state
    data_source = GFS()
    input_data = data_source(time=start_date, variable=cfg.fcn_variables)
    coords = coords_template.copy()

    # Convert to tensor
    input_tensor = torch.from_numpy(input_data.to_numpy()[None]).float().to(device)
    # apply gaussian perturbation
    if cfg.perturb_std > 0:
        input_tensor += torch.randn_like(input_tensor) * cfg.perturb_std
    input_tensor.requires_grad_(True)

    # forward
    current_state = input_tensor[..., :-1, :]
    target_sum = 0.0
    count = 0
    for t in tqdm.trange(cfg.num_timesteps):
        output, target_contribution = torch.utils.checkpoint.checkpoint(
            model_forward_wrapper,
            current_state,
            coords,
            model,
            target_index,
            use_reentrant=False,
        )
        target_sum += target_contribution
        current_state = output  # keep graph
        count += 1

    # Compute mean target contribution
    target_mean = target_sum / cfg.num_timesteps

    print(f"\n  Computing Sensitivity Analysis ")
    print(f"\n   Computing gradients with respect to initial conditions...")

    # backpropagate to get grads wrt input tensor
    input_tensor.grad = None
    target_mean[lat_idx, lon_idx].backward(inputs=[input_tensor])

    # Get the gradients with respect to initial conditions
    if input_tensor.grad is not None:
        # The gradients are with respect to the original input tensor
        gradients = input_tensor.grad[0, 0, :, :-1, :]  # Remove batch, lead_time dims and south pole
        print(f"   Gradient shape: {gradients.shape}")

        for var_idx, var_name in enumerate(cfg.fcn_variables):
            # Get gradient for this variable
            var_grad = gradients[var_idx]
            ensemble_gradients[var_name].append(var_grad.cpu().detach().numpy())


    print(f"\n   ✓ Sensitivity analysis complete for ensemble member {ens_idx}!")
    print(f"   Computed gradients for {len(ensemble_gradients)} variables")

# Average gradients over ensemble
avg_gradients = {}
for var_name, grads_list in ensemble_gradients.items():
    avg_gradients[var_name] = np.mean(grads_list, axis=0)
    print(f"{var_name}: ensemble-avg max_abs_grad={np.abs(avg_gradients[var_name]).max():.2e}, "
          f"mean_abs_grad={np.abs(avg_gradients[var_name]).mean():.2e}")


print("✓ Ensemble sensitivity analysis complete")


Loading model...
✓ Model patched for gradients
Target grid idx: lat=226, lon=992

=== Ensemble member 1/5 starting 2024-07-15 00:00:00 ===


Fetching GFS data: 100%|██████████| 26/26 [00:01<00:00, 24.68it/s]
100%|██████████| 60/60 [00:22<00:00,  2.72it/s]



  Computing Sensitivity Analysis 

   Computing gradients with respect to initial conditions...
   Gradient shape: torch.Size([26, 720, 1440])

   ✓ Sensitivity analysis complete for ensemble member 1!
   Computed gradients for 26 variables

=== Ensemble member 2/5 starting 2024-07-16 00:00:00 ===


Fetching GFS data: 100%|██████████| 26/26 [00:01<00:00, 25.62it/s]
100%|██████████| 60/60 [00:21<00:00,  2.78it/s]



  Computing Sensitivity Analysis 

   Computing gradients with respect to initial conditions...
   Gradient shape: torch.Size([26, 720, 1440])

   ✓ Sensitivity analysis complete for ensemble member 2!
   Computed gradients for 26 variables

=== Ensemble member 3/5 starting 2024-07-17 00:00:00 ===


Fetching GFS data: 100%|██████████| 26/26 [00:01<00:00, 23.16it/s]
100%|██████████| 60/60 [00:21<00:00,  2.78it/s]



  Computing Sensitivity Analysis 

   Computing gradients with respect to initial conditions...
   Gradient shape: torch.Size([26, 720, 1440])

   ✓ Sensitivity analysis complete for ensemble member 3!
   Computed gradients for 26 variables

=== Ensemble member 4/5 starting 2024-07-18 00:00:00 ===


Fetching GFS data: 100%|██████████| 26/26 [00:01<00:00, 23.15it/s]
100%|██████████| 60/60 [00:21<00:00,  2.77it/s]



  Computing Sensitivity Analysis 

   Computing gradients with respect to initial conditions...
   Gradient shape: torch.Size([26, 720, 1440])

   ✓ Sensitivity analysis complete for ensemble member 4!
   Computed gradients for 26 variables

=== Ensemble member 5/5 starting 2024-07-19 00:00:00 ===


Fetching GFS data: 100%|██████████| 26/26 [00:01<00:00, 23.15it/s]
100%|██████████| 60/60 [00:21<00:00,  2.77it/s]



  Computing Sensitivity Analysis 

   Computing gradients with respect to initial conditions...
   Gradient shape: torch.Size([26, 720, 1440])

   ✓ Sensitivity analysis complete for ensemble member 5!
   Computed gradients for 26 variables
u10m: ensemble-avg max_abs_grad=1.57e+11, mean_abs_grad=2.49e+07
v10m: ensemble-avg max_abs_grad=2.19e+11, mean_abs_grad=2.88e+07
t2m: ensemble-avg max_abs_grad=1.30e+11, mean_abs_grad=1.58e+07
sp: ensemble-avg max_abs_grad=1.33e+08, mean_abs_grad=2.08e+04
msl: ensemble-avg max_abs_grad=1.34e+09, mean_abs_grad=2.50e+05
t850: ensemble-avg max_abs_grad=1.13e+11, mean_abs_grad=2.09e+07
u1000: ensemble-avg max_abs_grad=2.09e+11, mean_abs_grad=2.41e+07
v1000: ensemble-avg max_abs_grad=2.41e+11, mean_abs_grad=2.85e+07
z1000: ensemble-avg max_abs_grad=3.05e+09, mean_abs_grad=4.61e+05
u850: ensemble-avg max_abs_grad=3.30e+11, mean_abs_grad=2.51e+07
v850: ensemble-avg max_abs_grad=4.58e+11, mean_abs_grad=3.16e+07
z850: ensemble-avg max_abs_grad=3.55e+09, me

In [13]:
# Check shape
for var_name, grads_list in ensemble_gradients.items():
    print(f"{var_name}: {len(grads_list)} members, each shape {grads_list[0].shape}")


u10m: 5 members, each shape (720, 1440)
v10m: 5 members, each shape (720, 1440)
t2m: 5 members, each shape (720, 1440)
sp: 5 members, each shape (720, 1440)
msl: 5 members, each shape (720, 1440)
t850: 5 members, each shape (720, 1440)
u1000: 5 members, each shape (720, 1440)
v1000: 5 members, each shape (720, 1440)
z1000: 5 members, each shape (720, 1440)
u850: 5 members, each shape (720, 1440)
v850: 5 members, each shape (720, 1440)
z850: 5 members, each shape (720, 1440)
u500: 5 members, each shape (720, 1440)
v500: 5 members, each shape (720, 1440)
z500: 5 members, each shape (720, 1440)
t500: 5 members, each shape (720, 1440)
z50: 5 members, each shape (720, 1440)
r500: 5 members, each shape (720, 1440)
r850: 5 members, each shape (720, 1440)
tcwv: 5 members, each shape (720, 1440)
u100m: 5 members, each shape (720, 1440)
v100m: 5 members, each shape (720, 1440)
u250: 5 members, each shape (720, 1440)
v250: 5 members, each shape (720, 1440)
z250: 5 members, each shape (720, 1440)
