In [None]:
import sys
import os
import hydra
sys.path.append(os.path.abspath("/e2ws/exercises/corrdiff"))
import generate

# read config with hydra, clear in case hydra has been initialised before
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize(version_base="1.3", config_path='./exercises/corrdiff/conf')
cfg = hydra.compose(config_name='config_generate')
generate.main(cfg)

In [None]:
import xarray

def open_results(path):
    root = xarray.open_dataset(path)
    pred = (xarray.open_dataset(path, group="prediction")
            .merge(root).set_coords(["lat", "lon"]))
    truth = (xarray.open_dataset(path, group="truth")
             .merge(root).set_coords(["lat", "lon"]))
    return pred, truth

pred, truth = open_results(f"{cfg['image_outdir']}_0.nc")
pred

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.animation as animation
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from numpy import datetime_as_string

var = 'maximum_radar_reflectivity'

n_samples = 4
n_samples = min(n_samples, cfg["seed_batch_size"])
output_dir = './outputs'

# concatenate truth data and ensemble mean as an "ensemble" member for easy
truth_expanded = truth.assign_coords(ensemble="truth").expand_dims("ensemble")
ens_mean = (
    pred.mean("ensemble")
    .assign_coords(ensemble="ensemble_mean")
    .expand_dims("ensemble")
)
# add [0, 1, 2, ...] to ensemble dim
pred["ensemble"] = [str(i) for i in range(pred.sizes["ensemble"])]
merged = xarray.concat([truth_expanded, ens_mean, pred], dim="ensemble")
projection=ccrs.PlateCarree()

# define plots
def make_figure():
    title = ['truth', 'mean']
    # fig = plt.figure(figsize=(11,5))
    fig.suptitle(f"{var} at {datetime_as_string(merged.time[0], unit='s')}", fontsize=18)
    fig.set_tight_layout(True)
    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    # ax = fig.add_subplot(1, 4, mem+1)


    for mem in range(3):
        ax = fig.add_subplot(1, 4, mem+1, projection=projection)
        ax.add_feature(cfeature.COASTLINE,lw=.5)
        ax.add_feature(cfeature.RIVERS,lw=.5)
        ax.add_feature(cfeature.BORDERS, linewidth=0.6, edgecolor='dimgray')
        ax.xaxis.set_major_formatter(lon_formatter)
        ax.yaxis.set_major_formatter(lat_formatter)

        if mem==2:
            continue

        plot_ds = merged[var][mem, 0, :, :]
        pc = ax.pcolormesh(merged.lon, merged.lat, plot_ds, transform=projection,
                        cmap='plasma', vmin=0, vmax=100)
        ax.set_title(title[mem])

    axx = fig.add_subplot(1, 4, 4)
    cbar = fig.colorbar(pc, extend='both', shrink=0.6, ax=axx)
    cbar.set_label('wind speed [m/s]', fontsize=12)

    return fig, ax

# plot the variables
def make_frame(mem):
    # 2 is for the esemble and truth
    # merged[v][: n_samples + 2, :].plot(row="time", col="ensemble")
    plot_ds = merged[var][mem+2, 0, :, :]
    pc = ax.pcolormesh(merged.lon, merged.lat, plot_ds, transform=projection,
                       cmap='plasma', vmin=0, vmax=100)
    # if mem == -1:
    #     cbar = fig.colorbar(pc, extend='both', shrink=0.6, ax=ax)
    #     cbar.set_label('wind speed [m/s]', fontsize=12)
    ax.set_title(f'ensemble member {mem+1} of {n_samples}')
    return pc

def animate(frame):
    return make_frame(frame)

def first_frame():
    return make_frame(-1)

%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"
fig, ax = make_figure()
ani = animation.FuncAnimation(fig,
                              animate,
                              n_samples,
                              init_func=first_frame,
                              blit=False,
                              repeat=False,
                              interval=.1)
plt.close('all')
ani

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.animation as animation
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from numpy import datetime_as_string

var = 'maximum_radar_reflectivity'

n_samples = 4
n_samples = min(n_samples, cfg["seed_batch_size"])
output_dir = './outputs'

# concatenate truth data and ensemble mean as an "ensemble" member for easy
truth_expanded = truth.assign_coords(ensemble="truth").expand_dims("ensemble")
ens_mean = (
    pred.mean("ensemble")
    .assign_coords(ensemble="ensemble_mean")
    .expand_dims("ensemble")
)
# add [0, 1, 2, ...] to ensemble dim
pred["ensemble"] = [str(i) for i in range(pred.sizes["ensemble"])]
merged = xarray.concat([truth_expanded, ens_mean, pred], dim="ensemble")
projection=ccrs.PlateCarree()

