# SEEDS Demo

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/google-research/blob/master/seeds/SEEDS_Demo.ipynb)

**Enable TPU to run the notebook**

- This demo colab runs on the free hosted TPU colab kernel.
- Navigate to the menu item `Runtime` → `Change runtime type`.
- Select `TPU` and click `Save`.
- Click on `Connect` on the top right. You are ready once you see  `✓TPU` next to the RAM & Disk display.

It is also possible to run this by choosing "GPU" as the runtime type. The checkpoints currently are not optimized for GPUs so this will be a lot slower. Also the default TPU kernel has 8 cores while the GPU kernel only has 1 core. The inference batch size is set per core in the code below so running on GPU will result in less generated samples.

## Copyright notice

Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product.


# Preparation

In [None]:
# @title Install external packages
!pip install ecmwflibs cfgrib eccodes
!pip install cartopy matplotlib numpy pandas scipy seaborn tqdm
!pip install xarray[complete]

In [None]:
# @title Install SEEDS package
%%shell
git clone -n --depth=1 --filter=tree:0 https://github.com/google-research/google-research
cd google-research
git sparse-checkout set --no-cone seeds
git checkout
cd seeds
pip install --no-deps .

# Basic SEEDS using example data

In [None]:
# @title Imports
import os

import cartopy.crs as ccrs
import cartopy.util
import gcsfs
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
import tqdm.notebook as tqdm
import xarray as xr

from seeds import grid_lib

sns.set_theme(context='paper', style='white', font_scale=1.2)
sns.set_palette('colorblind')
fs = gcsfs.GCSFileSystem('anon')

In [None]:
# @title Initialize accelerators
for device in devices:
  if device.device_type == 'GPU':
    strategy = tf.distribute.MirroredStrategy()
    print(f'GPUs found: {strategy.num_replicas_in_sync}')
    print('The model checkpoints are not optimized for GPUs so this will be slow.')
    break
  if device.device_type == 'TPU':
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)
    print(f'TPUs found: {strategy.num_replicas_in_sync}.')
    break

In [None]:
# @title List all available checkpoints
base_dir = 'gs://gresearch/seeds'

checkpoint_dir = f'{base_dir}/model_checkpoints'
for path in tf.io.gfile.glob(checkpoint_dir + '/*'):
  if not path.endswith('_$folder$'):
    print(os.path.basename(path))

Checkpoint naming convention:
- `gee_c2_s7`: SEEDS-GEE trained conditioning on 2 seeds for 7-day lead time.
- `gpp_c2_s7_g3_r4`: SEEDS-GPP trained conditioning on 2 seeds for 7-day lead time, where the label mixture is 3 GEFS members and 4 ERA reanalyses.

In [None]:
# @title Choose a model checkpoint
checkpoint = "gee_c2_s7" # @param {type:"string"}

In [None]:
# @title Load the checkpoint
name_parts = checkpoint.split('_')
num_seeds = int(name_parts[1][1:])
lead_days = int(name_parts[2][1:])
print('Number of seeds:', num_seeds)
print('Lead time (days):', lead_days)

with strategy.scope():
  model = tf.saved_model.load(f'{checkpoint_dir}/{checkpoint}')
print('Model total number of parameters:', sum([tf.size(var) for var in model._variables]).numpy())

In [None]:
# @title Inspect the GEFS data for 2022 already regridded to the cubed sphere at 2 degrees.
gefs = xr.open_zarr(fs.get_mapper(f'{base_dir}/data/gefs_forecast_2022_cubedsphere.zarr'))
gefs

In [None]:
# @title Load one GEFS snapshot and create the wrapper for plotting
base_time = pd.Timestamp('2022-01-01')

snapshot = gefs.sel(time=base_time, number=0, step=lead_days)['anomaly'].load()
grid = grid_lib.CubedSphere.on(snapshot.data)
plot_gridder = grid.plot_gridder()

def wrap(data):
  rec = plot_gridder(data)
  new_rec, new_lon = cartopy.util.add_cyclic_point(rec, rec.longitude)
  cyclic = xr.DataArray(new_rec, coords={'latitude': rec.latitude, 'longitude': new_lon})
  return cyclic

