# NCI GenCast Demo

This notebook demonstrates running the GenCast example.

GenCast is a model at both 0.25deg and 1deg resolution, with 13 pressure levels and a 4 times refined icosahedral mesh. It is trained on ERA5 data from 1979 to 2018, and can be causally evaluated on 2019 and later years.


# Installation and Initialization

In [None]:
# @title Imports

import dataclasses
import datetime
import math
from google.cloud import storage
from typing import Optional
import haiku as hk
from IPython.display import HTML
from IPython import display
import ipywidgets as widgets
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray

from graphcast import rollout
from graphcast import xarray_jax
from graphcast import normalization
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import xarray_tree
from graphcast import gencast
from graphcast import denoiser
from graphcast import nan_cleaning

In [2]:
# @title Plotting functions

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
    if "time" in first_data.dims:
      td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
      figure.suptitle(f"{fig_title}, {td}", fontsize=16)
    else:
      figure.suptitle(fig_title, fontsize=16)
    for im, (plot_data, norm, cmap) in zip(images, data.values()):
      im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())


# Load the Data and initialize the model

In [3]:
# NCI Dir location
#dir_prefix = "/g/data/wb00/admin/testing/NCI_Earth_Models_Dataset/GenCast/"
dir_prefix = '/g/data/wb00/NCI_Earth_Models_Dataset/GenCast/'

## Load the model params

Choose one of the two resolution model params:
- **checkpoint**: You'll get sensible predictions, but are limited to the model architecture that it was trained with, which may not fit on your device.


In [4]:
# @title Choose the model

import os 
source = "Checkpoint"
params_file_options = [item for item in os.listdir(dir_prefix+"params/") if os.path.isfile(dir_prefix+"params/"+item)]  
params_file = widgets.Dropdown(
    options=params_file_options,
    value = params_file_options[0],
    description="Params file:",
    layout={"width": "max-content"})
params_file

