In [1]:
import sys
import os
import hydra
sys.path.append(os.path.abspath("/e2ws/exercises/corrdiff"))
import generate

# read config with hydra, clear in case hydra has been initialised before
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize(version_base="1.3", config_path='./exercises/corrdiff/conf')
cfg = hydra.compose(config_name='config_generate')
generate.main(cfg)

  warn("Distributed manager is already intialized")
[14:15:40 - generate - INFO] [94mPatch-based generation disabled[0m
[14:15:40 - generate - INFO] [94mtorch.__version__: 2.2.0a0+81ea7a4[0m
[14:15:40 - generate - INFO] [94mLoading residual network from "/e2ws/exercises/corrdiff/checkpoints/diffusion.mdlus"...[0m
[14:15:41 - generate - INFO] [94mLoading network from "/e2ws/exercises/corrdiff/checkpoints/regression.mdlus"...[0m
[14:15:42 - generate - INFO] [94mGenerating images...[0m
[14:15:42 - generate - INFO] [94mstarting index: 0[0m
[14:15:42 - generate - INFO] [94mseeds: [0, 1, 2, 3, 4, 5, 6, 7][0m
100%|███████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.17batch/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1/1 [00:09<00:00,  9.92s/batch]
[14:15:54 - generate - INFO] [94mstarting index: 1[0m
[14:15:54 - generate - INFO] [94mseeds: [0, 1, 2, 3, 4, 5, 6, 7][0m
100%|█████████

In [2]:
import xarray

def open_results(path):
    root = xarray.open_dataset(path)
    pred = (xarray.open_dataset(path, group="prediction")
            .merge(root).set_coords(["lat", "lon"]))
    truth = (xarray.open_dataset(path, group="truth")
             .merge(root).set_coords(["lat", "lon"]))
    return pred, truth

pred, truth = open_results(f"{cfg['image_outdir']}_0.nc")
pred

In [10]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.animation as animation
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import numpy as np

var_idx = 0
time_idx = 0

vars = ['maximum_radar_reflectivity', 'temperature_2m', 'eastward_wind_10m', 'northward_wind_10m']
var = vars[var_idx]

n_samples = 8
n_samples = min(n_samples, cfg["seed_batch_size"])
output_dir = './outputs'

# concatenate truth data and ensemble mean as an "ensemble" member for easy
truth_expanded = truth.assign_coords(ensemble="truth").expand_dims("ensemble")
ens_mean = (
    pred.mean("ensemble")
    .assign_coords(ensemble="ensemble_mean")
    .expand_dims("ensemble")
)
# add [0, 1, 2, ...] to ensemble dim
pred["ensemble"] = [str(i) for i in range(pred.sizes["ensemble"])]
merged = xarray.concat([truth_expanded, ens_mean, pred], dim="ensemble")
projection=ccrs.PlateCarree()

vmax = np.max(merged[var][:, time_idx, ...]).item()
vmin = np.min(merged[var][:, time_idx, ...]).item()
cmap = 'plasma'
if var_idx > 1:
    vmax = max(abs(vmax), abs(vmin))
    vmin = -vmax
    cmap = 'RdBu_r'

# define plots
def make_figure():
    title = ['truth', 'ensemble mean']
    fig, ax = plt.subplots(1, 3, figsize=(11,5), subplot_kw={'projection': ccrs.PlateCarree()})
    fig.suptitle(f"{var} at {np.datetime_as_string(merged.time[time_idx], unit='s')}", fontsize=18)
    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()

    for mem in range(3):
        ax[mem].add_feature(cfeature.COASTLINE,lw=.5)
        ax[mem].add_feature(cfeature.RIVERS,lw=.5)
        ax[mem].add_feature(cfeature.BORDERS, linewidth=0.6, edgecolor='dimgray')
        ax[mem].xaxis.set_major_formatter(lon_formatter)
        ax[mem].yaxis.set_major_formatter(lat_formatter)

        if mem==2:
            continue

        plot_ds = merged[var][mem, 0, :, :]
        pc = ax[mem].pcolormesh(merged.lon, merged.lat, plot_ds, transform=projection,
                        cmap=cmap, vmin=vmin, vmax=vmax)
        ax[mem].set_title(title[mem])

    cbar = fig.colorbar(pc, extend='both', shrink=0.6, ax=ax, location='bottom')
    cbar.set_label(var, fontsize=12)

    return fig, ax

# plot the variables
def make_frame(mem):
    plot_ds = merged[var][mem+2, 0, :, :]  # 2 is for the esemble and truth
    pc = ax[2].pcolormesh(merged.lon, merged.lat, plot_ds, transform=projection,
                       cmap=cmap, vmin=vmin, vmax=vmax)
    ax[2].set_title(f'ensemble member {mem+1} of {n_samples}')
    return pc

def animate(frame):
    return make_frame(frame)

def first_frame():
    return make_frame(-1)

%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"
fig, ax = make_figure()
ani = animation.FuncAnimation(fig,
                              animate,
                              n_samples,
                              init_func=first_frame,
                              blit=False,
                              repeat=False,
                              interval=.1)
plt.close('all')
ani