# GraphCast

This colab lets you run several versions of GraphCast.

The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast).

A Colab runtime with TPU/GPU acceleration will substantially speed up generating predictions and computing the loss/gradients. If you're using a CPU-only runtime, you can switch using the menu "Runtime > Change runtime type".

> <p><small><small>Copyright 2023 DeepMind Technologies Limited.</small></p>
> <p><small><small>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at <a href="http://www.apache.org/licenses/LICENSE-2.0">http://www.apache.org/licenses/LICENSE-2.0</a>.</small></small></p>
> <p><small><small>Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.</small></small></p>

# Installation and Initialization


In [None]:
# @title Pip install graphcast and dependencies
# Imports edited version. Main additions: saves updated mesh nodes from initialisation, and saves latent features from each mlp
%pip install --upgrade https://github.com/ktempestuous/GC_testing/raw/refs/heads/main/Archive_3.zip # version 3
!pip install cartopy

In [None]:
# @title Workaround for cartopy crashes

# Workaround for cartopy crashes due to the shapely installed by default in
# google colab kernel (https://github.com/anitagraser/movingpandas/issues/81):
!pip uninstall -y shapely
!pip install shapely --no-binary shapely

In [None]:
# @title Imports

import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

import cartopy.crs as ccrs
import cartopy.feature as cfeature
from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray
import inspect
import xarray as xr

def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))


In [None]:
# @title Ensure latest version of Graphcast from Github is used
print(inspect.getsource(graphcast.GraphCast.__call__))

In [None]:
# @title Authenticate with Google Cloud Storage

gcs_client = storage.Client.create_anonymous_client() # This creates a Google Cloud Storage client that does not require authentication.
gcs_bucket = gcs_client.get_bucket("dm_graphcast") # accessing public GCS data without logging in
dir_prefix = "graphcast/"

# Load the Data and initialize the model

## Load the model params

Choose one of the two ways of getting model params:
- **random**: You'll get random predictions, but you can change the model architecture, which may run faster or fit on your device.
- **checkpoint**: You'll get sensible predictions, but are limited to the model architecture that it was trained with, which may not fit on your device. In particular generating gradients uses a lot of memory, so you'll need at least 25GB of ram (TPUv4 or A100).

Checkpoints vary across a few axes:
- The mesh size specifies the internal graph representation of the earth. Smaller meshes will run faster but will have worse outputs. The mesh size does not affect the number of parameters of the model.
- The resolution and number of pressure levels must match the data. Lower resolution and fewer levels will run a bit faster. Data resolution only affects the encoder/decoder.
- All our models predict precipitation. However, ERA5 includes precipitation, while HRES does not. Our models marked as "ERA5" take precipitation as input and expect ERA5 data as input, while model marked "ERA5-HRES" do not take precipitation as input and are specifically trained to take HRES-fc0 as input (see the data section below).

We provide three pre-trained models.
1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017,

2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from 1979 to 2015, useful to run a model with lower memory and compute constraints,

3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on HRES data from 2016 to 2021. This model can be initialized from HRES data (does not require precipitation inputs).


In [None]:
# @title Choose the model

params_file_options = [
    name for blob in gcs_bucket.list_blobs(prefix=dir_prefix+"params/")
    if (name := blob.name.removeprefix(dir_prefix+"params/"))]  # Drop empty string.

random_mesh_size = widgets.IntSlider(
    value=4, min=4, max=6, description="Mesh size:")
random_gnn_msg_steps = widgets.IntSlider(
    value=4, min=1, max=32, description="GNN message steps:")
random_latent_size = widgets.Dropdown(
    options=[int(2**i) for i in range(4, 10)], value=32,description="Latent size:")
random_levels = widgets.Dropdown(
    options=[13, 37], value=13, description="Pressure levels:")


params_file = widgets.Dropdown(
    options=params_file_options,
    description="Params file:",
    layout={"width": "max-content"})

source_tab = widgets.Tab([
    widgets.VBox([
        random_mesh_size,
        random_gnn_msg_steps,
        random_latent_size,
        random_levels,
    ]),
    params_file,
])
source_tab.set_title(0, "Random")
source_tab.set_title(1, "Checkpoint")
widgets.VBox([
    source_tab,
    widgets.Label(value="Run the next cell to load the model. Rerunning this cell clears your selection.")
])


In [None]:
# @title Load the model

source = source_tab.get_title(source_tab.selected_index) # gets previous selection

