In [None]:
import os

import xarray as xr
import numpy as np
import pandas as pd

from tqdm.notebook import tqdm

from distributed import LocalCluster, Client

import src_screening.model.accessor
from src_screening.model.wave_forcing import WaveForcing
from src_screening.model.post_processing import estimate_xr_grads, estimate_deform

import matplotlib.pyplot as plt
import cmocean
import matplotlib.gridspec as mpl_gs

In [None]:
plt.style.use("paper")
plt.style.use("egu_journals")

In [None]:
local_cluster = LocalCluster(n_workers=64, local_directory="/tmp")
client = Client(local_cluster)
client

# Load data

In [None]:
sel_ens = 9
sel_time = "1970-01-02T00:00:00"
sel_seed = 9
sel_leadtimes = ["9 min 52 s", "20 min 0 s", "30 min 8 s", "60 min"]

In [None]:
ds_hr_nature = xr.open_zarr(f"../../data/raw/test/hr_nature_forecast/", chunks={"time": -1, "ensemble": 1, "lead_time": -1}).sel(
    lead_time=sel_leadtimes, ensemble=sel_ens, time=sel_time
)
ds_nature = xr.open_zarr(f"../../data/raw/test/lr_nature_forecast/", chunks={"time": -1, "ensemble": 1, "lead_time": -1}).sel(
    lead_time=sel_leadtimes, ensemble=sel_ens, time=sel_time
)
ds_forecast = xr.open_zarr(f"../../data/raw/test/lr_forecast/", chunks={"time": -1, "ensemble": 1, "lead_time": -1}).sel(
    lead_time=sel_leadtimes, ensemble=sel_ens, time=sel_time
)

In [None]:
ds_hybrid = xr.open_dataset(
    f"../../data/processed/unext_small/{sel_seed:d}/traj_short/", engine="zarr",
    chunks={"time": -1, "ensemble": 1, "lead_time": -1, "nMesh2_node": -1, "nMesh2_face": -1}
).sel(
    lead_time=sel_leadtimes, ensemble=sel_ens, time=sel_time
)

# Plot example

In [None]:
hr_grads = estimate_xr_grads(ds_hr_nature)
hr_log_deform = np.log10(estimate_deform(hr_grads) * 86400)

plt_hr = xr.merge([ds_hr_nature, hr_log_deform])[["damage", "deform_tot", "Mesh2_face_nodes", "Mesh2_node_x", "Mesh2_node_y"]].compute()

In [None]:
fcst_grads = estimate_xr_grads(ds_forecast)
fcst_log_deform = np.log10(estimate_deform(fcst_grads) * 86400)

plt_fcst = xr.merge([ds_forecast, fcst_log_deform])[["damage", "deform_tot", "Mesh2_face_nodes", "Mesh2_node_x", "Mesh2_node_y"]].compute()

In [None]:
hybrid_grads = estimate_xr_grads(ds_hybrid)
hybrid_log_deform = np.log10(estimate_deform(hybrid_grads) * 86400)

plt_hybrid = xr.merge([ds_hybrid, hybrid_log_deform])[["damage", "deform_tot", "Mesh2_face_nodes", "Mesh2_node_x", "Mesh2_node_y"]].compute()

In [None]:
list_of_plts = [plt_hr, plt_fcst, plt_hybrid]

In [None]:
fig = plt.figure(figsize=(3, 2.5*len(list_of_plts)/2), dpi=150)

subfigs = fig.subfigures(3, 4)
for k, ds in enumerate(list_of_plts):
    for t in range(4):
        curr_ax = subfigs[k, t].subplots(1, 2, sharex=True, sharey=True, gridspec_kw=dict(hspace=0, wspace=0.05, left=0.03, bottom=0.03, top=0.97, right=0.97))
        curr_ax[0].set_axis_off()
        curr_ax[1].set_axis_off()
        curr_plt_dam = curr_ax[0].tripcolor(
            ds.sinn.triangulation, 1-ds["damage"].isel(lead_time=t),
            vmin=0., vmax=1, cmap="cmo.ice_r", rasterized=True
        )
        curr_plt_deform = curr_ax[1].tripcolor(
            ds.sinn.triangulation, ds["deform_tot"].isel(lead_time=t),
            vmin=-3, vmax=0., cmap="cmo.thermal", rasterized=True
        )
        curr_ax[0].set_xlim(-20000, 20000)
        curr_ax[0].set_ylim(-60000, 40000)
        

subfigs[0, 0].text(0.5, 1., s="+10 min", ha="center", va="bottom", transform=subfigs[0, 0].transSubfigure)
subfigs[0, 1].text(0.5, 1., s="+20 min", ha="center", va="bottom", transform=subfigs[0, 1].transSubfigure)
subfigs[0, 2].text(0.5, 1., s="+30 min", ha="center", va="bottom", transform=subfigs[0, 2].transSubfigure)
subfigs[0, 3].text(0.5, 1., s="+60 min", ha="center", va="bottom", transform=subfigs[0, 3].transSubfigure)

subfigs[0, 0].text(0., 0.5, s="Truth (4 km)", ha="right", va="center", transform=subfigs[0, 0].transSubfigure, rotation=90)
subfigs[1, 0].text(0., 0.5, s="Forecast (8 km)", ha="right", va="center", transform=subfigs[1, 0].transSubfigure, rotation=90)
subfigs[2, 0].text(0., 0.5, s="Hybrid (8 km)", ha="right", va="center", transform=subfigs[2, 0].transSubfigure, rotation=90)


ax_cbar = fig.add_axes([1.01, 0.55, 0.02, 0.4])
cbar = fig.colorbar(curr_plt_dam, cax=ax_cbar,)
cbar.set_label("Damage (1)")

ax_cbar = fig.add_axes([1.01, 0.05, 0.02, 0.4])
cbar = fig.colorbar(curr_plt_deform, cax=ax_cbar)
cbar.set_label("$\log_{10}(\dot{\epsilon}_{tot})$")

fig.savefig("figures/fig09_case_study.pdf", dpi=300, bbox_inches='tight', pad_inches = 0)