In [None]:
# @title Plot the forecasted anomaly from one GEFS member
fig, axes = plt.subplots(4, 2, figsize=(10, 12), sharex=True, sharey=True, subplot_kw=dict(projection=ccrs.Robinson()))
lon, lat = gefs.longitude.data, gefs.latitude.data
for i, ax in enumerate(axes.flat):
  wrap(snapshot.data[i]).plot(cmap='Spectral', transform=ccrs.PlateCarree(), ax=ax, add_colorbar=False)
  ax.set_title(snapshot.field.data[i])
  ax.coastlines()
fig.subplots_adjust(wspace=0, hspace=0.1)

In [None]:
# @title Load the base time climatology
climatology = xr.open_zarr(fs.get_mapper(f'{base_dir}/data/climatology_cubedsphere.zarr')).load()
monthday = base_time.month * 100 + base_time.day
clim_mean = climatology.sel(monthday=monthday)['mean'].load()

Note: SEEDS uses a fixed day-of-year climatology. The model uses the **base time** climatology as the input and samples **valid time** anomalies. So to convert the outputs back to the raw values, the valid time climatology also need to be loaded (this will be done in the later part of this colab).

In [None]:
# @title Create conditining information
# Take the first num_seeds GEFS anomalies as seeds.
cond_anomaly = gefs.sel(time=base_time, number=np.arange(num_seeds), step=lead_days)['anomaly'].load().data
# Concatenate those with the climatology to get the conditioning input
cond_clim_mean = clim_mean.data
cond = np.concatenate([cond_anomaly, cond_clim_mean[None, ...]], axis=0)
print('Conditioning shape (#inputs, #fields, #locations) =',cond.shape)

In [None]:
# @title Utility functions for distributing data across accelerators
def distribute(strategy, arr):
  if arr.shape[0] % strategy.num_replicas_in_sync != 0:
    raise ValueError('The batch size should be a multiple of num_replicas_in_sync.')
  local_size = arr.shape[0] // strategy.num_replicas_in_sync
  def value_fn(ctx):
    k = ctx.replica_id_in_sync_group
    return tf.cast(arr[k * local_size:(k + 1) * local_size], tf.float32)
  return strategy.experimental_distribute_values_from_function(value_fn)

def split(strategy, arr):
  if arr.shape[0] % strategy.num_replicas_in_sync != 0:
    raise ValueError('The batch size should be a multiple of num_replicas_in_sync.')
  def value_fn(ctx):
    return arr[ctx.replica_id_in_sync_group]
  return strategy.experimental_distribute_values_from_function(value_fn)

In [None]:
# @title Generate more ensemble members
batchsize = 2 * strategy.num_replicas_in_sync
# The sampling function is completely deterministic for a fixed model_rng. So
# each replica should have its own unique model_rng.
model_rng = tf.constant(np.arange(strategy.num_replicas_in_sync) + 42, tf.int64)
# Reducing num_diffusion_steps leads to faster generation but might degrade quality.
num_diffusion_steps = tf.constant(600, tf.int64)
min_diffusion_time = tf.constant(1e-3, tf.float32)

# To generate batchsize samples at a time, we duplicate cond as a batch.
tiled_cond = tf.cast(tf.tile(cond[None, ...], (batchsize, 1, 1, 1)), tf.float32)

# Run the sampler.
dist_model_rng = split(strategy, model_rng)
dist_conditioning = distribute(strategy, tiled_cond)
samples = strategy.run(model.sample, args=(dist_conditioning, dist_model_rng, num_diffusion_steps, min_diffusion_time))
samples = strategy.gather(samples, axis=0).numpy()

print('Samples shape: (#samples, 1, #fields, #locations) =', samples.shape)

In [None]:
# @title Plot the generated results
field_id = 1
plot_opts = dict(cmap='Spectral', transform=ccrs.PlateCarree(), add_colorbar=False)

seeds = cond[:num_seeds, field_id]
vmin, vmax = seeds.min() * 0.9, seeds.max() * 0.9

fig, axes = plt.subplots(4, 3, figsize=(12, 10), sharex=True, sharey=True, subplot_kw=dict(projection=ccrs.Robinson()))
for i, ax in enumerate(axes[0].flat):
  if i < num_seeds:
    wrap(seeds[i]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)
    ax.coastlines()
    ax.set_title(f'Cond #{i+1}')
