# **0.Import libraries**

In [None]:
from dataset import load_dataset
# from models.model_2_dropout_imp import ClimatePINN
from train import plot_comparison
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
import xarray as xr
from visualisation import visualize_predictions
import json
import seaborn as sns
import pandas as pd
from scipy.stats import pearsonr
from tqdm import tqdm
%matplotlib inline

# **1.Load data**

In [None]:
dataset = load_dataset(
    1, 
    train_val_split = None, 
    year0=2000, 
    root_dir="./data/era_5_data",
    normalize=True)

idx = [0,1,2]
inputs, targets = dataset['train'][:]['input'], dataset['train'][:]['target']
geo500 , t850 = inputs[:,0,...], inputs[:,1,...]
t2m, u, v = targets[:,0,...], targets[:,1,...], targets[:,2,...]
inputs = (geo500 , t850)
targets = (t2m, u, v)
lon, lat= dataset['train'][0]['coords'][0], dataset['train'][0]['coords'][1]

In [None]:
vars = (geo500 , t850, t2m, u, v)

corr_matrix = np.zeros((5,5))
vars_name = ['input:geo500', 'input:t850', 'target:t2m', 'target:u', 'target:v']
corr_matrix = np.zeros((5, 5, 32, 64))
import matplotlib.colors as mcolors
norm = mcolors.Normalize(vmin=-1, vmax=1)
for x, var_x in tqdm(enumerate(vars)):
    for y, var_y in enumerate(vars):
        corr = pearsonr(var_x, var_y)[0]
        corr_matrix[x,y] = corr

# **2.Correlation**

In [None]:
# Flatten the arrays
vars = (geo500.flatten(), t850.flatten(), t2m.flatten(), u.flatten(), v.flatten())
vars_name = ['geo500', 't850', 't2m', 'u', 'v']

# Compute the correlation matrix
corr_matrix = np.zeros((5, 5))
for x, var_x in tqdm(enumerate(vars)):
    for y, var_y in enumerate(vars):
        corr = pearsonr(var_x, var_y)[0]
        corr_matrix[x, y] = corr

# Create a figure and axes
fig, ax = plt.subplots(figsize=(10, 8))

# Plot the correlation matrix
im = ax.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)

# Add color bar
cbar = fig.colorbar(im, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
cbar.set_label('Correlation Coefficient')

# Add labels to the matrix
ax.set_xticks(np.arange(len(vars_name)))
ax.set_yticks(np.arange(len(vars_name)))
ax.set_xticklabels(vars_name)
ax.set_yticklabels(vars_name)

# Rotate the tick labels for better readability
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

# Add separators
ax.axvline(x=1.5, color='black', linewidth=2)
ax.axhline(y=1.5, color='black', linewidth=2)

# Annotate the cells with the correlation coefficients
for i in range(len(vars_name)):
    for j in range(len(vars_name)):
        text = ax.text(j, i, f"{corr_matrix[i, j]:.2f}",
                       ha="center", va="center", color="black")

# Add title
plt.title('Correlation Matrix')

# Show plot
plt.show()