In [None]:
import sys
sys.path.append("../../")

import os
from functools import partial

import numpy as np
import xarray as xr

from tqdm.notebook import tqdm

import torch
import torch.nn
from torch.utils.data import DataLoader

from hydra import initialize, compose
from hydra.utils import instantiate

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mpl_colors
import matplotlib.gridspec as mpl_gs
from matplotlib.patches import Rectangle
import cmocean

import src_screening.model.accessor
from src_screening.datasets import OfflineDataset

In [None]:
plt.style.use("paper")
plt.style.use("egu_journals")

# Load model

In [None]:
def load_model(
        model_checkpoint: str,
) -> torch.nn.Module:
    model_dir = os.path.dirname(model_checkpoint)
    with initialize(config_path=os.path.join(model_dir, 'hydra')):
        cfg = compose('config.yaml')
    if "model" in cfg.keys():
        cfg["model"]["_target_"] = cfg["model"]["_target_"].replace(
            ".model.", ".network."
        )
        cfg["model"]["backbone"]["_target_"] = cfg["model"]["backbone"]["_target_"].replace(
            ".model.", ".network."
        )
        cfg["model"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["model"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.model,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )
    else:
        cfg["network"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["network"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.network,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )    
    state_dict = torch.load(model_checkpoint, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict["state_dict"], strict=False)
    model = model.eval().cpu()
    model = model.requires_grad_(False)
    return model

In [None]:
network = load_model("../../data/models_jeanzay/input_difference/9/last.ckpt").cuda()

# Estimate original error

In [None]:
dataset = OfflineDataset(
    "../../data/raw/test/dataset/input_difference/",
    "../../data/raw/test/dataset/target_normal/"
)
data_loader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
original_rmse = []
for curr_batch in tqdm(data_loader):
    curr_batch = {k: v.cuda() for k, v in curr_batch.items()}
    fcst_nodes, fcst_faces = network(curr_batch["input_nodes"], curr_batch["input_faces"])
    mse_nodes = (fcst_nodes-curr_batch["error_nodes"]).pow(2).mean(dim=-1)
    mse_faces = (fcst_faces-curr_batch["error_faces"]).pow(2).mean(dim=-1)
    curr_mse = torch.cat((mse_nodes, mse_faces), dim=-1).cpu()
    original_rmse.append(curr_mse)
original_rmse = torch.cat(original_rmse, dim=0)
original_rmse = original_rmse.mean(dim=0).sqrt()

# Get scores for permuted inputs

In [None]:
n_input_nodes = 6
n_input_faces = 14

rmse_permuted = []

In [None]:
for idx_node in tqdm(range(n_input_nodes)):
    data_loader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)
    curr_node_rmse = []
    for curr_batch in data_loader:
        curr_batch = {k: v.cuda() for k, v in curr_batch.items()}
        curr_batch["input_nodes"][:, idx_node, :] = torch.roll(curr_batch["input_nodes"][:, idx_node, :], 1, dims=0)
        fcst_nodes, fcst_faces = network(curr_batch["input_nodes"], curr_batch["input_faces"])
        mse_nodes = (fcst_nodes-curr_batch["error_nodes"]).pow(2).mean(dim=-1)
        mse_faces = (fcst_faces-curr_batch["error_faces"]).pow(2).mean(dim=-1)
        curr_mse = torch.cat((mse_nodes, mse_faces), dim=-1).cpu()
        curr_node_rmse.append(curr_mse)
    curr_node_rmse = torch.cat(curr_node_rmse, dim=0)
    curr_node_rmse = curr_node_rmse.mean(dim=0).sqrt()
    rmse_permuted.append(curr_node_rmse)

In [None]:
for idx_face in tqdm(range(n_input_faces)):
    data_loader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)
    curr_face_rmse = []
    for curr_batch in data_loader:
        curr_batch = {k: v.cuda() for k, v in curr_batch.items()}
        curr_batch["input_faces"][:, idx_face, :] = torch.roll(curr_batch["input_faces"][:, idx_face, :], 1, dims=0)
        fcst_nodes, fcst_faces = network(curr_batch["input_nodes"], curr_batch["input_faces"])
        mse_nodes = (fcst_nodes-curr_batch["error_nodes"]).pow(2).mean(dim=-1)
        mse_faces = (fcst_faces-curr_batch["error_faces"]).pow(2).mean(dim=-1)
        curr_mse = torch.cat((mse_nodes, mse_faces), dim=-1).cpu()
        curr_face_rmse.append(curr_mse)
    curr_face_rmse = torch.cat(curr_face_rmse, dim=0)
    curr_face_rmse = curr_face_rmse.mean(dim=0).sqrt()
    rmse_permuted.append(curr_face_rmse)

