In [None]:
import sys
sys.path.append("../../")

import os

import numpy as np
import xarray as xr

from tqdm.notebook import tqdm

import torch
import torch.nn
from torch.autograd.functional import jacobian

from hydra import initialize, compose
from hydra.utils import instantiate

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mpl_colors
import matplotlib.gridspec as mpl_gs
from matplotlib.patches import Rectangle
import cmocean
import cmcrameri

import src_screening.model.accessor
from src_screening.datasets import OfflineDataset

In [None]:
plt.style.use("paper")
plt.style.use("egu_journals")

# Load data

In [None]:
dataset = OfflineDataset(
    input_path="../../data/raw/test/dataset/input_difference/",
    target_path="../../data/raw/test/dataset/target_normal/"
)

In [None]:
climatology = {
    "mean": xr.open_dataset("../../data/raw/train/climatology/input_difference_mean.nc",),
    "std": xr.open_dataset("../../data/raw/train/climatology/input_difference_std.nc",)
}

In [None]:
template = xr.open_dataset("../../data/interim/template_lr.nc")

## Load model

In [None]:
def load_model(
        model_checkpoint: str,
) -> torch.nn.Module:
    model_dir = os.path.dirname(model_checkpoint)
    with initialize(config_path=os.path.join(model_dir, 'hydra')):
        cfg = compose('config.yaml')
    if "model" in cfg.keys():
        cfg["model"]["_target_"] = cfg["model"]["_target_"].replace(
            ".model.", ".network."
        )
        cfg["model"]["backbone"]["_target_"] = cfg["model"]["backbone"]["_target_"].replace(
            ".model.", ".network."
        )
        cfg["model"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["model"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.model,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )
    else:
        cfg["network"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["network"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.network,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )    
    state_dict = torch.load(model_checkpoint, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict["state_dict"], strict=False)
    model = model.eval().cpu()
    model = model.requires_grad_(False)
    return model

In [None]:
network = load_model("../../data/models_jeanzay/input_difference/9/last.ckpt").cuda()

# Estimate saliency maps

In [None]:
input_nodes = dataset[0]["input_nodes"].cuda()[None, ...]
input_faces = dataset[0]["input_faces"].cuda()[None, ...]

In [None]:
pred_nodes, pred_faces = network(input_nodes, input_faces)

In [None]:
idx_q95 = 266

In [None]:
def get_reduced_prediction(*args):
    pred_nodes, pred_faces = network(*args)
    reduced_faces = pred_faces[:, 5, idx_q95]
    return reduced_faces.sum(dim=0)

In [None]:
input_nodes_noise = torch.randn(128, *input_nodes.shape, device=input_nodes.device)
input_faces_noise = torch.randn(128, *input_faces.shape, device=input_nodes.device)

input_nodes = input_nodes + input_nodes_noise * 0.1
input_faces = input_faces + input_faces_noise * 0.1

input_nodes = input_nodes.view(-1, 6, 187)
input_faces = input_faces.view(-1, 14, 312)

In [None]:
jac_nodes, jac_faces = jacobian(
    get_reduced_prediction, (input_nodes, input_faces)
)

In [None]:
plot_nodes = jac_nodes.mean(dim=0).cpu()
plot_faces = jac_faces.mean(dim=0).cpu()

# Plot grad vals

# Reduced figure for main

In [None]:
fig, ax = plt.subplots(ncols=11, figsize=(6.5, 1.2), dpi=300, sharex=True, sharey=True)
fig.subplots_adjust(wspace=0.04)

[curr_ax.set_axis_off() for curr_ax in ax]
ax[0].tripcolor(template.sinn.triangulation, pred_faces.mean(dim=0).cpu()[5], cmap="cmo.balance", vmin=-2, vmax=2, rasterized=True)
ax[0].text(x=0.5, y=1.02, s=r"$f(\mathbf{x})$", ha="center", va="bottom", transform=ax[0].transAxes)
ax[0].text(x=0.05, y=0.9, s=f"(a) A", ha="left", va="center", transform=ax[0].transAxes)
ax[0].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="white", s=8)

fig.add_artist(mpl.lines.Line2D([1.025, 1.025], [-0.2, 1.05], transform=ax[0].transAxes, c="black"))

norm = mpl_colors.SymLogNorm(0.05, 0.2, vmin=-3, vmax=3)

plt_sal = ax[1].tripcolor(template.sinn.triangulation, plot_nodes[1].numpy(), cmap="cmo.balance", norm=norm, rasterized=True)
ax[1].text(x=0.05, y=0.9, s=f"(b) v", ha="left", va="center", transform=ax[1].transAxes)
ax[1].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="black", s=8)