Dropdown(description='Params file:', layout=Layout(width='max-content'), options=('GenCast 0p25deg <2019.npz',…

In [5]:
# Model

with open(dir_prefix + f"params/{params_file.value}", 'rb') as f:
    ckpt = checkpoint.load(f, gencast.CheckPoint)

params = ckpt.params
state = {}

task_config = ckpt.task_config
sampler_config = ckpt.sampler_config
noise_config = ckpt.noise_config
noise_encoder_config = ckpt.noise_encoder_config
denoiser_architecture_config = ckpt.denoiser_architecture_config
denoiser_architecture_config.sparse_transformer_config.attention_type = "triblockdiag_mha"
denoiser_architecture_config.sparse_transformer_config.mask_type = "full"


print("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")

Model description:
 
        GenCast model at 0.25deg resolution with 13 pressure levels and a 6 times
        refined icosahedral mesh. This model is trained on ERA5 data from 1979
        to 2018 (inclusive), and can be causally evaluated on 2019
        and later years. This model was described in the paper
        `GenCast: Diffusion-based ensemble forecasting for medium-range weather`
        (https://arxiv.org/abs/2312.15796).
         

Model license:
 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.
 



## Load the example data

Example ERA5 datasets are available at 0.25 degree and 1 degree resolution.

Example HRES-fc0 datasets are available at 0.25 degree resolution.

The data resolution must match the model that is loaded. 

In [6]:
# @title Get and filter the list of available example datasets

dataset_file_options =  [item for item in os.listdir(dir_prefix + "NCI_dataset/") ]  
dataset_file_options


def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))


def data_valid_for_model(file_name: str, params_file_name: str):
  """Check data type and resolution matches."""
  data_file_parts = parse_file_parts(file_name.removesuffix(".nc"))
  data_res = data_file_parts["res"].replace(".", "p")
  if source == "Random":
    return random_resolution.value == data_res
  res_matches = data_res in params_file_name.lower()
  source_matches = "Operational" in params_file_name
  if data_file_parts["source"] == "era5":
    source_matches = not source_matches
  return res_matches and source_matches

dataset_file = widgets.Dropdown(
    options=[
        (", ".join([f"{k}: {v}" for k, v in parse_file_parts(option.removesuffix(".nc")).items()]), option)
        for option in dataset_file_options
        if data_valid_for_model(option, params_file.value)
    ],
    description="Dataset file:",
    layout={"width": "max-content"})
widgets.VBox([
    dataset_file,
    widgets.Label(value="Run the next cell to load the dataset. Rerunning this cell clears your selection and refilters the datasets that match your model.")
])


VBox(children=(Dropdown(description='Dataset file:', layout=Layout(width='max-content'), options=(('source: er…

In [7]:
# @title Load weather data
    
with open(dir_prefix + f"NCI_dataset/{dataset_file.value}", 'rb') as f:
    example_batch = xarray.load_dataset(f).compute()

assert example_batch.dims["time"] >= 3   

print(", ".join([f"{k}: {v}" for k, v in parse_file_parts(dataset_file.value.removesuffix(".nc")).items()]))

example_batch

  example_batch = xarray.load_dataset(f).compute()


source: era5, date: 2024-01-01, res: 0.25, levels: 13, steps: 04


  assert example_batch.dims["time"] >= 3


In [8]:
# @title Choose data to plot

plot_example_variable = widgets.Dropdown(
    options=example_batch.data_vars.keys(),
    value="2m_temperature",
    description="Variable")
plot_example_level = widgets.Dropdown(
    options=example_batch.coords["level"].values,
    value=500,
    description="Level")
plot_example_robust = widgets.Checkbox(value=True, description="Robust")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=example_batch.dims["time"], value=example_batch.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_example_variable,
    plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
])

  min=1, max=example_batch.dims["time"], value=example_batch.dims["time"],


VBox(children=(Dropdown(description='Variable', index=3, options=('geopotential_at_surface', 'land_sea_mask', …

In [9]:
# @title Plot example data

plot_size = 7

data = {
    " ": scale(select(example_batch, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),
              robust=plot_example_robust.value),
}
fig_title = plot_example_variable.value
if "level" in example_batch[plot_example_variable.value].coords:
  fig_title += f" at {plot_example_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust.value)


In [10]:
# @title Extract training and eval data

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("12h", "12h"), # Only 1AR training.
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("12h", f"{(example_batch.dims['time']-2)*12}h"), # All but 2 input frames.
    **dataclasses.asdict(task_config))

print("All Examples:  ", example_batch.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)


  example_batch, target_lead_times=slice("12h", f"{(example_batch.dims['time']-2)*12}h"), # All but 2 input frames.


All Examples:   {'lat': 721, 'time': 6, 'lon': 1440, 'batch': 1, 'level': 13}
Train Inputs:   {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Train Targets:  {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440, 'level': 13}
Train Forcings: {'batch': 1, 'time': 1, 'lon': 1440}
Eval Inputs:    {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Eval Targets:   {'batch': 1, 'time': 4, 'lat': 721, 'lon': 1440, 'level': 13}
Eval Forcings:  {'batch': 1, 'time': 4, 'lon': 1440}


In [11]:
with open(dir_prefix + "stats/diffs_stddev_by_level.nc", 'rb') as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(dir_prefix + "stats/mean_by_level.nc", 'rb') as f:
    mean_by_level = xarray.load_dataset(f).compute()
with open(dir_prefix + "stats/stddev_by_level.nc", 'rb') as f:
    stddev_by_level = xarray.load_dataset(f).compute()
with open(dir_prefix + "stats/min_by_level.nc", 'rb') as f:
    min_by_level = xarray.load_dataset(f).compute()


In [12]:
# @title Build jitted functions, and possibly initialize random weights


def construct_wrapped_gencast():
  """Constructs and wraps the GenCast Predictor."""
  predictor = gencast.GenCast(
      sampler_config=sampler_config,
      task_config=task_config,
      denoiser_architecture_config=denoiser_architecture_config,
      noise_config=noise_config,
      noise_encoder_config=noise_encoder_config,
  )

  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level,
  )

  predictor = nan_cleaning.NaNCleaner(
      predictor=predictor,
      reintroduce_nans=True,
      fill_value=min_by_level,
      var_to_clean='sea_surface_temperature',
  )

  return predictor


@hk.transform_with_state
def run_forward(inputs, targets_template, forcings):
  predictor = construct_wrapped_gencast()
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(inputs, targets, forcings):
  predictor = construct_wrapped_gencast()
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics),
  )


def grads_fn(params, state, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), i, t, f
    )
    return loss, (diagnostics, next_state)

  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True
  )(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads


if params is None:
  init_jitted = jax.jit(loss_fn.init)
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets=train_targets,
      forcings=train_forcings,
  )


loss_fn_jitted = jax.jit(
    lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]
)
grads_fn_jitted = jax.jit(grads_fn)
run_forward_jitted = jax.jit(
    lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]
)
# We also produce a pmapped version for running in parallel.
run_forward_pmap = xarray_jax.pmap(run_forward_jitted, dim="sample")

# Run the model

The `chunked_prediction_generator_multiple_runs` iterates over forecast steps, where the 1 step forecast is jitted and samples are pmapped across the chips.
This allows us to make efficient use of all devices and parallelise generating an ensemble across them. We then combine the chunks at the end to form our final forecast.

Note that the `Autoregressive rollout` cell will take longer than the standard inference time to run when executed for the first time, as this will include code compilation time. This cost does not increase with the number of devices, it is a fixed-cost one time operation whose result can be reused across any number of devices.

In [13]:
# The number of ensemble members should be a multiple of the number of devices.
print(f"Number of local devices {len(jax.local_devices())}")

Number of local devices 1


In [14]:
# @title Autoregressive rollout (loop in python)

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

num_ensemble_members = 8 # @param int
rng = jax.random.PRNGKey(0)
# We fold-in the ensemble member, this way the first N members should always
# match across different runs which use take the same inputs, regardless of
# total ensemble size.
rngs = np.stack(
    [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)

chunks = []
for chunk in rollout.chunked_prediction_generator_multiple_runs(
    # Use pmapped version to parallelise across devices.
    predictor_fn=run_forward_pmap,
    rngs=rngs,
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings,
    num_steps_per_chunk = 1,
    num_samples = num_ensemble_members,
    pmap_devices=jax.local_devices()
    ):
    chunks.append(chunk)
predictions = xarray.combine_by_coords(chunks)

Inputs:   {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Targets:  {'batch': 1, 'time': 4, 'lat': 721, 'lon': 1440, 'level': 13}
Forcings: {'batch': 1, 'time': 4, 'lon': 1440}


  num_target_steps = targets_template.dims["time"]
  self._set_arrayXarray(i, j, x)
  self._set_arrayXarray(i, j, x)
  num_inputs = prev_inputs.dims["time"]
  num_target_steps = targets_template.dims["time"]


In [15]:
# @title Choose predictions to plot

plot_pred_variable = widgets.Dropdown(
    options=predictions.data_vars.keys(),
    value="2m_temperature",
    description="Variable")
plot_pred_level = widgets.Dropdown(
    options=predictions.coords["level"].values,
    value=500,
    description="Level")
plot_pred_robust = widgets.Checkbox(value=True, description="Robust")
plot_pred_max_steps = widgets.IntSlider(
    min=1,
    max=predictions.dims["time"],
    value=predictions.dims["time"],
    description="Max steps")
plot_pred_samples = widgets.IntSlider(
    min=1,
    max=num_ensemble_members,
    value=num_ensemble_members,
    description="Samples")

widgets.VBox([
    plot_pred_variable,
    plot_pred_level,
    plot_pred_robust,
    plot_pred_max_steps,
    plot_pred_samples,
    widgets.Label(value="Run the next cell to plot the predictions. Rerunning this cell clears your selection.")
])

  max=predictions.dims["time"],
  value=predictions.dims["time"],


VBox(children=(Dropdown(description='Variable', index=2, options=('10m_u_component_of_wind', '10m_v_component_…

In [None]:
# @title Plot prediction samples and diffs

plot_size = 5
plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)

fig_title = plot_pred_variable.value
if "level" in predictions[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

for sample_idx in range(plot_pred_samples.value):
  data = {
      "Targets": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
      "Predictions": scale(select(predictions.isel(sample=sample_idx), plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
      "Diff": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
                          select(predictions.isel(sample=sample_idx), plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                        robust=plot_pred_robust.value, center=0),
  }
  display.display(plot_data(data, fig_title + f", Sample {sample_idx}", plot_size, plot_pred_robust.value))


In [17]:
# @title Plot ensemble mean and CRPS

def crps(targets, predictions, bias_corrected = True):
  if predictions.sizes.get("sample", 1) < 2:
    raise ValueError(
        "predictions must have dim 'sample' with size at least 2.")
  sum_dims = ["sample", "sample2"]
  preds2 = predictions.rename({"sample": "sample2"})
  num_samps = predictions.sizes["sample"]
  num_samps2 = (num_samps - 1) if bias_corrected else num_samps
  mean_abs_diff = np.abs(
      predictions - preds2).sum(
          dim=sum_dims, skipna=False) / (num_samps * num_samps2)
  mean_abs_err = np.abs(targets - predictions).sum(dim="sample", skipna=False) / num_samps
  return mean_abs_err - 0.5 * mean_abs_diff


plot_size = 5
plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)

fig_title = plot_pred_variable.value
if "level" in predictions[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

data = {
    "Targets": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Ensemble Mean": scale(select(predictions.mean(dim=["sample"]), plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Ensemble CRPS": scale(crps((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                        select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                      robust=plot_pred_robust.value, center=0),
}
display.display(plot_data(data, fig_title, plot_size, plot_pred_robust.value))

  plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)