In [None]:
rmse_permuted = torch.stack(rmse_permuted, dim=-1)

# Plot grad vals

## Normalise gradients

In [None]:
norm_rmse = rmse_permuted / original_rmse[:, None]

# Reorder gradients

In [None]:
norm_rmse_t = torch.cat((norm_rmse[:, [2]], norm_rmse[:, :2], norm_rmse[:, 6:13]), dim=-1)
norm_rmse_t1 = torch.cat((norm_rmse[:, [5]], norm_rmse[:, 3:5], norm_rmse[:, 13:]), dim=-1)
norm_rmse_reordered = torch.cat((norm_rmse_t, norm_rmse_t1), dim=-1)

plt_rmse_reordered = norm_rmse_reordered-1
plt_rmse_reordered = plt_rmse_reordered / plt_rmse_reordered.max(dim=-1).values[:, None]

# Plot matrix

In [None]:
fig, ax = plt.subplots(figsize=(6, 4.5), dpi=300)

plt_sal = ax.matshow(plt_rmse_reordered, cmap="cmo.balance", vmin=-1, vmax=1)

for row in range(9):
    for col in range(20):
        ax.text(
            col, row, f"{round(norm_rmse_reordered[row, col].item(), 1):.1f}",
            ha="center", va="center", fontsize=9,
            c="white" if plt_rmse_reordered[row, col].item() > 0.5 else "black")

#ax.scatter(norm_grad_reordered.abs().argmax(-1).numpy(), range(9), marker="x", c="white")

ax.spines.left.set_visible(False)
ax.spines.bottom.set_visible(False)
ax.xaxis.set_ticks_position('top') 
ax.xaxis.set_label_position('top') 

ax.set_yticks(list(range(9)))
_ = ax.set_yticklabels(["u - Velocity", "v - Velocity", r"$\sigma_{xx}$", r"$\sigma_{xy}$", r"$\sigma_{yy}$", "Damage", "Cohesion", "Area", "Thickness"])
ax.set_xticks(list(range(20)))
_ = ax.set_xticklabels([
    "Forcing", "u - Velocity", "v - Velocity", r"$\sigma_{xx}$", r"$\sigma_{xy}$", r"$\sigma_{yy}$", "Damage", "Cohesion", "Area", "Thickness",
    "Forcing", "u - Velocity", "v - Velocity", r"$\sigma_{xx}$", r"$\sigma_{xy}$", r"$\sigma_{yy}$", "Damage", "Cohesion", "Area", "Thickness",
], rotation = 45, horizontalalignment="left")
ax.set_ylabel(r"Output: $f(\mathbf{x})$")
ax.set_xlabel(r"Input: $\mathbf{x}$")

ax.axvline(9.5, ymin=-1, ymax=9, c="black", lw=1.5)

ax.text(x=4.75, y=9, s=r"Initial: $\mathbf{x}_{0}$", ha="center", va="center")
ax.text(x=14.5, y=9, s=r"Difference: $\Delta \mathbf{x} = \mathbf{x}_{1}-\mathbf{x}_{0}$", ha="center", va="center")

ax.set_ylim(9.2, -0.5)


ax_cbar = fig.add_axes([1.00, 0.38, 0.01, 0.4])
cbar = fig.colorbar(plt_sal, cax=ax_cbar, orientation="vertical")
ax_cbar.set_ylim(0, 1)
cbar.set_ticks([])
ax_cbar.text(x=1.18, y=1, s="Important", rotation=90, transform=ax_cbar.transAxes, ha="left", va="top")
ax_cbar.text(x=1.18, y=0, s="Unimportant", rotation=90, transform=ax_cbar.transAxes, ha="left", va="bottom")
fig.savefig("figures/fig06_permutation_feature_importance.pdf", bbox_inches='tight', pad_inches = 0, dpi=300, facecolor='white', transparent=False)