In [None]:
import dataclasses
import xarray
import matplotlib.pyplot as plt
import sys
sys.path.append('graphcast')
from graphcast import graphcast, checkpoint, normalization, autoregressive, casting, data_utils, rollout
import jax
import haiku as hk
import numpy as np
import functools

import pandas as pd

In [None]:
## all
import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

import cartopy.crs as ccrs
from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray


In [None]:
from graphcast_functional import *

In [49]:
eval_forcings

In [None]:
example_batch.sel({"time": target_lead_times})

In [48]:

src_diffs_stddev_by_level = "data/stats/diffs_stddev_by_level.nc"
src_mean_by_level = "data/stats/mean_by_level.nc"
src_stddev_by_level = "data/stats/stddev_by_level.nc"

with open(src_diffs_stddev_by_level, "rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(src_mean_by_level, "rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()
with open(src_stddev_by_level, "rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()




src = "data/params/GraphCast_ERA5_1979-2017_Resolution-0.25_PressureLevels-37_Mesh-2to6_PrecipitationInputOutput.npz"
with open(src, "rb",) as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)

params = ckpt.params
state = {}

model_config = ckpt.model_config
task_config = ckpt.task_config
print("Model description:/n", ckpt.description, "/n")
print("Model license:/n", ckpt.license, "/n")


example_batch_src = "data/datasets/source-era5_date-2022-01-01_res-0.25_levels-37_steps-01.nc"
with open(example_batch_src, "rb") as f:
    example_batch = xarray.load_dataset(f).compute()

eval_steps = 1


eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{eval_steps*6}h"),
    **dataclasses.asdict(task_config))


#eval_inputs, eval_targets = data_utils.extract_input_target_times(dataset=example_batch, input_duration="6h", target_lead_times=("1d"))

Model description:/n 
GraphCast model at 0.25deg resolution, with 37 pressure levels. This model is
trained on ERA5 data from 1979 to 2017, and can be causally evaluated on 2018
and later years. This model takes as inputs `total_precipitation_6hr`. This was
described in the paper
`GraphCast: Learning skillful medium-range global weather forecasting`
(https://arxiv.org/abs/2212.12794).
 /n
Model license:/n 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.
 /n


In [None]:
def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=False)
  return predictor



def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

In [None]:
# Transform the function with Haiku
run_forward = hk.transform_with_state(run_forward)

init_jitted = jax.jit(with_configs(run_forward.init))

#grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))


In [None]:
# load first prediction
src = "predicted_dataset.nc"
with open(src, "rb",) as f:
    predicted_dataset = xarray.load_dataset(f).compute()

In [47]:
predictions24h_in6h_steps = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)

ValueError: 'grid2mesh_gnn/~_networks_builder/encoder_nodes_grid_nodes_mlp/~/linear_0/w' with retrieved shape (474, 512) does not match shape=[10, 512] dtype=dtype(bfloat16)