In [None]:
import sys
sys.path.append("../../")

import os

import numpy as np
import xarray as xr
import scipy.stats

from tqdm.notebook import tqdm

import torch
import torch.nn
from torch.distributions import Cauchy

from hydra import initialize, compose
from hydra.utils import instantiate

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mpl_colors
import matplotlib.gridspec as mpl_gs
from matplotlib.patches import Rectangle
import cmocean

import src_screening.model.accessor
from src_screening.datasets import OfflineDataset

In [None]:
plt.style.use("paper")
plt.style.use("egu_journals")

# Load data

In [None]:
dataset = OfflineDataset("../../data/raw/test/dataset/input_normal/", "../../data/raw/test/dataset/target_normal/")

In [None]:
template = xr.open_dataset("../../data/interim/template_lr.nc")

## Load model

In [None]:
def load_model(
        model_checkpoint: str,
) -> torch.nn.Module:
    model_dir = os.path.dirname(model_checkpoint)
    with initialize(config_path=os.path.join(model_dir, 'hydra')):
        cfg = compose('config.yaml')
    if "model" in cfg.keys():
        cfg["model"]["_target_"] = cfg["model"]["_target_"].replace(
            ".model.", ".network."
        )
        cfg["model"]["backbone"]["_target_"] = cfg["model"]["backbone"]["_target_"].replace(
            ".model.", ".network."
        )
        cfg["model"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["model"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.model,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )
    else:
        cfg["network"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["network"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.network,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )    
    state_dict = torch.load(model_checkpoint, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict["state_dict"], strict=False)
    model = model.eval().cpu()
    return model

In [None]:
network_relu = load_model("../../data/models_jeanzay/unext_small/0/last.ckpt")

In [None]:
network_gelu = load_model("../../data/models_jeanzay/gelu_all/0/last.ckpt")

In [None]:
network_no = load_model("../../data/models_jeanzay/no_output/0/last.ckpt")

# Get cartesian features

In [None]:
data_loader = torch.utils.data.DataLoader(dataset, batch_size=128)
with torch.no_grad():
    features_relu = []
    for chunk in tqdm(data_loader):
        predictor_cart = network_relu.backbone.to_cartesian(chunk["input_nodes"], chunk["input_faces"])
        features_cart = network_relu.backbone.get_backbone_prediction(predictor_cart)
        features_relu.append(network_relu.backbone.from_cartesian(features_cart)[1])
    features_relu = torch.concat(features_relu, dim=0)

In [None]:
data_loader = torch.utils.data.DataLoader(dataset, batch_size=128)
with torch.no_grad():
    features_gelu = []
    for chunk in tqdm(data_loader):
        predictor_cart = network_gelu.backbone.to_cartesian(chunk["input_nodes"], chunk["input_faces"])
        features_cart = network_gelu.backbone.get_backbone_prediction(predictor_cart)
        features_gelu.append(network_gelu.backbone.from_cartesian(features_cart)[1])
    features_gelu = torch.concat(features_gelu, dim=0)

In [None]:
data_loader = torch.utils.data.DataLoader(dataset, batch_size=128)
with torch.no_grad():
    features_no = []
    for chunk in tqdm(data_loader):
        predictor_cart = network_no.backbone.to_cartesian(chunk["input_nodes"], chunk["input_faces"])
        features_cart = network_no.backbone.get_backbone_prediction(predictor_cart)
        features_no.append(network_no.backbone.from_cartesian(features_cart)[1])
    features_no = torch.concat(features_no, dim=0)

# Plot feature maps

In [None]:
colors1 = plt.get_cmap("cmo.balance")(np.linspace(0.5, 0, 128))
colors2 = plt.get_cmap("cmo.thermal")(np.linspace(0, 1, 128))

In [None]:
colors = np.vstack((colors1, colors2))
thermal_ice = mpl_colors.LinearSegmentedColormap.from_list('thermal_ice', colors)

# Fit Cauchy distribution

In [None]:
latent_dist_params = torch.nn.Parameter(torch.zeros(2))
latent_optim = torch.optim.Adam([latent_dist_params], lr=1E-1)

pbar = tqdm(range(100))

for _ in pbar:
    latent_optim.zero_grad()
    scale = (latent_dist_params[1]*0.5).exp()
    cauchy_dist = Cauchy(latent_dist_params[0], scale)
    nll = -cauchy_dist.log_prob(features_no.flatten()).mean()
    nll.backward()
    latent_optim.step()
    pbar.set_postfix(nll=nll.item(), loc=latent_dist_params[0].item(), scale=scale.item())

# Create final figures

In [None]:
fig = plt.figure(figsize=(3, 4), dpi=300)
gs = mpl_gs.GridSpec(nrows=7, ncols=1, hspace=0.1)

# No activation map

ax0 = fig.add_subplot(gs[0, :])
ax0.set_axis_off()
sel_no = features_no[0, 89].numpy()
norm_no = np.quantile(np.abs(sel_no), 0.99)
plt_no = ax0.tripcolor(
    template.sinn.triangulation, sel_no / norm_no,
    cmap=thermal_ice, vmin=-1, vmax=1, rasterized=True
)
ax0.set_ylim(-20000, 20000)
ax0.set_xlim(-100000, 100000)
t2 = mpl.transforms.Affine2D().rotate_deg(-90) + ax0.transData
plt_no.set_transform(t2)
ax0.text(x=0.005, y=0.975, s="(a) w/o", transform=ax0.transAxes, ha="left", va="top", c="white")
ax0.text(0.7, 0.55, s=f"$.75={np.percentile(sel_no/norm_no, 75):.2f}$", ha="left", va="bottom", transform=ax0.transAxes, color="white")
ax0.text(0.7, 0.3, s=f"$.50={np.median(sel_no/norm_no):.2f}$", ha="left", va="bottom", transform=ax0.transAxes, color="white")
ax0.text(0.7, 0.05, s=f"$.25={np.percentile(sel_no/norm_no, 25):.2f}$", ha="left", va="bottom", transform=ax0.transAxes, color="white")

# RELU map

ax1 = fig.add_subplot(gs[2, :], sharex=ax0, sharey=ax0)
ax1.set_axis_off()
sel_relu = features_relu[0, 25].numpy()
norm_relu = np.quantile(np.abs(sel_relu), 0.99)
plt_relu = ax1.tripcolor(
    template.sinn.triangulation, sel_relu / norm_relu,
    cmap=thermal_ice, vmin=-1, vmax=1, rasterized=True
)
t2 = mpl.transforms.Affine2D().rotate_deg(-90) + ax1.transData
plt_relu.set_transform(t2)
ax1.text(x=0.005, y=0.975, s="(c) relu", transform=ax1.transAxes, ha="left", va="top", c="white")
ax1.text(0.7, 0.55, s=f"$.75={np.percentile(sel_relu/norm_relu, 75):.2f}$", ha="left", va="bottom", transform=ax1.transAxes, color="white")
ax1.text(0.7, 0.3, s=f"$.50={np.median(sel_relu/norm_relu):.2f}$", ha="left", va="bottom", transform=ax1.transAxes, color="white")
ax1.text(0.7, 0.05, s=f"$.25={np.percentile(sel_relu/norm_relu, 25):.2f}$", ha="left", va="bottom", transform=ax1.transAxes, color="white")

# GELU map

ax2 = fig.add_subplot(gs[1, :], sharex=ax0, sharey=ax0)
ax2.set_axis_off()
sel_gelu = features_gelu[0, 60].numpy()
norm_gelu = np.quantile(np.abs(sel_gelu), 0.99)
plt_gelu = ax2.tripcolor(
    template.sinn.triangulation, sel_gelu / norm_gelu,
    cmap=thermal_ice, vmin=-1, vmax=1, rasterized=True
)
t2 = mpl.transforms.Affine2D().rotate_deg(-90) + ax2.transData
plt_gelu.set_transform(t2)
ax2.text(x=0.005, y=0.975, s="(b) Gelu", transform=ax2.transAxes, ha="left", va="top", c="white")
ax2.text(0.7, 0.55, s=f"$.75={np.percentile(sel_gelu/norm_gelu, 75):.2f}$", ha="left", va="bottom", transform=ax2.transAxes, color="white")
ax2.text(0.7, 0.3, s=f"$.50={np.median(sel_gelu/norm_gelu):.2f}$", ha="left", va="bottom", transform=ax2.transAxes, color="white")
ax2.text(0.7, 0.05, s=f"$.25={np.percentile(sel_gelu/norm_gelu, 25):.2f}$", ha="left", va="bottom", transform=ax2.transAxes, color="white")

# Colorbar

ax_cbar = fig.add_axes([1, 0.6294, 0.02, 0.3408])
norm = mpl_colors.Normalize(vmin=-1, vmax=1)
cbar = mpl.colorbar.ColorbarBase(ax_cbar, cmap=thermal_ice, norm=norm, label="Normalised activation", orientation="vertical")


# Histogram

ax = fig.add_subplot(gs[4:, :])
_ = ax.hist(
    features_no.numpy().flatten(),
    bins=np.linspace(-2, 2, 200),
    histtype="stepfilled", #edgecolor="black",
    facecolor=mpl_colors.to_rgba("black", 0.5), density=True,
    lw=1,
    zorder=1
)
cauchy_x = torch.linspace(-1, 2, 1000)
with torch.no_grad():
    cauchy_y = cauchy_dist.log_prob(cauchy_x).exp()
ax.plot(cauchy_x, cauchy_y, c="black", zorder=2)

_ = ax.hist(
    features_gelu.numpy().flatten(),
    bins=np.linspace(-2, 2, 200),
    histtype="stepfilled", edgecolor="deepskyblue",
    facecolor=mpl_colors.to_rgba("deepskyblue", 0.5), density=True,
    lw=1, zorder=3
)
hist_relu = ax.hist(
    features_relu.numpy().flatten(),
    bins=np.linspace(-2, 2, 200),
    histtype="stepfilled", edgecolor="firebrick",
    facecolor=mpl_colors.to_rgba("firebrick", 0.5), density=True,
    lw=1, zorder=4
)
ax.set_ylim(0, 6)
ax.set_xlabel("Activation value")
ax.set_ylabel("Probabilty density")
ax.set_yticks(np.arange(0, 6))


d = .01

handles = [
    Rectangle((0,0), 1, 1, facecolor=mpl_colors.to_rgba("black", 0.5), edgecolor="black", lw=1),
    Rectangle((0,0), 1, 1, facecolor=mpl_colors.to_rgba("deepskyblue", 0.5), edgecolor="deepskyblue", lw=1),
    Rectangle((0,0), 1, 1, facecolor=mpl_colors.to_rgba("firebrick", 0.5), edgecolor="firebrick", lw=1)
]
ax.legend(handles=handles, labels=["w/o", "Gelu", "relu"])

ax_broken = fig.add_subplot(gs[3, :], sharex=ax)
_ = ax_broken.hist(
    features_relu.numpy().flatten(),
    bins=np.linspace(-2, 2, 200),
    histtype="stepfilled", edgecolor="firebrick",
    facecolor=mpl_colors.to_rgba("firebrick", 0.5), density=True,
    lw=1
)
ax_broken.set_ylim(19, 21)
ax_broken.spines['bottom'].set_visible(False)
ax_broken.xaxis.tick_top()
ax_broken.tick_params(labeltop=False, top=False)
ax_broken.set_yticks([20, 21])
ax_broken.text(x=0.005, y=0.975, s="(d)", transform=ax_broken.transAxes, ha="left", va="top")

ax.plot((-d, +d), (1-d, 1+d), transform=ax.transAxes, color='k', clip_on=False)
ax_broken.plot((-d, +d), (-d*3, +d*3), transform=ax_broken.transAxes, color='k', clip_on=False)

ax.set_xlim(-1, 2)

fig.savefig("figures/figdfigc02_visualise_activation_features01_hist_activation_features.pdf", bbox_inches='tight', pad_inches = 0)