for i, ax in enumerate(axes[1:].flat):
  if i < batchsize:
    wrap(samples[i, 0, field_id]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)
    ax.coastlines()
    ax.set_title(f'Generated #{i+1}')
fig.subplots_adjust(wspace=0, hspace=0.1)
fig.suptitle(f'field={gefs.field.data[field_id]}');

In [None]:
# @title Use the valid time climatology to map the anomalies to raw values
valid_time = base_time + pd.Timedelta(days=lead_days)
valid_monthday = valid_time.month * 100 + valid_time.day
valid_clim = climatology.sel(monthday=valid_monthday).load()
clim_mean = valid_clim['mean'].data
clim_std = valid_clim['std'].data

raw_samples = samples * clim_std + clim_mean
raw_cond = cond * clim_std + clim_mean

In [None]:
# @title Plot the generated results in raw values
field_id = 4
plot_opts = dict(cmap='Reds', transform=ccrs.PlateCarree(), add_colorbar=False)

seeds = raw_cond[:num_seeds, field_id]
vmin, vmax = seeds.min() * 0.9, seeds.max() * 0.9

fig, axes = plt.subplots(4, 3, figsize=(12, 10), sharex=True, sharey=True, subplot_kw=dict(projection=ccrs.Robinson()))
for i, ax in enumerate(axes[0].flat):
  if i < num_seeds:
    wrap(seeds[i]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)
    ax.coastlines()
    ax.set_title(f'Cond #{i+1}')
for i, ax in enumerate(axes[1:].flat):
  if i < batchsize:
    wrap(raw_samples[i, 0, field_id]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)
    ax.coastlines()
    ax.set_title(f'Generated #{i+1}')
fig.subplots_adjust(wspace=0, hspace=0.1)
fig.suptitle(f'field={gefs.field.data[field_id]}');

# Using SEEDS with live operational GEFS data

