In [None]:
import sys
sys.path.append("../../")

import os

import numpy as np
import xarray as xr

from tqdm.notebook import tqdm

import torch
import torch.nn

from hydra import initialize, compose
from hydra.utils import instantiate

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mpl_colors
import matplotlib.gridspec as mpl_gs
from matplotlib.patches import Rectangle
import cmocean

import src_screening.model.accessor
from src_screening.datasets import OfflineDataset

In [None]:
plt.style.use("paper")
plt.style.use("egu_journals")

# Load data

In [None]:
dataset = OfflineDataset(
    "../../data/raw/test/dataset/input_normal/",
    "../../data/raw/test/dataset/target_normal/",
)

In [None]:
template = xr.open_dataset("../../data/interim/template_lr.nc")

## Load model

In [None]:
def load_model(
        model_checkpoint: str,
) -> torch.nn.Module:
    model_dir = os.path.dirname(model_checkpoint)
    with initialize(config_path=os.path.join(model_dir, 'hydra')):
        cfg = compose('config.yaml')

    try:
        cfg["model"]["backbone"]["_target_"] = 'src_screening.network.backbone.UNextBackbone'
        cfg["model"]["_target_"] = 'src_screening.network.offline.DeterministicOfflineModel'
        cfg["model"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["model"]["backbone"]["cartesian_weights_path"]
        model: torch.nn.Module = instantiate(
            cfg.model,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )
    except Exception as e:
        cfg["network"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["network"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.network,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )        
    state_dict = torch.load(model_checkpoint, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict["state_dict"], strict=False)
    model = model.eval().cpu()
    return model

In [None]:
network_gaussian = load_model("../../data/models_jeanzay/gaussian_nll/0/last.ckpt")

In [None]:
network_laplace = load_model("../../data/models_jeanzay/unext_small/0/last.ckpt")

# Get cartesian features

In [None]:
data_loader = torch.utils.data.DataLoader(dataset, batch_size=128)
with torch.no_grad():
    features_gaussian = []
    for chunk in tqdm(data_loader):
        predictor_cart = network_gaussian.backbone.to_cartesian(chunk["input_nodes"], chunk["input_faces"])
        features_cart = network_gaussian.backbone.get_backbone_prediction(predictor_cart)
        features_gaussian.append(network_gaussian.backbone.from_cartesian(features_cart)[1])
    features_gaussian = torch.concat(features_gaussian, dim=0)

In [None]:
data_loader = torch.utils.data.DataLoader(dataset, batch_size=128)
with torch.no_grad():
    features_laplace = []
    for chunk in tqdm(data_loader):
        predictor_cart = network_laplace.backbone.to_cartesian(chunk["input_nodes"], chunk["input_faces"])
        features_cart = network_laplace.backbone.get_backbone_prediction(predictor_cart)
        features_laplace.append(network_laplace.backbone.from_cartesian(features_cart)[1])
    features_laplace = torch.concat(features_laplace, dim=0)

# Plot single feature maps

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=32, dpi=60, figsize=(1*32, 5*4), sharex=True, sharey=True)
for i in tqdm(range(128)):
    ax[i//32, i%32].set_axis_off()
    sel_feature = features_gaussian[0, i].numpy()
    ax[i//32, i%32].tripcolor(template.sinn.triangulation, sel_feature, cmap="cmo.thermal", vmin=0, vmax=np.quantile(sel_feature, 0.99))
ax[0, 0].set_xlim(-20000, 20000)
ax[0, 0].set_ylim(-100000, 100000)

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=32, dpi=60, figsize=(1*32, 5*4), sharex=True, sharey=True)
for i in tqdm(range(128)):
    ax[i//32, i%32].set_axis_off()
    sel_feature = features_laplace[0, i].numpy()
    ax[i//32, i%32].tripcolor(template.sinn.triangulation, sel_feature, cmap="cmo.thermal", vmin=0, vmax=np.quantile(sel_feature, 0.99))
ax[0, 0].set_xlim(-20000, 20000)
ax[0, 0].set_ylim(-100000, 100000)

# Estimation of contrast

In [None]:
contrast_gaussian = (features_gaussian.std(dim=-1))/(features_gaussian.mean(dim=-1))
contrast_laplace = (features_laplace.std(dim=-1))/(features_laplace.mean(dim=-1))

In [None]:
nonzero_gaussian = features_gaussian.mean(dim=-1) > 0
nonzero_laplace = features_laplace.mean(dim=-1) > 0

In [None]:
contrast_gaussian = contrast_gaussian[nonzero_gaussian]
contrast_laplace = contrast_laplace[nonzero_laplace]

# Plot

In [None]:
np.median(contrast_gaussian)

In [None]:
np.median(contrast_laplace)

In [None]:
fig = plt.figure(figsize=(5, 5.5), dpi=150)
gs = mpl_gs.GridSpec(nrows=5, ncols=1, hspace=0.05)

ax_gaussian = fig.add_subplot(gs[0, :])
ax_gaussian.set_axis_off()
sel_gaussian = features_gaussian[0, 40].numpy()
norm_gaussian = np.quantile(sel_gaussian, 0.99)
plt_gaussian = ax_gaussian.tripcolor(
    template.sinn.triangulation, sel_gaussian / norm_gaussian,
    cmap="cmo.thermal", vmin=0, vmax=1, rasterized=True
)
t2 = mpl.transforms.Affine2D().rotate_deg(-90) + ax_gaussian.transData
plt_gaussian.set_transform(t2)
ax_gaussian.text(0.01, 0.98, s="(a)", ha="left", va="top", transform=ax_gaussian.transAxes, fontsize=10, color="white")
ax_gaussian.text(0.99, 0.98, s="Gaussian", ha="right", va="top", transform=ax_gaussian.transAxes, fontsize=10, color="white")
ax_gaussian.text(0.92, 0.25, s=f"$\mu={(sel_gaussian / norm_gaussian).mean():.2f}$", ha="center", va="bottom", transform=ax_gaussian.transAxes, color="white")
ax_gaussian.text(0.92, 0.05, s=f"$\sigma={(sel_gaussian / norm_gaussian).std(ddof=1):.2f}$", ha="center", va="bottom", transform=ax_gaussian.transAxes, color="white")

ax_laplace = fig.add_subplot(gs[1, :], sharex=ax_gaussian, sharey=ax_gaussian)
ax_laplace.set_axis_off()
sel_laplace = features_laplace[0, 49].numpy()
norm_laplace = np.quantile(sel_laplace, 0.99)
plt_laplace = ax_laplace.tripcolor(
    template.sinn.triangulation, sel_laplace/norm_laplace,
    cmap="cmo.thermal", vmin=0, vmax=1, rasterized=True
)
t2 = mpl.transforms.Affine2D().rotate_deg(-90) + ax_laplace.transData
plt_laplace.set_transform(t2)

ax_laplace.text(0.01, 0.98, s="(b)", ha="left", va="top", transform=ax_laplace.transAxes, fontsize=10, color="white")
ax_laplace.text(0.99, 0.98, s="Laplace", ha="right", va="top", transform=ax_laplace.transAxes, fontsize=10, color="white")
ax_laplace.text(0.92, 0.25, s=f"$\mu={(sel_laplace/norm_laplace).mean():.2f}$", ha="center", va="bottom", transform=ax_laplace.transAxes, color="white")
ax_laplace.text(0.92, 0.05, s=f"$\sigma={(sel_laplace/norm_laplace).std(ddof=1):.2f}$", ha="center", va="bottom", transform=ax_laplace.transAxes, color="white")
ax_laplace.set_xlim(-100000, 100000)
ax_laplace.set_ylim(-20000, 20000)


ax_cbar = fig.add_axes([1, 0.652, 0.02, 0.318])
norm = mpl_colors.Normalize(vmin=0, vmax=1)
cbar = mpl.colorbar.ColorbarBase(ax_cbar, cmap="cmo.thermal", norm=norm, label="Normalised activation", orientation="vertical")
#ax_cbar.xaxis.set_ticks_position('right')
#ax_cbar.xaxis.set_label_position('right')

bins = np.linspace(-0.05, 19.95, 201)

ax_hist = fig.add_subplot(gs[2:, :])
hist_gauss = ax_hist.hist(
    contrast_gaussian.numpy().flatten(),
    bins=bins,
    histtype="stepfilled", edgecolor="#8F0685",
    facecolor=mpl_colors.to_rgba("#8F0685", 0.3),
    lw=1, cumulative=True, density=True
)
ax_hist.hist(
    contrast_laplace.numpy().flatten(),
    bins=bins,
    histtype="stepfilled", edgecolor="#6EC940",
    facecolor=mpl_colors.to_rgba("#6EC940", 0.3),
    lw=1, cumulative=True, density=True
)

ax_hist.bar(
    -2.5, (~nonzero_gaussian).sum()/(nonzero_gaussian).sum(),
    width=0.3,
    edgecolor="#8F0685",
    facecolor=mpl_colors.to_rgba("#8F0685", 0.3),
    lw=1
)
ax_hist.bar(
    -2.5, (~nonzero_laplace).sum()/(nonzero_laplace).sum(),
    width=0.3,
    edgecolor="#6EC940",
    facecolor=mpl_colors.to_rgba("#6EC940", 0.3),
    lw=1
)

ax_hist.set_xlim(-2.9, 18)
ax_hist.set_xticks([-2.5] + list(np.arange(0, 20, 2.5)))
ax_hist.set_xticklabels(["inactive"] + list(np.arange(0, 20, 2.5)))
ax_hist.set_xlabel("Contrast of activations ($\sigma$/$\mu$)")

#hist_ticks = np.linspace(0, 0.75, 4)
#ax_hist.set_yticks(hist_ticks*norm_factor)
#ax_hist.set_yticklabels(hist_ticks)
#ax_hist.set_ylabel("Count density")
#ax_hist.set_ylim(0, norm_factor)

handles = [
    Rectangle((0,0), 1, 1, facecolor=mpl_colors.to_rgba("#8F0685", 0.3), edgecolor="#8F0685", lw=1),
    Rectangle((0,0), 1, 1, facecolor=mpl_colors.to_rgba("#6EC940", 0.3), edgecolor="#6EC940", lw=1)
]
ax_hist.legend(handles=handles, labels=["Gaussian", "Laplace"], loc=2, bbox_to_anchor=(0.02, 0.95))
ax_hist.text(x=0.005, y=1-0.025/3, s="(c)", ha="left", va="top", transform=ax_hist.transAxes)

ax_hist.set_ylabel("CDF")

fig.savefig("figures/figc01_loss_features.pdf", bbox_inches='tight', pad_inches = 0, dpi=300)