if source == "Random":
  params = None  # Filled in below # No pretrained parameters; this will be a randomly initialized model.
  state = {} # empty model state
  model_config = graphcast.ModelConfig(
      resolution=0,
      mesh_size=random_mesh_size.value,
      latent_size=random_latent_size.value,
      gnn_msg_steps=random_gnn_msg_steps.value,
      hidden_layers=1,
      radius_query_fraction_edge_length=0.6)
  task_config = graphcast.TaskConfig(
      input_variables=graphcast.TASK.input_variables,
      target_variables=graphcast.TASK.target_variables,
      forcing_variables=graphcast.TASK.forcing_variables,
      pressure_levels=graphcast.PRESSURE_LEVELS[random_levels.value],
      input_duration=graphcast.TASK.input_duration,
  )
else:
  assert source == "Checkpoint"
  with gcs_bucket.blob(f"{dir_prefix}params/{params_file.value}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)
  params = ckpt.params # ckpt is an instance of the class graphcast.CheckPoint
  state = {}

  model_config = ckpt.model_config
  task_config = ckpt.task_config
  print("Model description:\n", ckpt.description, "\n")
  print("Model license:\n", ckpt.license, "\n")

print(model_config)
print(task_config)

## Load the example data

Several example datasets are available, varying across a few axes:
- **Source**: fake, era5, hres
- **Resolution**: 0.25deg, 1deg, 6deg
- **Levels**: 13, 37
- **Steps**: How many timesteps are included

Not all combinations are available.
- Higher resolution is only available for fewer steps due to the memory requirements of loading them.
- HRES is only available in 0.25 deg, with 13 pressure levels.

The data resolution must match the model that is loaded.

Some transformations were done from the base datasets:
- We accumulated precipitation over 6 hours instead of the default 1 hour.
- For HRES data, each time step corresponds to the HRES forecast at leadtime 0, essentially providing an "initialisation" from HRES. See HRES-fc0 in the GraphCast paper for further description. Note that a 6h accumulation of precipitation is not available from HRES, so our model taking HRES inputs does not depend on precipitation. However, because our models predict precipitation, we include the ERA5 precipitation in the example data so it can serve as an illustrative example of ground truth.
- We include ERA5 `toa_incident_solar_radiation` in the data. Our model uses the radiation at -6h, 0h and +6h as a forcing term for each 1-step prediction. If the radiation is missing from the data (e.g. in an operational setting), it will be computed using a custom implementation that produces values similar to those in ERA5.

In [None]:
# @title Get and filter the list of available example datasets

dataset_file_options = [
    name for blob in gcs_bucket.list_blobs(prefix=dir_prefix+"dataset/")
    if (name := blob.name.removeprefix(dir_prefix+"dataset/"))]  # Drop empty string.

def data_valid_for_model(
    file_name: str, model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):
  file_parts = parse_file_parts(file_name.removesuffix(".nc"))
  return (
      model_config.resolution in (0, float(file_parts["res"])) and
      len(task_config.pressure_levels) == int(file_parts["levels"]) and
      (
          ("total_precipitation_6hr" in task_config.input_variables and
           file_parts["source"] in ("era5", "fake")) or
          ("total_precipitation_6hr" not in task_config.input_variables and
           file_parts["source"] in ("hres", "fake"))
      )
  )


dataset_file = widgets.Dropdown(
    options=[
        (", ".join([f"{k}: {v}" for k, v in parse_file_parts(option.removesuffix(".nc")).items()]), option)
        for option in dataset_file_options
        if data_valid_for_model(option, model_config, task_config)
    ],
    description="Dataset file:",
    layout={"width": "max-content"})
widgets.VBox([
    dataset_file,
    widgets.Label(value="Run the next cell to load the dataset. Rerunning this cell clears your selection and refilters the datasets that match your model.")
])

In [None]:
# @title Load weather data

if not data_valid_for_model(dataset_file.value, model_config, task_config):
  raise ValueError(
      "Invalid dataset file, rerun the cell above and choose a valid dataset file.")

with gcs_bucket.blob(f"{dir_prefix}dataset/{dataset_file.value}").open("rb") as f:
  example_batch = xarray.load_dataset(f).compute()

assert example_batch.dims["time"] >= 3  # 2 for input, >=1 for targets

print(", ".join([f"{k}: {v}" for k, v in parse_file_parts(dataset_file.value.removesuffix(".nc")).items()]))

example_batch

In [None]:
# @title Choose training and eval data to extract
train_steps = widgets.IntSlider(
    value=1, min=1, max=example_batch.sizes["time"]-2, description="Train steps")
eval_steps = widgets.IntSlider(
    value=example_batch.sizes["time"]-2, min=1, max=example_batch.sizes["time"]-2, description="Eval steps")

widgets.VBox([
    train_steps,
    eval_steps,
    widgets.Label(value="Run the next cell to extract the data. Rerunning this cell clears your selection.")
])

In [None]:
# @title Extract training and eval data

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{train_steps.value*6}h"),
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{eval_steps.value*6}h"),
    **dataclasses.asdict(task_config))

print("All Examples:  ", example_batch.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)


In [None]:
# @title Load normalization data