Read the operational GEFS data on Google Cloud published by NOAA. For more information, see the [website](https://console.cloud.google.com/marketplace/product/noaa-public/gfs-ensemble-forecast-system).


In [None]:
# @title Operational GEFS data reader code
def download_file(path):
  local_name = os.path.basename(path)
  if not os.path.exists(local_name):
    tf.io.gfile.copy(path, local_name)

def make_operational_gefs_aws_url(date, lead_days, number, file='pgrb2a'):
  prefix = f'gep{number:02}' if number > 0 else 'gec00'
  # SEEDS models are trained only on the 00 hour forecast.
  return f'gs://gfs-ensemble-forecast-system/gefs.{date}/00/atmos/{file}p5/{prefix}.t00z.{file}.0p50.f{24 * lead_days}'

def make_hPa_getter(in_name, out_name, levels):
  def get(x):
    x = x[in_name].sel(isobaricInhPa=levels)
    x['isobaricInhPa'] = [f'{out_name}_{level}hPa' for level in levels]
    return x.to_dataset('isobaricInhPa')
  return get

def load_grib(path, getters):
  coords = {'latitude', 'longitude', 'step', 'time'}
  res = []
  for selector, getter in getters:
    res.append(getter(xr.open_dataset(path, engine='cfgrib', filter_by_keys=selector)))
  res = xr.merge([part.drop(set(part.coords.keys()) - coords) for part in res])
  return res.load()

def load_gefs_grib(path, file='pgrb2a'):
  if file == 'pgrb2a':
    getters = [
      ({'paramId': 167}, lambda x: x.rename({'t2m': 't_2m'})),
      ({'paramId': 3054}, lambda x: x.rename({'pwat': 'tcwv'})),
      ({'paramId': 130}, make_hPa_getter('t', 't', [850])),
      ({'paramId': 131}, make_hPa_getter('u', 'u', [850])),
      ({'paramId': 132}, make_hPa_getter('v', 'v', [850])),
      ({'paramId': 156}, make_hPa_getter('gh', 'z', [500])),
    ]
  else:
    getters = [
      ({'paramId': 151}, lambda x: x.rename({'msl': 'msl'})),
      ({'paramId': 133, 'typeOfLevel': 'isobaricInhPa'}, make_hPa_getter('q', 'q', [500])),
    ]
  return load_grib(path, getters)

g = 9.80665  # Gravitational acceleration.

def gefs_to_era5_units(ds):
  for field in ds.data_vars:
    if field.startswith('z_'):
      ds[field] = ds[field] * 9.80665 # Unit: gpm -> dm
  return ds

In [None]:
# @title Choose a base time and and lead time
base_date = "20231026" # @param {type:"string"}
lead_days = 7 # @param {type:"integer"}

print('valid_date is', (pd.Timestamp(base_date) + pd.Timedelta(days=lead_days)).strftime('%Y%m%d'))


Here to save time we only download the first 8 members. For the best results, download all 31 GEFS members instead.

In [None]:
# @title Load GEFS data
forecasts = []
for number in range(8):  # To download all GEFS members, change this to 31.
  forecast = []
  for file in ['pgrb2a', 'pgrb2b']:
    url = make_operational_gefs_aws_url(base_date, lead_days, number, file)
    print(f'Fetch {url}...', flush=True)
    download_file(url)
    filename = os.path.basename(url)
    forecast.append(load_gefs_grib(filename, file=file))
  forecast = xr.merge(forecast).assign_coords({'number': number})
  forecasts.append(forecast)
forecasts = xr.concat(forecasts, 'number')

In [None]:
# @title Process the raw GEFS to a single tensor
# Change to ERA5 units.
forecasts = gefs_to_era5_units(forecasts)
# Make sure the fields follow the order in the model.
fields = [field.decode() for field in model.field_tags.numpy()]
forecasts = forecasts[fields]
forecasts = forecasts.to_array('field').transpose('number', 'field', 'latitude', 'longitude')
print('(#members, #fields, #lats, #lons) =', forecasts.shape)

In [None]:
# @title Regrid to cubed sphere at 2 degrees (48 nodes for 90 degrees)
source_grid = grid_lib.Equirectangular.on(forecasts)
target_grid = grid_lib.CubedSphere(48)
gridder = source_grid.to(target_grid)
coords = {k: v.data for k, v in forecasts.coords.items()}
lon, lat = target_grid.grid_points
coords['latitude'] = ('values', lat)
coords['longitude'] = ('values', lon)
gridded = xr.DataArray(gridder(forecasts.data), dims=['number', 'field', 'values'], coords=coords)
print('(#members, #fields, #grid_points) =', gridded.shape)

In [None]:
# @title Get the climatologies and convert to anomalies
base_time = pd.Timestamp(base_date)
valid_time = base_time + pd.Timedelta(days=lead_days)
base_clim_mean = climatology.sel(monthday=base_time.month * 100 + base_time.day)['mean'].load().data
valid_clim_mean = climatology.sel(monthday=valid_time.month * 100 + valid_time.day)['mean'].load().data
valid_clim_std = climatology.sel(monthday=valid_time.month * 100 + valid_time.day)['std'].load().data
anomalies = (gridded - valid_clim_mean[None, ...]) / valid_clim_std[None, ...]

Here to save time we sample 4 rounds. In each round we generate 16 samples from 2 random seeds out of the 8 downloaded GEFS members. We get 4*16=64 samples in total.

To get the best result, download the full 31 member ensemble before and do many rounds. For example, we can do 32 rounds and generate 16 samples from random seeds out of the 31 to get 32*16=512 samples. Scaling up to more TPUs can make this significantly faster.

In [None]:
# @title Generate more ensemble members
rounds = 4
samples_per_round = 2 * strategy.num_replicas_in_sync  # This 16 in this demo.
rng = np.random.default_rng(42)
num_diffusion_steps = tf.constant(600, tf.int64)
min_diffusion_time = tf.constant(1e-3, tf.float32)

src_ensemble_size = anomalies.shape[0]
results = []
for _ in tqdm.tqdm(range(rounds)):
  seeds_idx = rng.choice(8, size=num_seeds, replace=False)
  query = np.concatenate([anomalies[seeds_idx], base_clim_mean[None, ...]], axis=0)
  tiled_cond = tf.tile(query[None, ...], (samples_per_round, 1, 1, 1))

  model_rng = tf.constant(rng.integers(0, 2 ** 10, size=strategy.num_replicas_in_sync), tf.int64)
  dist_model_rng = split(strategy, model_rng)
  dist_conditioning = distribute(strategy, tiled_cond)
  samples = strategy.run(model.sample, args=(dist_conditioning, dist_model_rng, num_diffusion_steps, min_diffusion_time))
  samples = strategy.gather(samples, axis=0).numpy()
  results.append(samples)
results = np.concatenate(results, axis=0)
print('Samples shape: (#samples, 1, #fields, #locations) =', results.shape)

In [None]:
# @title Map back to raw values
results_raw = results[:, 0] * valid_clim_std[None, ...] + valid_clim_mean[None, ...]
gefs_raw = gridded.data

In [None]:
# @title Plot the generated results in raw values
field_id = 6
plot_opts = dict(cmap='Blues', transform=ccrs.PlateCarree(), add_colorbar=False)

seeds = gefs_raw[:num_seeds, field_id]
vmin, vmax = seeds.min() * 0.9, seeds.max() * 0.9

fig, axes = plt.subplots(4, 3, figsize=(12, 10), sharex=True, sharey=True, subplot_kw=dict(projection=ccrs.Robinson()))
for i, ax in enumerate(axes[0].flat):
  if i < num_seeds:
    wrap(seeds[i]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)
    ax.coastlines()
    ax.set_title(f'Cond #{i+1}')
for i, ax in enumerate(axes[1:].flat):
  if i < batchsize:
    wrap(results_raw[i, field_id]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)
    ax.coastlines()
    ax.set_title(f'Generated #{i+1}')
fig.subplots_adjust(wspace=0, hspace=0.1)
fig.suptitle(f'field={gefs.field.data[field_id]}');

In [None]:
# @title Postage stamp charts over Europe (first 2 rows are from GEFS, the next 4 rows are generated)
fig, axes = plt.subplots(6, 4, figsize=(12, 18), subplot_kw=dict(projection=ccrs.LambertConformal(5, 48)))
levels = 14

def level_styler(low, mid, high):
  lowbar = 8
  return [low] * lowbar + [mid] + [high] * (14 - lowbar)

vmin, vmax = None, None
fixed_levels = levels
zplot_opts = dict(cmap='Spectral_r', add_colorbar=False, transform=ccrs.PlateCarree())
pplot_opts = dict(
    transform=ccrs.PlateCarree(),
    colors='darkslategray',
    linewidths=level_styler(1, 1.5, 1),
    linestyles=level_styler('dashed', 'solid', 'solid'),
)
for i, ax in enumerate(axes.flat):
  if i // 4 < 2:
    ensemble = gefs_raw
    start = 0
  else:
    ensemble = results_raw
    start = 8
  ax.set_extent((-40, 50, 10, 86), crs=ccrs.PlateCarree())
  if i - start < ensemble.shape[0]:
    zplot = wrap(ensemble[i - start, 2] / g).plot(ax=ax, vmin=vmin, vmax=vmax, **zplot_opts)
    if vmin is None:
      vmin, vmax = zplot.get_clim()
    pplot = wrap(ensemble[i - start, 0]).plot.contour(ax=ax, levels=fixed_levels, **pplot_opts)
  if isinstance(fixed_levels, int):
    fixed_levels = pplot.levels
  ax.coastlines(resolution='110m', linewidth=1.5)
fig.subplots_adjust(wspace=0, hspace=0)
cbar_ax = fig.add_axes([0.96, 0.3, 0.02, 0.3])
mpl.colorbar.ColorbarBase(cbar_ax, orientation='vertical', cmap='Spectral_r', norm=mpl.colors.Normalize(vmin, vmax), extend='both')
cbar_ax.set_title('Geopotential at 500hPa height (m)', rotation='vertical', x=-0.7, y=0.15)

# Advanced usages

The forward SDE is

$$
dX_t = g(t)\,dW_t.
$$
where $W_t$ is the standard Wiener process. The diffusion coefficient is given by $g(t)=b^t$ where the constant $b$ is the noise schedule base exponent. This is a Gaussian process with

$$
(X_t|X_0=x_0) \sim N(x_0, \sigma^2(t)I), \qquad \sigma^2(t)=\int_0^t g^2(s)\,ds.
$$

Let $p(t,x)$ be the probability density for $X_t$. The reverse SDE is

$$
dY_t = -g^2(t)\nabla \log p(t,Y_t)\,dt+g(t)\,d\bar{W}_t,
$$

where $\bar{W}_t$ is the reverse Wiener process. In diffusion modeling, we use a neural net $s_\theta(t,x)$ to approximate $\nabla \log p(t,Y_t)$.

In [None]:
# @title Plot the noise schedule
print('SDE noise schedule base exponent:', model.sde_base_exponent.numpy())
diffusion_time = tf.linspace(1e-3, 1.0, 32)
diffusion_coef = model.diffusion_coef(diffusion_time)
marginal_std = model.marginal_std(diffusion_time)

fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharex=True)
ax = axes[0]
ax.plot(diffusion_time, diffusion_coef, label='Diffusion coefficient')
ax.set_title('Diffusion coefficient $g(t)$')
ax = axes[1]
ax.plot(diffusion_time, marginal_std)
ax.set_title('Marginal std $\sigma(t)$');

