> Adapted from DeepMind Technologies Limited code under Apache 2.0.
>
> Specifically, this is based on the provided demo notebook available at the following [link](https://github.com/google-deepmind/graphcast/blob/main/gencast_mini_demo.ipynb).
>
> This code is distributed under the same license.

# Distilled GenCast inference and evaluation

## Initialization

In [None]:
import dataclasses
import xarray
import numpy as np
import gcsfs
import jax
import io
import copy
from IPython import display

import gencast_distillation.config as config
from inference_helpers import init_and_compile
from plotting_helpers import select, scale, plot_data
from eval_helpers import evaluate
from graphcast import checkpoint
from graphcast import gencast
from graphcast import data_utils
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import normalization
from graphcast import nan_cleaning
from graphcast import xarray_tree

## Plotting functions

## Load data and initialize model

In [4]:
bucket_base = "gs://gencast-distillation-bucket"
fs = gcsfs.GCSFileSystem()

In [5]:
weights_path = f"{bucket_base}/gencast_weights/gencast_params_GenCast 1p0deg Mini _2019.npz"
with fs.open(weights_path, 'rb') as f:
    model_weights = io.BytesIO(f.read())

teacher_ckpt = checkpoint.load(model_weights, gencast.CheckPoint)
params = teacher_ckpt.params
state = {}

task_config = teacher_ckpt.task_config
sampler_config = teacher_ckpt.sampler_config
noise_config = teacher_ckpt.noise_config
noise_encoder_config = teacher_ckpt.noise_encoder_config
denoiser_architecture_config = teacher_ckpt.denoiser_architecture_config

print(task_config)
print("===")
print(sampler_config) # where we should see changes
print("===")
print(noise_config)
print("===")
print(noise_encoder_config)
print("===")
print(denoiser_architecture_config)

In [None]:
print(teacher_ckpt.params)


In [1]:
import pickle

with open("../student_ckpt.pkl", "rb") as f:
    student_params = pickle.load(f)

print(student_params)

# student_ckpt = copy.deepcopy(teacher_ckpt)
# # student_ckpt.params = student_params
# student_ckpt.sampler_config.num_noise_levels = 10

In [None]:
era5_data =  config.example_data

In [None]:
train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    era5_data, target_lead_times=slice("12h", "12h"), # Only 1AR training.
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    era5_data, target_lead_times=slice("12h", f"{(era5_data.dims['time']-2)*12}h"), # All but 2 input frames.
    **dataclasses.asdict(task_config))

In [None]:
norm_data = config.normalization_data 

## Run inference

In [None]:
# # TODO: move to helper

# def construct_wrapped_gencast():
#   """Constructs and wraps the GenCast Predictor."""
#   predictor = gencast.GenCast(
#       sampler_config=sampler_config,
#       task_config=task_config,
#       denoiser_architecture_config=denoiser_architecture_config,
#       noise_config=noise_config,
#       noise_encoder_config=noise_encoder_config,
#   )

#   predictor = normalization.InputsAndResiduals(
#       predictor,
#       diffs_stddev_by_level=norm_data["diffs_stddev_by_level"],
#       mean_by_level=norm_data["mean_by_level"],
#       stddev_by_level=norm_data["stddev_by_level"],
#   )

#   predictor = nan_cleaning.NaNCleaner(
#       predictor=predictor,
#       reintroduce_nans=True,
#       fill_value=norm_data["min_by_level"],
#       var_to_clean='sea_surface_temperature',
#   )

#   return predictor


# @hk.transform_with_state
# def run_forward(inputs, targets_template, forcings):
#   predictor = construct_wrapped_gencast()
#   return predictor(inputs, targets_template=targets_template, forcings=forcings)


# @hk.transform_with_state
# def loss_fn(inputs, targets, forcings):
#   predictor = construct_wrapped_gencast()
#   loss, diagnostics = predictor.loss(inputs, targets, forcings)
#   return xarray_tree.map_structure(
#       lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
#       (loss, diagnostics),
#   )


# if params is None:
#   init_jitted = jax.jit(loss_fn.init)
#   params, state = init_jitted(
#       rng=jax.random.PRNGKey(0),
#       inputs=train_inputs,
#       targets=train_targets,
#       forcings=train_forcings,
#   )