ax[2].tripcolor(template.sinn.triangulation, plot_faces[2].numpy(), cmap="cmo.balance", norm=norm, rasterized=True)
ax[2].text(x=0.05, y=0.9, s="(c) $\sigma_{yy}$", ha="left", va="center", transform=ax[2].transAxes)
ax[2].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="black", s=8)

ax[3].tripcolor(template.sinn.triangulation, plot_faces[3].numpy(), cmap="cmo.balance", norm=norm, rasterized=True)
ax[3].text(x=0.05, y=0.9, s="(d) d", ha="left", va="center", transform=ax[3].transAxes)
ax[3].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="black", s=8)

ax[4].tripcolor(template.sinn.triangulation, plot_faces[5].numpy(), cmap="cmo.balance", norm=norm, rasterized=True)
ax[4].text(x=0.05, y=0.9, s="(e) A", ha="left", va="center", transform=ax[4].transAxes)
ax[4].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="white", s=8)

ax[5].tripcolor(template.sinn.triangulation, plot_faces[6].numpy(), cmap="cmo.balance", norm=norm, rasterized=True)
ax[5].text(x=0.05, y=0.9, s="(f) h", ha="left", va="center", transform=ax[5].transAxes)
ax[5].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="white", s=8)

ax[3].text(x=0.05, y=1.02, s="Initial: $\mathbf{x}_{0}$", ha="center", va="bottom", transform=ax[3].transAxes)
fig.add_artist(mpl.lines.Line2D([1.025, 1.025], [0, 1.05], transform=ax[5].transAxes, c="black"))

ax[6].tripcolor(template.sinn.triangulation, plot_nodes[4].numpy(), cmap="cmo.balance", norm=norm, rasterized=True)
ax[6].text(x=0.05, y=0.9, s=f"(g) v", ha="left", va="center", transform=ax[6].transAxes)
ax[6].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="black", s=8)

ax[7].tripcolor(template.sinn.triangulation, plot_faces[9].numpy(), cmap="cmo.balance", norm=norm, rasterized=True)
ax[7].text(x=0.05, y=0.9, s="(h) $\sigma_{yy}$", ha="left", va="center", transform=ax[7].transAxes)
ax[7].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="black", s=8)

ax[8].tripcolor(template.sinn.triangulation, plot_faces[10].numpy(), cmap="cmo.balance", norm=norm, rasterized=True)
ax[8].text(x=0.05, y=0.9, s=f"(i) D", ha="left", va="center", transform=ax[8].transAxes)
ax[8].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="white", s=8)

ax[9].tripcolor(template.sinn.triangulation, plot_faces[12].numpy(), cmap="cmo.balance", norm=norm, rasterized=True)
ax[9].text(x=0.05, y=0.9, s=f"(j) A", ha="left", va="center", transform=ax[9].transAxes)
ax[9].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="white", s=8)

ax[10].tripcolor(template.sinn.triangulation, plot_faces[13].numpy(), cmap="cmo.balance", norm=norm, rasterized=True)
ax[10].text(x=0.05, y=0.9, s=f"(k) h", ha="left", va="center", transform=ax[10].transAxes)
ax[10].scatter(template.Mesh2_face_x[idx_q95], template.Mesh2_face_y[idx_q95], marker=".", color="white", s=8)

ax[8].text(x=0.5, y=1.02, s="Difference: $\Delta \mathbf{x} = \mathbf{x}_{1} - \mathbf{x}_{0}$", ha="center", va="bottom", transform=ax[8].transAxes)


ax[0].set_xlim(-20000, 20000)
ax[0].set_ylim(-80000, 0)


ax_cbar = fig.add_axes([0.2514, 0.1, 0.715, 0.03])
cbar = fig.colorbar(plt_sal, cax=ax_cbar, orientation="horizontal")
cbar.set_label(r"Sensitivity: $\partial f(\mathbf{x})/\partial \mathbf{x}$")
cbar.set_ticks([-0.5, -0.1, 0, 0.1, 0.5])
cbar.set_ticklabels([-0.5, -0.1, 0, 0.1, 0.5])
ax_cbar.set_xlim(-1, 1)
ax_cbar.minorticks_off()


fig.savefig("figures/fig07_saliency_map_area.pdf", dpi=300, bbox_inches="tight", pad_inches=0.0)