Because $(X_t|X_0=x_0) \sim N(x_0, \sigma^2(t)I)$, we see that $x_0+\sigma(t)Z$ for $Z\sim N(0,I)$ has the same distribution as $X_t$. The neural score function is trained on the denoising loss

$$
\mathbb{E}_{t\sim U(0,1]}\mathbb{E}_{x\sim p_{\textrm{data}}(x)}\mathbb{E}_{Z\sim N(0,I)}||s_\theta(t,x+\sigma(t)Z)\sigma(t) + Z||_2^2.
$$

Hence we for any $x\sim p_{\textrm{data}}(x)$, $Z\sim N(0,I)$, $t\in (0,1]$, we expect

$$
D(t, x+\sigma(t)Z):= x
$$

is a denoiser. We can thus visualize the learned score function by looking at the corresponding denoiser.

In [None]:
# @title Evaluate the denoiser at diffusion times for some random perturbations
batchsize = strategy.num_replicas_in_sync

# This continues from before using the 2022-01-01 example data, where we took
# the first num_seeds GEFS member for conditioning. Thus, here we take the
# last GEFS member as the label for denoising.
label = gefs.sel(time='2022-01-01', step=lead_days).isel(number=slice(-1, None, None))['anomaly'].load().data
label = tf.tile(tf.cast(label, tf.float32)[None, ...], (batchsize, 1, 1, 1))
tiled_cond = tf.tile(tf.cast(cond, tf.float32)[None, ...], (batchsize, 1, 1, 1))

