In [None]:
import sys
sys.path.append("../../")

import os
from copy import deepcopy

import xarray as xr
import numpy as np
import pandas as pd

import torch
import torch.nn

import pytorch_lightning as pl

from hydra import initialize, compose
from hydra.utils import instantiate

import matplotlib.pyplot as plt
import matplotlib.colors as mpl_colors
import matplotlib.gridspec as mpl_gs
import cmocean

import src_screening.model.accessor
import src_screening.model.fem_interpolation as grid_utils
from src_screening.datasets import OfflineDataset


In [None]:
plt.style.use("paper")
plt.style.use("egu_journals")

# Load data

In [None]:
dataset = OfflineDataset(
    "../../data/raw/test/dataset/input_normal",
    "../../data/raw/test/dataset/target_normal/",
)

In [None]:
template = xr.open_dataset("../../data/interim/template_lr.nc")

## Load model

In [None]:
def load_model(
        model_checkpoint: str,
) -> torch.nn.Module:
    model_dir = os.path.dirname(model_checkpoint)
    with initialize(config_path=os.path.join(model_dir, 'hydra')):
        cfg = compose('config.yaml')

    try:
        cfg["model"]["backbone"]["_target_"] = 'src_screening.network.backbone.UNextBackbone'
        cfg["model"]["_target_"] = 'src_screening.network.offline.DeterministicOfflineModel'
        cfg["model"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["model"]["backbone"]["cartesian_weights_path"]
        model: torch.nn.Module = instantiate(
            cfg.model,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )
    except Exception as e:
        cfg["network"]["backbone"]["cartesian_weights_path"] = "../../" + cfg["network"]["backbone"]["cartesian_weights_path"]
        model = instantiate(
            cfg.network,
            optimizer_config=cfg.optimizer,
            _recursive_=False
        )        
    state_dict = torch.load(model_checkpoint, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict["state_dict"], strict=False)
    model = model.eval().cpu()
    return model

In [None]:
network = load_model(
    "../../data/models_jeanzay/unext_small/9/last.ckpt",
)

# Get data to plot

In [None]:
data_dict = {k: i[None, ...] for k, i in dataset[0].items()}

In [None]:
with torch.no_grad():
    predictor_cart = network.backbone.to_cartesian(data_dict["input_nodes"], data_dict["input_faces"])
    features_cart = network.backbone.get_backbone_prediction(predictor_cart)
    features_nodes, features_faces = network.backbone.from_cartesian(features_cart)
    error_faces = network.head_face(features_faces)
    errors_cart = network.head_face(features_cart.view(1, 128, -1)).view(1, 7, *features_cart.shape[-2:])

## Coordinates

In [None]:
cart_coords_mesh, cart_coords_xy = grid_utils.gen_cartesian_coords(
    template, cartesian_res=None,
    target_shape=(128, 32)
)

In [None]:
res_x = cart_coords_xy[0][1]-cart_coords_xy[0][0]
res_y = cart_coords_xy[1][1]-cart_coords_xy[1][0]

cart_bounds_x = [cart_coords_xy[0][0]-res_x/2] + list(cart_coords_xy[0]+res_x/2)
cart_bounds_y = [cart_coords_xy[1][0]-res_y/2] + list(cart_coords_xy[1]+res_y/2)

# Plot

In [None]:
fig, ax = plt.subplots(dpi=300, figsize=(1, 5))
ax.set_axis_off()
ax.tripcolor(template.sinn.triangulation, data_dict["input_faces"][0, 8].numpy(), cmap="cmo.balance", vmin=-4, vmax=4)
ax.set_xlim(-20000, 20000)
ax.set_ylim(-100000, 100000)
fig.savefig("figures/fig03_network/01_input.png", bbox_inches='tight', pad_inches = 0)

In [None]:
fig, ax = plt.subplots(dpi=300, figsize=(1, 5))
ax.set_axis_off()
ax.pcolormesh(cart_bounds_x, cart_bounds_y, predictor_cart[0, 14].numpy(), cmap="cmo.balance", vmin=-4, vmax=4)
ax.set_xlim(-20000, 20000)
ax.set_ylim(-100000, 100000)
fig.savefig("figures/fig03_network/02_input_cart.png", bbox_inches='tight', pad_inches = 0)

In [None]:
fig, ax = plt.subplots(dpi=300, figsize=(1, 5))
ax.set_axis_off()
norm_features = features_cart[0, 83].numpy()
norm_features = norm_features / norm_features.max()
ax.pcolormesh(cart_bounds_x, cart_bounds_y, norm_features, cmap="cmo.thermal", vmin=0, vmax=1.,)
#ax.set_xlim(-20000, 20000)
ax.set_ylim(-100000, 100000)
fig.savefig("figures/fig03_network/03_features_cart.png", bbox_inches='tight', pad_inches = 0)

In [None]:
fig, ax = plt.subplots(dpi=300, figsize=(1, 5))
ax.set_axis_off()
norm_features = features_faces[0, 83].numpy()
norm_features = norm_features / norm_features.max()
ax.tripcolor(template.sinn.triangulation, norm_features, cmap="cmo.thermal", vmin=0, vmax=1.,)
ax.set_xlim(-20000, 20000)
ax.set_ylim(-100000, 100000)
fig.savefig("figures/fig03_network/04_features_tri.png", bbox_inches='tight', pad_inches = 0)

In [None]:
fig, ax = plt.subplots(dpi=300, figsize=(1, 5))
ax.set_axis_off()
ax.tripcolor(template.sinn.triangulation, error_faces[0, 3].numpy(), cmap="cmo.balance", vmin=-1., vmax=1.)
ax.set_xlim(-20000, 20000)
ax.set_ylim(-100000, 100000)
fig.savefig("figures/fig03_network/05_prediction.png", bbox_inches='tight', pad_inches = 0)