# Graphcast HRES model
Notebook to load and initiliase the model in order to create forecasts using HRES at full resolution and 13 levels.

## Imports and supporting functions

In [2]:
# imports

import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

import cartopy.crs as ccrs
from google.cloud import storage
import gcsfs
import zarr
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray

import dask.array as da

from typing import Tuple, Dict


def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))

In [3]:
# import nvidia_smi

# nvidia_smi.nvmlInit()

# handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
# # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate

# info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)

# print("Total memory:", info.total/10**9, "GB")
# print("Free memory:", info.free/10**9, "GB")
# print("Used memory:", info.used/10**9, "GB")

# nvidia_smi.nvmlShutdown()

# !nvidia-smi

In [4]:
# @title Authenticate with Google Cloud Storage

gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")

In [5]:
# @title Plotting functions
# Edited to allow saving of a plot to MP4

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: Dict[str, Tuple[xarray.Dataset, matplotlib.colors.Normalize, str]],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4,
    save_path: str = None,  # Add a parameter to specify the file path for saving the animation
    loop_count: int = 10    # Add a parameter to specify the number of times to loop the animation
) -> HTML:
  
  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols, plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
    actual_frame = frame % max_steps
    if "time" in first_data.dims:
      td = datetime.timedelta(microseconds=first_data["time"][actual_frame].item() / 1000)
      figure.suptitle(f"{fig_title}, {td}", fontsize=16)
    else:
      figure.suptitle(fig_title, fontsize=16)
    for im, (plot_data, norm, cmap) in zip(images, data.values()):
      im.set_data(plot_data.isel(time=actual_frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=range(loop_count * max_steps), interval=250)

  if save_path:
    ani.save(save_path, writer='ffmpeg', fps=4)  # Save the animation to an MP4 file

  plt.close(figure.number)
  return HTML(ani.to_jshtml())

## Load the model
Load the parameters and configurations of the Graphcast_operational model, to be initialized using HRES in 13 levels.

In [6]:
# Choosing the right model and parameters
source = "Checkpoint"
params_file_value = "GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz"

assert source == "Checkpoint"
with gcs_bucket.blob(f"params/{params_file_value}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)
params = ckpt.params
state = {}

model_config = ckpt.model_config
task_config = ckpt.task_config
print("Model description:\n", ckpt.description, "\n")
# print("Model license:\n", ckpt.license, "\n")

model_config

Model description:
 
GraphCast model at 0.25deg resolution, with 13 pressure levels. This model is
trained on ERA5 data from 1979 to 2017, and fine-tuned on HRES-fc0 data from
2016 to 2021 and can be causally evaluated on 2022 and later years. This model
does not take `total_precipitation_6hr` as inputs and can make predictions in an
operational setting (i.e., initialised from HRES-fc0).
 



ModelConfig(resolution=0.25, mesh_size=6, latent_size=512, gnn_msg_steps=16, hidden_layers=1, radius_query_fraction_edge_length=0.5999912857713345, mesh2grid_edge_normalization_factor=0.6180338738074472)

## Load data
Loading HRES from weatherbench and transforming it to match the shape of the graphcast_demo sample datasets.

In [7]:
try:
    # Create a GCSFileSystem object for accessing the GCS bucket
    fs = gcsfs.GCSFileSystem(anon=True)  # Use anon=True for public data
    print("GCSFileSystem object created successfully")

    # Use GCSFileSystem to access the Zarr store
    store = gcsfs.GCSMap(root='gs://weatherbench2/datasets/hres_t0/2016-2022-6h-1440x721.zarr', gcs=fs, check=False)
    print("GCSMap object created successfully")
    print(store)

    # Open the dataset with xarray, zarr engine is used by default for Zarr stores
    hres_6h = xarray.open_zarr(store, consolidated=True)
    print("Dataset opened successfully")

except Exception as e:
    print(f"An error occurred: {e}")

# Renaming coordinates
hres_6h = hres_6h.rename({'latitude': 'lat', 'longitude': 'lon'})

# Add a new 'datetime' dimension that is a copy of the 'time' dimension
hres_6h_var_sliced = hres_6h.assign_coords(datetime=hres_6h['time'])

# Create a new batch dimension
batch_dim = da.zeros(hres_6h_var_sliced.sizes['time'], dtype=int)

# Assign the batch dimension to the dataset
hres_6h_var_sliced = hres_6h_var_sliced.assign_coords(batch=('time', batch_dim))

# Reshape the datetime dimension to include the batch dimension
reshaped_datetime = hres_6h_var_sliced['datetime'].values.reshape((1, hres_6h_var_sliced.sizes['time']))
hres_6h_var_sliced = hres_6h_var_sliced.assign_coords(datetime=(['batch', 'time'], reshaped_datetime))

# Convert 'time' to 'timedelta64[ns]' with the first value being 0
start_time = hres_6h_var_sliced['time'].values[0]
time_deltas = hres_6h_var_sliced['time'].values - start_time
hres_6h_var_sliced = hres_6h_var_sliced.assign_coords(time=('time', time_deltas))

# Reshape variables to include the 'batch' dimension using Dask arrays
reshaped_vars = {}
for var in hres_6h_var_sliced.data_vars:
    reshaped_vars[var] = (['batch', 'time'] + [dim for dim in hres_6h_var_sliced[var].dims if dim != 'time'],
                          hres_6h_var_sliced[var].data.reshape((1, len(hres_6h_var_sliced['time'])) + hres_6h_var_sliced[var].shape[1:]))

# Create the new dataset with reshaped variables
ds_new = xarray.Dataset(reshaped_vars,
                    coords={'lon': hres_6h_var_sliced['lon'], 'lat': hres_6h_var_sliced['lat'], 'level': hres_6h_var_sliced['level'], 'time': hres_6h_var_sliced['time'],
                            'datetime': (['batch', 'time'], hres_6h_var_sliced['datetime'].values.reshape((1, len(hres_6h_var_sliced['time']))))})

# Ensure all other coordinates are correctly assigned
for coord in hres_6h_var_sliced.coords:
    if coord not in ds_new.coords and coord != 'datetime':
        ds_new = ds_new.assign_coords({coord: hres_6h_var_sliced[coord]})


# open the land_sea_mask and geopotential_at_surface datasets
ds_land_sea_mask = xarray.open_dataset('ds_land_sea_mask.nc')
ds_geopotential_at_surface = xarray.open_dataset('ds_geopotential_at_surface.nc')

# Add land_sea_mask and geopotential_at_surface to the new dataset
ds_new['land_sea_mask'] = ds_land_sea_mask['land_sea_mask']
ds_new['geopotential_at_surface'] = ds_geopotential_at_surface['geopotential_at_surface']

# Drop the redundant 'batch' coordinate (if not needed)
ds_new = ds_new.drop_vars('batch')

# Slicing the data for the variables required
ds_new = ds_new[[  
                                'geopotential_at_surface',
                                'land_sea_mask',                 
                                '2m_temperature',                
                                'mean_sea_level_pressure',       
                                '10m_v_component_of_wind',       
                                '10m_u_component_of_wind',       
                                'total_precipitation_6hr',       
                                # 'toa_incident_solar_radiation',  ----> to fix  
                                'temperature',                   
                                'geopotential',                  
                                'u_component_of_wind',           
                                'v_component_of_wind',          
                                'vertical_velocity',             
                                'specific_humidity']]

# Verify the updated dataset
hres_gc_shaped = ds_new
hres_gc_shaped

# ds_new['temperature'].sel(time='0',lat=1, lon=10, level=850,method='Nearest').values

GCSFileSystem object created successfully
GCSMap object created successfully
<fsspec.mapping.FSMap object at 0x7f5a95d5ef50>
Dataset opened successfully


Unnamed: 0,Array,Chunk
Bytes,39.71 GiB,3.96 MiB
Shape,"(1, 10268, 721, 1440)","(1, 1, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 39.71 GiB 3.96 MiB Shape (1, 10268, 721, 1440) (1, 1, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",1  1  1440  721  10268,

Unnamed: 0,Array,Chunk
Bytes,39.71 GiB,3.96 MiB
Shape,"(1, 10268, 721, 1440)","(1, 1, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,39.71 GiB,3.96 MiB
Shape,"(1, 10268, 721, 1440)","(1, 1, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 39.71 GiB 3.96 MiB Shape (1, 10268, 721, 1440) (1, 1, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",1  1  1440  721  10268,

Unnamed: 0,Array,Chunk
Bytes,39.71 GiB,3.96 MiB
Shape,"(1, 10268, 721, 1440)","(1, 1, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,39.71 GiB,3.96 MiB
Shape,"(1, 10268, 721, 1440)","(1, 1, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 39.71 GiB 3.96 MiB Shape (1, 10268, 721, 1440) (1, 1, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",1  1  1440  721  10268,

Unnamed: 0,Array,Chunk
Bytes,39.71 GiB,3.96 MiB
Shape,"(1, 10268, 721, 1440)","(1, 1, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,39.71 GiB,3.96 MiB
Shape,"(1, 10268, 721, 1440)","(1, 1, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 39.71 GiB 3.96 MiB Shape (1, 10268, 721, 1440) (1, 1, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",1  1  1440  721  10268,

Unnamed: 0,Array,Chunk
Bytes,39.71 GiB,3.96 MiB
Shape,"(1, 10268, 721, 1440)","(1, 1, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,39.71 GiB,3.96 MiB
Shape,"(1, 10268, 721, 1440)","(1, 1, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 39.71 GiB 3.96 MiB Shape (1, 10268, 721, 1440) (1, 1, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",1  1  1440  721  10268,

Unnamed: 0,Array,Chunk
Bytes,39.71 GiB,3.96 MiB
Shape,"(1, 10268, 721, 1440)","(1, 1, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 516.28 GiB 51.49 MiB Shape (1, 10268, 13, 721, 1440) (1, 1, 13, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",10268  1  1440  721  13,

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 516.28 GiB 51.49 MiB Shape (1, 10268, 13, 721, 1440) (1, 1, 13, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",10268  1  1440  721  13,

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 516.28 GiB 51.49 MiB Shape (1, 10268, 13, 721, 1440) (1, 1, 13, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",10268  1  1440  721  13,

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 516.28 GiB 51.49 MiB Shape (1, 10268, 13, 721, 1440) (1, 1, 13, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",10268  1  1440  721  13,

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 516.28 GiB 51.49 MiB Shape (1, 10268, 13, 721, 1440) (1, 1, 13, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",10268  1  1440  721  13,

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 516.28 GiB 51.49 MiB Shape (1, 10268, 13, 721, 1440) (1, 1, 13, 721, 1440) Dask graph 10268 chunks in 3 graph layers Data type float32 numpy.ndarray",10268  1  1440  721  13,

Unnamed: 0,Array,Chunk
Bytes,516.28 GiB,51.49 MiB
Shape,"(1, 10268, 13, 721, 1440)","(1, 1, 13, 721, 1440)"
Dask graph,10268 chunks in 3 graph layers,10268 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Slicing for timesteps

In [8]:
def datetime_slicer(ds, target_start_date, target_end_date):
    """Slices the dataset based on a start and end date.

    Args:
        ds: dataset containing time dimension in timedelta64[ns]
        target_start_date: start of the slice wanted in datetime64
        target_end_date: end of the slice wanted in datetime64

    Returns: 
        ds_slice: a new dataset of the same shape but sliced in the time dimension
    """


    # Assuming ds is your dataset and already loaded into memory
    # Ensure 'datetime' is properly set as a coordinate if not already done
    # if 'datetime' not in ds.coords:
    #     ds = ds.assign_coords(datetime=(('batch', 'time'), ds['datetime'].values.reshape((1, len(ds['time'])))))

    # # Set 'datetime' as a coordinate if it's not already
    # ds = ds.set_coords('datetime')

    # Reference start date
    start_date = np.datetime64('2016-01-01T00:00:00.000000000')

    # Target date range
    target_start_date = np.datetime64(target_start_date)
    target_end_date = np.datetime64(target_end_date)

    # Calculate timedelta64 for the target range
    timedelta_start = target_start_date - start_date
    timedelta_end = target_end_date - start_date

    # Slice the dataset using the 'time' dimension
    ds_slice = ds.sel(time=slice(timedelta_start, timedelta_end))

    return ds_slice

def datetime_steps_slicer(ds, target_start_date, num_timesteps):
    """Slices the dataset based on a start date and a number of timesteps.

    Args:
        ds: dataset containing time dimension in timedelta64[ns]
        target_start_date: start of the slice wanted in datetime64
        num_timesteps: number of timesteps to slice

    Returns: 
        ds_slice: a new dataset of the same shape but sliced in the time dimension
    """
    num_timesteps =num_timesteps - 1 #ensure that the total number of timesteps is equal to num_timesteps provided

    # Ensure 'datetime' is properly set as a coordinate if not already done
    # if 'datetime' not in ds.coords:
    #     ds = ds.assign_coords(datetime=(('batch', 'time'), ds['datetime'].values.reshape((1, len(ds['time'])))))

    # # Set 'datetime' as a coordinate if it's not already
    # ds = ds.set_coords('datetime')

    # Reference start date
    start_date = np.datetime64('2016-01-01T00:00:00.000000000')

    # Target start date
    target_start_date = np.datetime64(target_start_date)

    # Calculate timedelta64 for the start date
    timedelta_start = target_start_date - start_date

    # Calculate the duration of each timestep (6 hours)
    timestep_duration = np.timedelta64(6, 'h')

    # Calculate the end time by adding the number of timesteps to the start time
    timedelta_end = timedelta_start + num_timesteps * timestep_duration

    # Slice the dataset using the 'time' dimension
    ds_slice = ds.sel(time=slice(timedelta_start, timedelta_end))

    return ds_slice



#### Selecting data for predictions and loading into memory

In [9]:
timesteps = 4
hres_gc_shaped_sliced = datetime_steps_slicer(hres_gc_shaped, '2022-01-01', timesteps)
example_batch = hres_gc_shaped_sliced.compute()
example_batch


In [10]:
# # @title Choose data to plot

# plot_example_variable = widgets.Dropdown(
#     options=example_batch.data_vars.keys(),
#     value="u_component_of_wind",
#     description="Variable")
# plot_example_level = widgets.Dropdown(
#     options=example_batch.coords["level"].values,
#     value=1000,
#     description="Level")
# plot_example_robust = widgets.Checkbox(value=True, description="Robust")
# plot_example_max_steps = widgets.IntSlider(
#     min=1, max=example_batch.dims["time"], value=example_batch.dims["time"],
#     description="Max steps")

# widgets.VBox([
#     plot_example_variable,
#     plot_example_level,
#     plot_example_robust,
#     plot_example_max_steps,
#     widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
# ])

In [11]:
# # @title Plot example data

# plot_size = 5

# data = {
#     " ": scale(select(example_batch, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),
#               robust=plot_example_robust.value),
# }
# fig_title = plot_example_variable.value
# if "level" in example_batch[plot_example_variable.value
# ].coords:
#   fig_title += f" at {plot_example_level.value} hPa"

# plot_data(data, fig_title, plot_size, plot_example_robust.value)

## Running the model

In [12]:
# @title Choose training and eval data to extract
train_steps = widgets.IntSlider(
    value=1, min=1, max=example_batch.sizes["time"]-2, description="Train steps")
eval_steps = widgets.IntSlider(
    value=example_batch.sizes["time"]-2, min=1, max=example_batch.sizes["time"]-2, description="Eval steps")

widgets.VBox([
    train_steps,
    eval_steps,
    widgets.Label(value="Run the next cell to extract the data. Rerunning this cell clears your selection.")
])

VBox(children=(IntSlider(value=1, description='Train steps', max=2, min=1), IntSlider(value=2, description='Ev…

In [21]:
eval_steps.value

2

: 

In [13]:
# @title Extract training and eval data

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{train_steps.value*6}h"),
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{eval_steps.value*6}h"),
    **dataclasses.asdict(task_config))

print("All Examples:  ", example_batch.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)


2024-08-14 10:28:23.545157: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


All Examples:   {'lat': 721, 'lon': 1440, 'batch': 1, 'time': 4, 'level': 13}
Train Inputs:   {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Train Targets:  {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440, 'level': 13}
Train Forcings: {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440}
Eval Inputs:    {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Eval Targets:   {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Eval Forcings:  {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440}


In [14]:
# @title Load normalization data

with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()

In [15]:
# @title Build jitted functions, and possibly initialize random weights

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets_template=train_targets,
      forcings=train_forcings)

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))

In [19]:
eval_targets

In [20]:
# @title Autoregressive rollout (loop in python)

assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)  
predictions

Inputs:   {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Targets:  {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Forcings: {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440}


  num_target_steps = targets_template.dims["time"]
  scan_length = targets_template.dims['time']
  num_inputs = inputs.dims['time']
  num_inputs = prev_inputs.dims["time"]


### Plotting the predictions

In [None]:
# # @title Choose predictions to plot

# plot_pred_variable = widgets.Dropdown(
#     options=predictions.data_vars.keys(),
#     value="2m_temperature",
#     description="Variable")
# plot_pred_level = widgets.Dropdown(
#     options=predictions.coords["level"].values,
#     value=500,
#     description="Level")
# plot_pred_robust = widgets.Checkbox(value=True, description="Robust")
# plot_pred_max_steps = widgets.IntSlider(
#     min=1,
#     max=predictions.dims["time"],
#     value=predictions.dims["time"],
#     description="Max steps")

# widgets.VBox([
#     plot_pred_variable,
#     plot_pred_level,
#     plot_pred_robust,
#     plot_pred_max_steps,
#     widgets.Label(value="Run the next cell to plot the predictions. Rerunning this cell clears your selection.")
# ])

In [None]:
# # @title Plot predictions

# plot_size = 5
# plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)

# data = {
#     "Targets": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
#     "Predictions": scale(select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
#     "Diff": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
#                         select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
#                        robust=plot_pred_robust.value, center=0),
# }
# fig_title = plot_pred_variable.value
# if "level" in predictions[plot_pred_variable.value].coords:
#   fig_title += f" at {plot_pred_level.value} hPa"

# plot_data(data, fig_title, plot_size, plot_pred_robust.value)