# Create batchsize copies of perturbed samples at different diffusion times.
diffusion_time = tf.linspace(1e-2, 1.0, batchsize)
marginal_std = model.marginal_std(diffusion_time)
noise = tf.random.normal((batchsize, 1) + label.shape[2:])
perturbed = label + noise * marginal_std[:, None, None, None]

# Evaluate the model score function and compute the denoiser result.
dist_conditioning = distribute(strategy, tiled_cond)
dist_diffusion_time = distribute(strategy, diffusion_time)
dist_perturbed = distribute(strategy, perturbed)
scores = strategy.run(model.score, args=(dist_conditioning, dist_diffusion_time, dist_perturbed))
scores = strategy.gather(scores, axis=0).numpy()

denoised = perturbed + scores * (marginal_std ** 2)[:, None, None, None]

In [None]:
# @title Plot the denoiser results
field_id = 1

fig, axes = plt.subplots(8, 3, figsize=(12, 16), sharex=True, sharey=True, subplot_kw=dict(projection=ccrs.Robinson()))
for i in range(batchsize):
  truth = label[i, 0, field_id].numpy()
  vmin, vmax = truth.min(), truth.max()
  for j, data in enumerate([label[i, 0, field_id], perturbed[i, 0, field_id], denoised[i, 0, field_id]]):
    wrap(data.numpy()).plot(vmin=vmin, vmax=vmax, cmap='Spectral', transform=ccrs.PlateCarree(), ax=axes[i][j], add_colorbar=False)
  axes[i][0].set_ylabel(f't={diffusion_time[i]:.2f}')
axes[0][0].set_title('Label')
axes[0][1].set_title('Perturbed')
axes[0][2].set_title('Denoised')
fig.suptitle(f'Denoising plots for the field {gefs.field.data[field_id]}')
fig.subplots_adjust(wspace=0, hspace=0.02)