# run_forward_jitted = jax.jit(
#     lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]
# )
# # We also produce a pmapped version for running in parallel.
# run_forward_pmap = xarray_jax.pmap(run_forward_jitted, dim="sample")

In [None]:
# The number of ensemble members should be a multiple of the number of devices.
# This should be adapted in the next cell
print(f"Number of local devices {len(jax.local_devices())}")

In [None]:
params, state, run_forward_jitted, run_forward_pmap = init_and_compile(
    rng_key=jax.random.PRNGKey(0),
    sampler_config=sampler_config,
    task_config=task_config,
    denoiser_architecture_config=denoiser_architecture_config,
    noise_config=noise_config,
    noise_encoder_config=noise_encoder_config,
    norm_data=norm_data,
    train_inputs=train_inputs,
    train_targets=train_targets,
    train_forcings=train_forcings,
)


num_ensemble_members = 40
rng = jax.random.PRNGKey(0)
rngs = np.stack(
    [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)

chunks = []
for chunk in rollout.chunked_prediction_generator_multiple_runs(
    # Use pmapped version to parallelise across devices.
    predictor_fn=run_forward_pmap,
    rngs=rngs,
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings,
    num_steps_per_chunk = 1,
    num_samples = num_ensemble_members,
    pmap_devices=jax.local_devices()
    ):
    chunks.append(chunk)
    
predictions = xarray.combine_by_coords(chunks)

## Plot predictions

Here we plot the predictions for `2m_temperature`, which is the air temperature at 2 meters above the surface. Note that this can be substituted with any variable we predict, and the corresponding pressure level where applicable.

In [None]:
plot_size = 5
plot_max_steps = predictions.dims["time"]
level = predictions.coords['level'].values[0]

fig_title = "2m_temperature"
if "level" in predictions["2m_temperature"].coords:
  fig_title += f" at {level} hPa"

for sample_idx in range(min(4, num_ensemble_members)):
  data = {
      "Targets": scale(select(eval_targets, "2m_temperature", level, plot_max_steps), robust=True),
      "Predictions": scale(select(predictions.isel(sample=sample_idx), "2m_temperature", level, plot_max_steps), robust=True),
      "Diff": scale((select(eval_targets, "2m_temperature", level, plot_max_steps) -
                          select(predictions.isel(sample=sample_idx), "2m_temperature", level, plot_max_steps)),
                        robust=True, center=0),
  }
  display.display(plot_data(data, fig_title + f", Sample {sample_idx}", plot_size, True))

In [None]:
def crps(targets, predictions, bias_corrected = True):
  if predictions.sizes.get("sample", 1) < 2:
    raise ValueError(
        "predictions must have dim 'sample' with size at least 2.")
  sum_dims = ["sample", "sample2"]
  preds2 = predictions.rename({"sample": "sample2"})
  num_samps = predictions.sizes["sample"]
  num_samps2 = (num_samps - 1) if bias_corrected else num_samps
  mean_abs_diff = np.abs(
      predictions - preds2).sum(
          dim=sum_dims, skipna=False) / (num_samps * num_samps2)
  mean_abs_err = np.abs(targets - predictions).sum(dim="sample", skipna=False) / num_samps
  return mean_abs_err - 0.5 * mean_abs_diff


plot_size = 5
plot_max_steps = predictions.dims["time"]
level = predictions.coords['level'].values[0]

fig_title = "2m_temperature"
if "level" in predictions["2m_temperature"].coords:
  fig_title += f" at {level} hPa"

data = {
    "Targets": scale(select(eval_targets, "2m_temperature", level, plot_max_steps), robust=True),
    "Ensemble Mean": scale(select(predictions.mean(dim=["sample"]), "2m_temperature", level, plot_max_steps), robust=True),
    "Ensemble CRPS": scale(crps((select(eval_targets, "2m_temperature", level, plot_max_steps)),
                        select(predictions, "2m_temperature", level, plot_max_steps)),
                      robust=True, center=0),
}
display.display(plot_data(data, fig_title, plot_size, True))

In [None]:
evaluate(predictions, eval_targets)