In [None]:
import sys
sys.path.append("../../")

import os
import numpy as np

from tqdm.notebook import tqdm

import torch
import torch.nn

import xarray as xr

from hydra import initialize, compose
from hydra.utils import instantiate

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mpl_colors
import matplotlib.gridspec as mpl_gs
import cmocean

from src_screening.datasets import OfflineDataset

In [None]:
plt.style.use("paper")
plt.style.use("egu_journals")
plt.style.use("presentation")

# Load data

In [None]:
dataset = OfflineDataset(
    "../../data/raw/test/dataset/input_normal/",
    "../../data/raw/test/dataset/target_normal/"
)
template = xr.open_dataset("../../data/interim/template_lr.nc")

## Load model

In [None]:
def load_model(
        model_checkpoint: str,
) -> torch.nn.Module:
    model_dir = os.path.dirname(model_checkpoint)
    with initialize(config_path=os.path.join(model_dir, 'hydra')):
        cfg = compose('config.yaml')

    try:
        cfg["model"]["backbone"]["_target_"] = 'src_screening.network.backbone.UNextBackbone'
        cfg["model"]["_target_"] = 'src_screening.network.offline.DeterministicOfflineModel'
        cfg["model"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["model"]["backbone"]["cartesian_weights_path"]
        model: torch.nn.Module = instantiate(
            cfg.model,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )
    except Exception as e:
        cfg["network"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["network"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.network,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )        
    state_dict = torch.load(model_checkpoint, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict["state_dict"], strict=False)
    model = model.eval().cpu()
    return model

In [None]:
network_8_32 = load_model("../../data/models_jeanzay/cartesian_32x8/9/last.ckpt")
network_16_64 = load_model("../../data/models_jeanzay/cartesian_64x16/9/last.ckpt")
network_32_128 = load_model("../../data/models_jeanzay/unext_small/9/last.ckpt")

# Get cartesian features

In [None]:
test_sample = dataset[0]

In [None]:
with torch.no_grad():
    predictor_cart = network_8_32.backbone.to_cartesian(test_sample["input_nodes"][None, ...], test_sample["input_faces"][None, ...])
    features_8_32 = network_8_32.backbone.get_backbone_prediction(predictor_cart)

    predictor_cart = network_16_64.backbone.to_cartesian(test_sample["input_nodes"][None, ...], test_sample["input_faces"][None, ...])
    features_16_64 = network_16_64.backbone.get_backbone_prediction(predictor_cart)

    predictor_cart = network_32_128.backbone.to_cartesian(test_sample["input_nodes"][None, ...], test_sample["input_faces"][None, ...])
    features_32_128 = network_32_128.backbone.get_backbone_prediction(predictor_cart)

# Plot features

In [None]:
def get_cart_bounds(n_x, n_y):
    res_x = 40000/n_x
    res_y = 200000/n_y
    bounds_x = np.linspace(-20000, 20000, n_x+1)
    bounds_y = np.linspace(-100000, 100000, n_y+1)
    return bounds_x, bounds_y

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=32, dpi=60, figsize=(1*32, 5*4), sharex=True, sharey=True)
for i in range(128):
    ax[i//32, i%32].set_axis_off()
    sel_feature = features_32_128[0, i].numpy()
    ax[i//32, i%32].pcolormesh(*get_cart_bounds(32, 128), sel_feature, cmap="cmo.thermal", vmin=0, vmax=np.quantile(sel_feature, 0.99))
ax[0, 0].set_xlim(-20000, 20000)
ax[0, 0].set_ylim(-100000, 100000)

# Input

In [None]:
fig, ax = plt.subplots(nrows=3, dpi=150, figsize=(12, 12*3/5), sharex=True, sharey=True)

ax[0].set_axis_off()
lr_feature = features_8_32[0, 69].numpy()
plt_lr = ax[0].pcolormesh(
    *get_cart_bounds(8, 32), lr_feature/np.quantile(lr_feature, 0.99),
    cmap="cmo.thermal", vmin=0, vmax=1, rasterized=True
)
t2 = mpl.transforms.Affine2D().rotate_deg(-90) + ax[0].transData
plt_lr.set_transform(t2)
ax[0].set_xlim(-20000, 20000)
ax[0].set_ylim(-100000, 100000)
ax[0].text(x=0.005, y=0.98, s=r"(a)", transform=ax[0].transAxes, ha="left", va="top", color="white")
ax[0].text(x=0.995, y=0.98, s=r"$8 \times 32$", transform=ax[0].transAxes, ha="right", va="top", color="white")

ax[1].set_axis_off()
med_feature = features_16_64[0, 6].numpy()
plt_med = ax[1].pcolormesh(
    *get_cart_bounds(16, 64), med_feature/np.quantile(med_feature, 0.99),
    cmap="cmo.thermal", vmin=0, vmax=1, rasterized=True
)
ax[1].set_xlim(-20000, 20000)
ax[1].set_ylim(-100000, 100000)
t2 = mpl.transforms.Affine2D().rotate_deg(-90) + ax[1].transData
plt_med.set_transform(t2)
ax[1].text(x=0.005, y=0.98, s=r"(b)", transform=ax[1].transAxes, ha="left", va="top", color="white")
ax[1].text(x=0.995, y=0.98, s=r"$16 \times 64$", transform=ax[1].transAxes, ha="right", va="top", color="white")


ax[2].set_axis_off()
hr_feature = features_32_128[0, 38].numpy()
plt_hr = ax[2].pcolormesh(
    *get_cart_bounds(32, 128), hr_feature/np.quantile(hr_feature, 0.99),
    cmap="cmo.thermal", vmin=0, vmax=1, rasterized=True
)
ax[2].set_xlim(-20000, 20000)
ax[2].set_ylim(-100000, 100000)
t2 = mpl.transforms.Affine2D().rotate_deg(-90) + ax[2].transData
plt_hr.set_transform(t2)
ax[2].text(x=0.005, y=0.98, s=r"(c)", transform=ax[2].transAxes, ha="left", va="top", color="white")
ax[2].text(x=0.995, y=0.98, s=r"$32 \times 128$", transform=ax[2].transAxes, ha="right", va="top", color="white")

ax[0].set_ylim(-20000, 20000)
ax[0].set_xlim(-100000, 100000)

ax_cbar = fig.add_axes([1, 0.1648, 0.02, 0.803])
norm = mpl_colors.Normalize(vmin=0, vmax=1)
cbar = mpl.colorbar.ColorbarBase(ax_cbar, cmap="cmo.thermal", norm=norm, label="Normalised activation", orientation="vertical")
plt.subplots_adjust(hspace=0.15)

fig.savefig("figures/fig05_normalised_activations.pdf", bbox_inches='tight', pad_inches = 0)

#### 