# define plots
def make_figure():
    title = ['truth', 'mean']
    # fig = plt.figure(figsize=(11,5))
    fig, ax = plt.subplots(1, 3, figsize=(11,5), subplot_kw={'projection': ccrs.PlateCarree()})
    fig.suptitle(f"{var} at {datetime_as_string(merged.time[0], unit='s')}", fontsize=18)
    # fig.set_tight_layout(True)
    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    # ax = fig.add_subplot(1, 4, mem+1)


    for mem in range(3):
        # ax[mem] = fig.add_subplot(1, 4, mem+1, projection=projection)
        ax[mem].add_feature(cfeature.COASTLINE,lw=.5)
        ax[mem].add_feature(cfeature.RIVERS,lw=.5)
        ax[mem].add_feature(cfeature.BORDERS, linewidth=0.6, edgecolor='dimgray')
        ax[mem].xaxis.set_major_formatter(lon_formatter)
        ax[mem].yaxis.set_major_formatter(lat_formatter)

        if mem==2:
            continue

        plot_ds = merged[var][mem, 0, :, :]
        pc = ax[mem].pcolormesh(merged.lon, merged.lat, plot_ds, transform=projection,
                        cmap='plasma', vmin=0, vmax=100)
        ax[mem].set_title(title[mem])

    # axx = fig.add_subplot(1, 4, 4)
    cbar = fig.colorbar(pc, extend='both', shrink=0.6, ax=ax, location='bottom')
    cbar.set_label('wind speed [m/s]', fontsize=12)

    return fig, ax

# plot the variables
def make_frame(mem):
    # 2 is for the esemble and truth
    # merged[v][: n_samples + 2, :].plot(row="time", col="ensemble")
    plot_ds = merged[var][mem+2, 0, :, :]
    pc = ax[2].pcolormesh(merged.lon, merged.lat, plot_ds, transform=projection,
                       cmap='plasma', vmin=0, vmax=100)
    # if mem == -1:
    #     cbar = fig.colorbar(pc, extend='both', shrink=0.6, ax=ax)
    #     cbar.set_label('wind speed [m/s]', fontsize=12)
    ax[2].set_title(f'ensemble member {mem+1} of {n_samples}')
    return pc

def animate(frame):
    return make_frame(frame)

def first_frame():
    return make_frame(-1)

%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"
fig, ax = make_figure()
ani = animation.FuncAnimation(fig,
                              animate,
                              n_samples,
                              init_func=first_frame,
                              blit=False,
                              repeat=False,
                              interval=.1)
plt.close('all')
ani

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.animation as animation
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from numpy import datetime_as_string

var = 'maximum_radar_reflectivity'

n_samples = 2
n_samples = min(n_samples, cfg["seed_batch_size"])
output_dir = './outputs'

# concatenate truth data and ensemble mean as an "ensemble" member for easy
truth_expanded = truth.assign_coords(ensemble="truth").expand_dims("ensemble")
ens_mean = (
    pred.mean("ensemble")
    .assign_coords(ensemble="ensemble_mean")
    .expand_dims("ensemble")
)
# add [0, 1, 2, ...] to ensemble dim
pred["ensemble"] = [str(i) for i in range(pred.sizes["ensemble"])]
merged = xarray.concat([truth_expanded, ens_mean, pred], dim="ensemble")
projection=ccrs.PlateCarree()

# define plots
def make_figure():
    title = ['truth', 'mean']
    # fig = plt.figure(figsize=(11,5))
    fig, ax = plt.subplots(1, 3, figsize=(11,5), subplot_kw={'projection': ccrs.PlateCarree()})
    # ax = [0,0,0,fig.add_subplot(1, 4, 4, projection=projection)]
    fig.suptitle(f"{var} at {datetime_as_string(merged.time[0], unit='s')}", fontsize=18)
    # fig.set_tight_layout(True)
    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    # ax = fig.add_subplot(1, 4, mem+1)


    for mem in range(3):
        # ax[mem] = fig.add_subplot(1, 4, mem+1, projection=projection)
        ax[mem].add_feature(cfeature.COASTLINE,lw=.5)
        ax[mem].add_feature(cfeature.RIVERS,lw=.5)
        ax[mem].add_feature(cfeature.BORDERS, linewidth=0.6, edgecolor='dimgray')
        ax[mem].xaxis.set_major_formatter(lon_formatter)
        ax[mem].yaxis.set_major_formatter(lat_formatter)

        if mem==2:
            continue

        plot_ds = merged[var][mem, 0, :, :]
        pc = ax[mem].pcolormesh(merged.lon, merged.lat, plot_ds, transform=projection,
                        cmap='plasma', vmin=0, vmax=100)
        ax[mem].set_title(title[mem])

    # axx = fig.add_subplot(1, 4, 4)
    # cbar = fig.colorbar(pc, extend='both', shrink=0.6, ax=ax[3])
    # cbar.set_label('wind speed [m/s]', fontsize=12)
    fig.colorbar(pc, ax=ax, shrink=.6, location='bottom', pad=.2)

    return fig, ax

# plot the variables
def make_frame(mem):
    # 2 is for the esemble and truth
    # merged[v][: n_samples + 2, :].plot(row="time", col="ensemble")
    plot_ds = merged[var][mem+2, 0, :, :]
    pc = ax[2].pcolormesh(merged.lon, merged.lat, plot_ds, transform=projection,
                       cmap='plasma', vmin=0, vmax=100)
    # if mem == -1:
    #     cbar = fig.colorbar(pc, extend='both', shrink=0.6, ax=ax)
    #     cbar.set_label('wind speed [m/s]', fontsize=12)
    ax[2].set_title(f'ensemble member {mem+1} of {n_samples}')

    # fig.set_tight_layout(True)
    return pc

def animate(frame):
    return make_frame(frame)

def first_frame():
    return make_frame(-1)

%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"
fig, ax = make_figure()
ani = animation.FuncAnimation(fig,
                              animate,
                              n_samples,
                              init_func=first_frame,
                              blit=False,
                              repeat=False,
                              interval=.1)
plt.close('all')
ani