# Generative Downscaling of Atmospheric Data using CorrDiff

The following notebook demonstrates how to use Modulus's [CorrDiff](https://arxiv.org/abs/2309.15214)
implementation. Similar to diagnostic models, CorrDiff does not predict a future weather state. Instead
a diagnostic set of variables can be derived at a finer resolution than the original atmospheric data.
In this case we will super-resolve data over Taiwan from a 25km to 2km resolution.
CorrDiff produces an ensemble of high-resolution realisations, accounting for the stochastic
nature of small-scale structures. The ensemble output will then be loaded into an xarray Dataset
and some sample data plotting is provided.

In summary this notebook will cover the following topics:

- Configuring and setting up CorrDiff
- Produce a downscaled ensemble through various realisations of stochastic features
- Plot result and compare different realisations of ensemble


## Set Up and Execute Downscaling
As always, we're starting off with some imports of libraries and the generate method from
CorrDiff. CorrDiff is currently not implemented in E2MIP, we will hence leverage the Modulus
implementation. The config file is located under `/e2ws/exercises/corrdiff/conf`.
Note that the hydra package, which reads the config, only accepts relative paths to the file. In the
case of running a notebook in a container, the path has to be defined relative to the location at
which the jupyter server was launched. This notebook assumes that jupyter was launchef in `/e2ws`, if
you launched it from a different location you have to edit the `config_path` in the `hydra.initialize`
call.

Under the hood, the generate method will load snapshots at times specified in
the config file. The default values select three snapshots on 12 September 2021,
when Typhoon Chanthu passed over Taiwan. Chanthu was the second most intense TC in 2021
The coarse data stems from the ERA5 data set and has a resolution of 25km. Firstly, the regression
model is applied predicting the _ensemble mean_ of the downscaled variables at a resolution
of 2km. In a second step, the diffusion model restores the fine structures, which are not resolved
by the original data. With respect to the coarse resolution, these structures have stochastic nature.
The diffusion model can therefore predict a range of realisations of these structures, producing an
ensemble.

In [None]:
import sys
import os
import hydra
import xarray

sys.path.append(os.path.abspath("/e2ws/exercises/corrdiff"))
import generate

# read config with hydra, clear in case hydra has been initialised before
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize(version_base="1.3", config_path='./exercises/corrdiff/conf')
cfg = hydra.compose(config_name='config_generate')
generate.main(cfg)

## Explore Downscaled Data
CorrDiff outputs a netcdf file containing the input, its prediction and the ground truth.
The prediction contains variables which are not included in the input set, making the model
fully generative.

In [None]:
def open_results(path):
    root = xarray.open_dataset(path)
    input = (xarray.open_dataset(path, group="input")
             .merge(root).set_coords(["lat", "lon"]))
    pred = (xarray.open_dataset(path, group="prediction")
            .merge(root).set_coords(["lat", "lon"]))
    truth = (xarray.open_dataset(path, group="truth")
             .merge(root).set_coords(["lat", "lon"]))
    return input, pred, truth

input, pred, truth = open_results(f"{cfg['image_outdir']}_0.nc")
print(f"input {input.data_vars}")
print(f"\nprediction: {pred}")

## Visualising the Result

Finally, let us visualise the result. The routine below will plot the ground truth
for comparison, the ensemble mean (better would be regression result, it's on my
todo list, for large ensemble sizes it should be visually identical) and an ensemble
member. Use the arrows on the bottom of the plot to click through ensemble members.
Change `var_idx` to plot a different variable (see `vars` for index list) and `time_idx`
to plot a different time step.

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.animation as animation
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import numpy as np

var_idx = 0
time_idx = 0

vars = ['maximum_radar_reflectivity', 'temperature_2m', 'eastward_wind_10m', 'northward_wind_10m']
var = vars[var_idx]

n_samples = 8
n_samples = min(n_samples, cfg["seed_batch_size"])
output_dir = './outputs'

# concatenate truth data and ensemble mean as an "ensemble" member for easy
truth_expanded = truth.assign_coords(ensemble="truth").expand_dims("ensemble")
ens_mean = (
    pred.mean("ensemble")
    .assign_coords(ensemble="ensemble_mean")
    .expand_dims("ensemble")
)
# add [0, 1, 2, ...] to ensemble dim
pred["ensemble"] = [str(i) for i in range(pred.sizes["ensemble"])]
merged = xarray.concat([truth_expanded, ens_mean, pred], dim="ensemble")
projection=ccrs.PlateCarree()

vmax = np.max(merged[var][:, time_idx, ...]).item()
vmin = np.min(merged[var][:, time_idx, ...]).item()
cmap = 'plasma'
if var_idx > 1:
    vmax = max(abs(vmax), abs(vmin))
    vmin = -vmax
    cmap = 'RdBu_r'

# define plots
def make_figure():
    title = ['truth', 'ensemble mean']
    fig, ax = plt.subplots(1, 3, figsize=(11,5), subplot_kw={'projection': ccrs.PlateCarree()})
    fig.suptitle(f"{var} at {np.datetime_as_string(merged.time[time_idx], unit='s')}", fontsize=18)
    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()

    for mem in range(3):
        ax[mem].add_feature(cfeature.COASTLINE,lw=.5)
        ax[mem].add_feature(cfeature.RIVERS,lw=.5)
        ax[mem].add_feature(cfeature.BORDERS, linewidth=0.6, edgecolor='dimgray')
        ax[mem].xaxis.set_major_formatter(lon_formatter)
        ax[mem].yaxis.set_major_formatter(lat_formatter)

        if mem==2:
            continue

        plot_ds = merged[var][mem, 0, :, :]
        pc = ax[mem].pcolormesh(merged.lon, merged.lat, plot_ds, transform=projection,
                        cmap=cmap, vmin=vmin, vmax=vmax)
        ax[mem].set_title(title[mem])

    cbar = fig.colorbar(pc, extend='both', shrink=0.6, ax=ax, location='bottom')
    cbar.set_label(var, fontsize=12)

    return fig, ax

# plot the variables
def make_frame(mem):
    plot_ds = merged[var][mem+2, 0, :, :]  # 2 is for the esemble and truth
    pc = ax[2].pcolormesh(merged.lon, merged.lat, plot_ds, transform=projection,
                       cmap=cmap, vmin=vmin, vmax=vmax)
    ax[2].set_title(f'ensemble member {mem+1} of {n_samples}')
    return pc

def animate(frame):
    return make_frame(frame)

def first_frame():
    return make_frame(0)

%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"
fig, ax = make_figure()
ani = animation.FuncAnimation(fig,
                              animate,
                              n_samples,
                              init_func=first_frame,
                              blit=False,
                              repeat=False,
                              interval=.1)
plt.close('all')
ani

<div class="alert alert-warning"><h4>Exercise</h4><p>

Inverse-diffusion is an iterative process controlled by a range of parameters. &sigma;<sub>max</sub>
defines the variance of the noise added at the beginning, &sigma;<sub>min</sub> is the target variance
or the remaining noise at the end of the process.
The number of iterations for moving from &sigma;<sub>max</sub> to &sigma;<sub>min</sub> can be
controlled through `num_steps`. Open the config file `e2workshop/exercises/corrdiff/conf/config_generate.yaml`,
change the number of steps and see what happens. As a first try, consider doubling or halving the number of
steps. It is not required to stay within powers of two. Is the result as you expected?
</p></div>