with gcs_bucket.blob(dir_prefix+"stats/diffs_stddev_by_level.nc").open("rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute() # Standard deviation of differences between consecutive timesteps. Often used for loss weighting or forecast uncertainty calibration.
  # For example, if a variable has small changes between timesteps, it might be given more weight in the loss function.
with gcs_bucket.blob(dir_prefix+"stats/mean_by_level.nc").open("rb") as f:
  mean_by_level = xarray.load_dataset(f).compute() # Mean values of each weather variable at each pressure level. Used to normalize input and output variables before feeding them to the model.
with gcs_bucket.blob(dir_prefix+"stats/stddev_by_level.nc").open("rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute() # Standard deviation of each variable at each pressure level. Used together with the mean for normalization.

# Run the model

Note that the cell below may take a while (possibly minutes) to run the first time you execute them, because this will include the time it takes for the code to compile. The second time running will be significantly faster.

This uses the python loop to iterate over prediction steps, where the 1-step prediction is jitted. This has lower memory requirements than the training steps below, and should enable making prediction with the small GraphCast model on 1 deg resolution data for 4 steps.

In [None]:
# @title Create wrapper around main Graphcast class
# Keeping the unwrapped wrapper allows for saved instances to be extracted from the initialisation. e.g. self._latent_mesh_nodes

def construct_wrapped_graphcast(model_config, task_config, keep_unwrapped_function=False): # option to use unwrapped version of GC class to extract other variables
    core_predictor = graphcast.GraphCast(model_config, task_config)

    predictor = casting.Bfloat16Cast(core_predictor)
    predictor = normalization.InputsAndResiduals(
        predictor,
        diffs_stddev_by_level=diffs_stddev_by_level,
        mean_by_level=mean_by_level,
        stddev_by_level=stddev_by_level
    )
    predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)

    if keep_unwrapped_function:
        return predictor, core_predictor  # Return both
    else:
        return predictor


In [None]:
@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
    predictor, core_predictor = construct_wrapped_graphcast(
        model_config, task_config, keep_unwrapped_function=True
    )

    prediction = predictor(inputs, targets_template=targets_template, forcings=forcings)

    return prediction


In [None]:
# since the model is stateless, need to remind model each time of model_config, task_config, params and state. This wrapper means one doesn't have to type this into every function.

def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

def with_params(fn):
  return functools.partial(fn, params=params, state=state)


In [None]:
# initialise model. Does one pass through which creates graphs and gnn structure. Params and state used for running model.
params, state = run_forward.init(
           jax.random.PRNGKey(0),
           model_config=model_config, # could have used with_configs here but have typed it explicitly...
           task_config=task_config,
           inputs=train_inputs,
           targets_template=train_targets,
           forcings=train_forcings
        )
# prints latent_mesh_nodes and updated_latent_mesh_nodes

In [None]:
run_forward = with_params((with_configs(
    run_forward.apply)))

In [None]:
(output), new_state = run_forward(
    rng=jax.random.PRNGKey(1),
    inputs=train_inputs,
    targets_template=train_targets,
    forcings=train_forcings
)

In [None]:
@hk.transform_with_state
def get_latents(model_config, task_config, inputs, targets_template, forcings):

    # Retrieve latent graphs from saved state
    latent_graphs_m = hk.get_state("latent_graphs_m", shape=None, dtype=None, init=lambda *_: None)
    return latent_graphs_m

In [None]:
# Apply function after run_forward ran:
latents, _ = get_latents.apply(
    params, state, rng=jax.random.PRNGKey(2),
    model_config=model_config,
    task_config=task_config,
    inputs=train_inputs,
    targets_template=train_targets,
    forcings=train_forcings
)

In [None]:
# @title Analysing extracted latent features:
len(latents)

In [None]:
node_features = latent_graph.nodes["mesh_nodes"].features

In [None]:
edge_features = latent_graph.edges

In [None]:
node_features.shape

In [None]:
# Get the only key in the dict (or the one you want if there are many)
key = list(edge_features.keys())[0]

In [None]:
# Get the EdgeSet object
edge_set = edge_features[key]
print(edge_set)

In [None]:
# Get the features array
edge_features_ = edge_set.features
print(edge_features_.shape)

In [None]:
for i, latent in enumerate(latents):
    features = latent.nodes["mesh_nodes"].features
    print(f"Step {i}: {features.shape}")

In [None]:
# calculate activation of each node:
activations = []

for i, latent in enumerate(latents):
    features = latent.nodes["mesh_nodes"].features
    features_squeeze = features.squeeze(axis=1)  # shape (2562, latent_dim)
    l2_norms = np.linalg.norm(features_squeeze, axis=1)  # shape (2562,)
    activations.append(l2_norms)  # store per-step activations
    print(f"Step {i}: {l2_norms.shape}")

In [None]:
# Next: plot activations on the grid, with placements according to mesh graph.
# Now: load mesh grid (will have to have saved graphs in google drive)
from google.colab import drive
import pickle
drive.mount('/content/drive/')

with open("/content/drive/MyDrive/mesh_typedgraph.pkl", "rb") as f:
    mesh_graph = pickle.load(f)


In [None]:
def cartesian_to_latlon(x, y, z):
    lat = np.degrees(np.arcsin(x))        # latitude from x
    lon = np.degrees(np.arctan2(z, y))    # longitude from y and z
    lon = (lon + 360) % 360               # wrap longitude to [0, 360)
    return lat, lon

# Get mesh node positions (assume same for all time steps)
mesh_features = mesh_graph.nodes["mesh_nodes"].features  # shape: (2562, 3)
x, y, z = mesh_features[:, 0], mesh_features[:, 1], mesh_features[:, 2]
latitudes, longitudes = cartesian_to_latlon(x, y, z)

# Plot using Cartopy
fig = plt.figure(figsize=(10, 20))

for i in range(5):
    ax = plt.subplot(5, 1, i+1, projection=ccrs.PlateCarree())

    # Add borders and coastlines
    ax.add_feature(cfeature.BORDERS, edgecolor='black')
    ax.add_feature(cfeature.COASTLINE, linewidth=0.5)

    # Scatterplot of activations on mesh
    sc = ax.scatter(
        longitudes, latitudes,
        c=activations[i],
        cmap='Greens',
        s=50,
        alpha=0.7,
        transform=ccrs.PlateCarree()
    )

    # Colorbar
    cbar = plt.colorbar(sc, ax=ax, orientation='vertical', pad=0.02, shrink=0.7)
    cbar.set_label('Activation Magnitude')

    # Gridlines and labels
    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    gl.top_labels = False
    gl.right_labels = False

    ax.set_title(f"Step {i} - Node Activation Heatmap")

plt.tight_layout()
plt.savefig("node_activation_v1.pdf")

In [None]:
# Now plot the real data...
len(output.data_vars)

In [None]:
# create new train_inputs_plot xarray
# Get list of first 10 variable names
first_11_vars = list(train_inputs.data_vars)[:11]

# Create a new dataset with only those variables
train_inputs_subset = train_inputs[first_11_vars].copy()

train_inputs_subset.data_vars

# Assume train_inputs is already loaded

variables = list(train_inputs_subset.data_vars)
n_vars = len(variables)
ncols = 4
nrows = int(np.ceil(n_vars / ncols))

# Use PlateCarree projection (standard lat-lon)
projection = ccrs.PlateCarree()

fig, axes = plt.subplots(nrows, ncols,
                         figsize=(5 * ncols, 3.5 * nrows),
                         subplot_kw={'projection': projection},
                         constrained_layout=True)

axes = axes.flatten()

# Extract lat/lon
lat = train_inputs.coords["lat"].values
lon = train_inputs.coords["lon"].values
lon_grid, lat_grid = np.meshgrid(lon, lat)

for i, var_name in enumerate(variables):
    ax = axes[i]

    var_in = train_inputs_subset[var_name]
    var_out = output[var_name]
    var_sel_input = var_in.isel(batch=0, time=1)
    var_sel_output = var_out.isel(batch=0, time=0)

    if "level" in var_in.dims:
        level_values = var_in.coords["level"].values
        level_index = int(np.argmin(np.abs(level_values - 500)))
        var_sel_input = var_sel_input.isel(level=level_index)
        var_sel_output = var_sel_output.isel(level=level_index)

    data_2d_input = var_sel_input.values
    data_2d_output = var_sel_output.values

    # Compute symmetric vmin/vmax
    diff = data_2d_output - data_2d_input
    max_abs = np.nanmax(np.abs(diff))

    # Plot data
    im = ax.pcolormesh(lon_grid, lat_grid, diff,
                       shading='auto', cmap='coolwarm', vmin=-max_abs, vmax=+max_abs,
                       transform=ccrs.PlateCarree())

    # Add coastlines and land
    ax.coastlines(resolution='110m', linewidth=1)
   # ax.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax.add_feature(cfeature.LAND, facecolor='none', edgecolor='black', linewidth=0.3)

    ax.set_title(var_name)
    ax.set_xticks([-180, -90, 0, 90, 180], crs=projection)
    ax.set_yticks([-90, -45, 0, 45, 90], crs=projection)
    ax.gridlines(draw_labels=False, linewidth=0.2)
    plt.colorbar(im, ax=ax, shrink=0.7, orientation='horizontal')

# Turn off unused subplots
for j in range(n_vars, len(axes)):
    axes[j].axis('off')

plt.suptitle("2D Maps of differences between prediction and input data", fontsize=16)
plt.savefig("diff_v1.pdf")
