In [None]:
import sys
sys.path.append("../../")

import os


from tqdm.notebook import tqdm


import torch
import torch.nn
from torch.utils.data import DataLoader, Dataset

import xarray as xr
import numpy as np
import pandas as pd

from distributed import LocalCluster, Client

from hydra import initialize, compose
from hydra.utils import instantiate

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mpl_colors
import matplotlib.gridspec as mpl_gs
import cmocean
import cmcrameri

from src_screening.datasets import OfflineDataset
from src_screening.model.wave_forcing import WaveForcing

In [None]:
plt.style.use("paper")
plt.style.use("egu_journals")

In [None]:
local_cluster = LocalCluster(n_workers=48, threads_per_worker=1, local_directory="/tmp")
client = Client(local_cluster)
client

# Load data

In [None]:
ds_nature = xr.open_dataset("../../data/raw/test/lr_nature_forecast/", engine="zarr", chunks={"time": -1, "ensemble": 1, "lead_time": -1, "nMesh2_node": -1, "nMesh2_face": -1})
ds_forecast = xr.open_dataset("../../data/processed/unext_small/9/traj_short/", engine="zarr", chunks={"time": -1, "ensemble": 1, "lead_time": -1, "nMesh2_node": -1, "nMesh2_face": -1})

## Estimate forecast error and forecast update

In [None]:
forecast_error = (ds_forecast.sel(lead_time=["9 min 52 s", "50 min 24 s"])-ds_nature).reset_index("lead_time", drop=True)
forecast_update = (ds_forecast.roll(lead_time=1, roll_coords=False)-ds_forecast).sel(lead_time=["10 min 8 s", "50 min 40 s"]).reset_index("lead_time", drop=True)

In [None]:
rms_error = np.sqrt((forecast_error**2).mean(["ensemble", "time", "nMesh2_node", "nMesh2_face"]))
rms_update = np.sqrt((forecast_update**2).mean(["ensemble", "time", "nMesh2_node", "nMesh2_face"]))

In [None]:
first_ratio = (rms_update.isel(lead_time=0)/rms_error.isel(lead_time=0)).compute()
second_ratio = (rms_update.isel(lead_time=1)/rms_error.isel(lead_time=1)).compute()

# Estimate correlation between update and error

In [None]:
error_perts = forecast_error-forecast_error.mean(["nMesh2_face", "nMesh2_node"])
update_perts = forecast_update-forecast_update.mean(["nMesh2_face", "nMesh2_node"])

In [None]:
pattern_covariance = (error_perts*update_perts).sum(["nMesh2_face", "nMesh2_node"]) / (xr.ones_like(error_perts).sum(["nMesh2_face", "nMesh2_node"])-1)
error_std = forecast_error.std(["nMesh2_face", "nMesh2_node"], ddof=1)
update_std = forecast_update.std(["nMesh2_face", "nMesh2_node"], ddof=1)

In [None]:
pattern_correlation = pattern_covariance / (error_std+1E-9) / (update_std+1E-9)

In [None]:
mean_pattern_correlation = np.tanh(np.arctanh(pattern_correlation).to_array("var_names").mean(["var_names", "ensemble", "time"])).compute()

In [None]:
average_correlation = np.tanh(np.arctanh(pattern_correlation).mean(["ensemble", "time"])).compute()

In [None]:
pandas_correlation = average_correlation[["v", "stress_yy", "damage", "area"]].to_array("var_names").T.to_pandas()

In [None]:
pandas_correlation["mean"] = mean_pattern_correlation.to_pandas()

In [None]:
pandas_correlation.round(2)

In [None]:
pandas_correlation.round(2).to_latex()

#### 