In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import os
from qgsw.utils.sorting import sort_files
from qgsw.run_summary import RunSummary
import torch
from matplotlib.axes import Axes
from qgsw.comparison.comparators import RMSE, absolute_difference
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.figure import Figure

In [None]:
ROOT = Path(os.path.abspath('')).parent

In [None]:


def make_axes(figure: Figure, row:int, shape:tuple[int,int]):
    axes_1:Axes = plt.subplot2grid(shape,(row,0), fig=figure)
    axes_cbar_1:Axes = make_axes_locatable(axes_1).append_axes("right", size="7%", pad="2%")
    axes_diff:Axes = plt.subplot2grid(shape,(row,1), fig=figure)
    axes_cbar_diff:Axes = make_axes_locatable(axes_diff).append_axes("right", size="7%", pad="2%")
    axes_2:Axes = plt.subplot2grid(shape,(row,2), fig=figure)
    axes_cbar_2:Axes = make_axes_locatable(axes_2).append_axes("right", size="7%", pad="2%")
    axes_rmse:Axes = plt.subplot2grid(shape,(row,3), fig=figure, colspan=2)
    return axes_1, axes_2, axes_diff, axes_rmse, axes_cbar_1, axes_cbar_2, axes_cbar_diff

def extract(file_1:Path, file_2:Path, field: str) -> tuple[torch.Tensor,torch.Tensor]:
    return torch.tensor(np.load(file_1)[field])[0,0,...][24:-24,24:-24], torch.tensor(np.load(file_2)[field])[0,0,...][24:-24,24:-24]

def plot(file_1:Path, file_2:Path, rmse:list[float], field: str, axes_1:Axes, axes_2:Axes, axes_diff:Axes, axes_rmse:Axes, axes_cbar_1:Axes, axes_cbar_2:Axes,axes_cbar_diff:Axes) -> list[float] :
    
    axes_1.cla()
    axes_1.set_title(f"1 - {field}")
    axes_2.cla()
    axes_2.set_title(f"2 - {field}")
    axes_diff.cla()
    axes_diff.set_title(f"|1-2| - {field}")
    axes_cbar_1.cla()
    axes_cbar_2.cla()
    axes_cbar_diff.cla()
    axes_rmse.cla()
    axes_rmse.set_title(f"RMSE - {field}")

    data_1, data_2 = extract(file_1, file_2, field)
    # data_1 = data_1 / data_2.max()
    # data_2 = data_2 / data_2.max()

    vmax = max(torch.max(torch.abs(data_1)),torch.max(torch.abs(data_2)))

    cbar_1 = axes_1.imshow(data_1, vmin=-vmax, vmax=vmax)
    plt.colorbar(cbar_1, cax=axes_cbar_1)
    cbar_2 = axes_2.imshow(data_2, vmin=-vmax, vmax=vmax)
    plt.colorbar(cbar_2, cax=axes_cbar_2)
    # cbar_diff = axes_diff.imshow(absolute_difference(data_1,data_2))
    cbar_diff = axes_diff.imshow(torch.square(data_1 - data_2) / 9.375e-5)
    plt.colorbar(cbar_diff, cax=axes_cbar_diff)

    data2_above_thres = torch.abs((data_2 * 200 - 9.375e-5)) > (10*9.375e-5)

    err = torch.sqrt(torch.mean(torch.square(data_1 - data_2))) / 9.375e-5

    rmse.append(err)
    axes_rmse.plot(rmse, c='blue')

    return rmse


In [None]:
%matplotlib tk

folder_1 = "sf_barotropic"
folder_2 = "two_layers_barotropic_100km"

path_1 = ROOT.joinpath(f"output/g5k/{folder_1}")
run_1 = RunSummary.from_file(path_1.joinpath("_summary.toml"))
steps_1, files_1 = sort_files(list(path_1.glob(f"{run_1.configuration.model.prefix}*.npz")),run_1.configuration.model.prefix,".npz")
path_2 = ROOT.joinpath(f"output/g5k/{folder_2}")
run_2 = RunSummary.from_file(path_2.joinpath("_summary.toml"))
steps_2, files_2 = sort_files(list(path_2.glob(f"{run_2.configuration.model.prefix}*.npz")),run_2.configuration.model.prefix,".npz")

if not (run_2.dt == run_1.dt):
    print("Timesteps are not matching.")
if not (run_2.duration == run_1.duration):
    print("Duration are not matching.")

plt.ion()

shape = (4,5)

fig = plt.figure(figsize = (4*shape[1],3*shape[0]))
fig.suptitle(f"1: {folder_1} | 2: {folder_2}")

# pv_axes = make_axes(fig,0,shape)
u_axes = make_axes(fig,1,shape)
# v_axes = make_axes(fig,2,shape)
# h_axes = make_axes(fig,3,shape)

rmse_pv = []
rmse_u = []
rmse_v = []
rmse_h = []

for i in range(len(files_1)):
    if not (steps_2[i] == steps_1[i]):
        msg = f"Impossible to match steps {steps_2[i]} and {steps_1[i]}."
        raise ValueError(msg)
    
    file_1 = files_1[i]
    file_2 = files_2[i]
    
    # rmse_pv = plot(file_1,file_2,rmse_pv,"pv",*pv_axes)
    rmse_u = plot(file_1,file_2,rmse_u,"pv",*u_axes)
    # rmse_v = plot(file_1,file_2,rmse_v,"v",*v_axes)
    # rmse_h = plot(file_1,file_2,rmse_h,"h",*h_axes)

    plt.pause(0.01)

plt.close()
plt.ioff()