### Testing functions for NAO matching ###

In [None]:
# Load autoreload extension
%load_ext autoreload
%autoreload 2

# Import local modules
import sys
import os
import pathlib
import glob
import re
import time

# Importing third party modules
import pandas as pd
import numpy as np
import xarray as xr
from tqdm import tqdm

In [None]:
# import local modules
sys.path.append('/home/users/benhutch/lagging-NAO-test-suite/alternate_lag_suite')

# Import alt lag functions
import alternate_lag_functions as funcs

In [None]:
%%time

# Test the NAO function
obs_nao, model_nao = funcs.calculate_nao_index(
    season="ONDJFM",
    forecast_range="2-9",
    start_year=1961,
    end_year=2014,
    models_list=["BCC-CSM2-MR"],
    plot=False,
) # test for a shorter time frame


In [None]:

obs_nao

In [None]:
model_nao

In [None]:
# remove the first lag - 1 time steps from the model data
lag = 4

model_nao = model_nao.isel(time=slice(lag - 1, None))

In [None]:
model_nao

In [None]:
# Remove the final lag - 1 time steps from the end of the model data
model_nao = model_nao.isel(time=slice(None, -1 * (lag - 1)))

In [None]:
model_nao

In [None]:
# Remove the first lag - 1 time steps from the obs data
obs_nao = obs_nao.isel(time=slice(lag - 1, None))

In [None]:
obs_nao

In [None]:
model_nao

In [None]:
# Calculate the correlation between the model and obs NAO index
from scipy.stats import pearsonr

print(obs_nao.shape)
print(model_nao['psl'].shape)

# Set up the ensemble mean of the model
model_nao_ens_mean = model_nao['psl'].mean(dim="ensemble_member")

In [None]:
%%time

# Calculate the correlation
corr, _ = pearsonr(obs_nao, model_nao_ens_mean)

In [None]:
# Calculate the standard deviation of the ensemble mean
sig_f_sig = np.std(model_nao_ens_mean)

# Calculate the standard deviation of the ensemble
sig_f_tot = np.std(model_nao['psl'])

# Calculate the standard deviation of the observations
sig_o_tot = np.std(obs_nao)

In [None]:
# Calculate the rpc
rpc = corr / (sig_f_sig / sig_f_tot)

# Calculate the rps
rps = rpc * (sig_o_tot / sig_f_tot)

In [None]:
# Scale the ensemble mean nao by the rps
model_nao_ens_mean = model_nao_ens_mean * rps

In [None]:
import matplotlib.pyplot as plt

# Set up a figure
fig, ax = plt.subplots()

# Plot the obs NAO
obs_nao.plot(ax=ax, label="Obs")

# Plot the model NAO
model_nao_ens_mean.plot(ax=ax, label="Model")

# Add a legend
ax.legend()

NAO looks approximately correct (but we also need to align the years), now we want to find the ensemble members which have the closest values to the signal adjusted NAO each year.

In [None]:
%%time

# For each year, calculate the absolute difference between the NAO of each
# of the ensemble members and the signal adjusted ensemble mean
# Then we want to create an ascending list of the ensemble members and their
# differences from the signal adjusted ensemble mean, from smallest to largest
# Then we want to calculate the rank of each ensemble member in this list

# Calculate the absolute difference between the NAO of each of the ensemble
# members and the signal adjusted ensemble mean
# Set up the years
years = model_nao['time.year'].values

# Extract the values for the data
model_nao_values = model_nao['psl'].values

# Extract the model

# limit years for testing
years = years[:10]

# Set up a list to store the ranks
rank_list = {year: [] for year in years}

# Loop over the years
for i, year in tqdm(enumerate(years)):
    # Get the ensemble members for the year
    year_ens = model_nao_values[i, :]

    # print the types
    print(type(year_ens))
    print(type(model_nao_ens_mean[i].values))

    # Calculate the absolute difference
    abs_diff = np.abs(year_ens - model_nao_ens_mean[i].values)

    # Create a dataframe
    df = pd.DataFrame(
        {
            "ensemble_member": model_nao['ensemble_member'].values,
            "abs_diff": abs_diff,
        }
    )

    # Sort the dataframe by the absolute difference
    df = df.sort_values(by="abs_diff")

    # Add the dataframe to the rank list
    rank_list[year] = df


In [None]:
rank_list[1972] 

#### Matching for the variable ####

Now we want to pick a variable to match for, say, sfcWind. Then we want to find the ensemble members within the sfcWind ensemble which correspond to the 20 highest ranked members in the ranked list. First we want to assemble the ensemble list for this variable.

In [None]:
# Define a function for finding the ensemble members for a given variable
def find_ens_members(
                variable: str,
                models_list: list,
                season: str,
                forecast_range: str,
                start_year: int = 1961,
                end_year: int = 2014,
                lag: int = 4,
                region: str = "global",
                base_dir: str = "/gws/nopw/j04/canari/users/benhutch/skill-maps-processed-data",
):
        """
        Forms a list of all of the ensemble members for a given variable.

        Args:
                variable (str): The variable to find the ensemble members for.
                models_list (list): A list of the models to find the ensemble members for.
                start_year (int): The start year of the data.
                end_year (int): The end year of the data.
                region (str): The region to find the ensemble members for.
                base_dir (str): The base directory of the data.

        Returns:
                ens_members (list): A list of the ensemble members for the given variable.
        """

        # Set up a list to store the ensemble members
        ens_members = []

        # Loop over the models
        for model in tqdm(models_list):
                # Set up the fstem
                fstem = (
                        f"{base_dir}/{variable}/{model}/{region}/{forecast_range}/{season}/outputs/"
                )
                                
                # Set up the filename
                fname = f"*s{start_year}-*years_{forecast_range}_start_{start_year}_end_{end_year}_anoms.nc"

                # Find the files
                files = glob.glob(f"{fstem}{fname}")

                # Extract the filenames
                fnames = [file.split("/")[-1] for file in files]

                # Split the fnames by _
                fnames_split = [fname.split("_")[4] for fname in fnames]

                # Split this list by -
                members = [fname.split("-")[1] for fname in fnames_split]

                # Find the unique members
                unique_members = list(set(members))

                # New members list
                new_members = []

                if lag is not None:
                        for lag_idx in range(lag):
                                # append the model to the members
                                new_members += [f"{model}_{member}_lag_{lag_idx}" for member in unique_members]
                else:
                        # append the model to the members
                        #f"{model}_{member}" for member in members
                        new_members = [f"{model}_{member}" for member in members]

                # Add the members to the list
                ens_members.extend(new_members)

        return ens_members


In [None]:
# Import dictionaries
sys.path.append('/home/users/benhutch/lagging-NAO-test-suite/')

# Import dictionaries
import dictionaries as dicts

In [None]:
# test this function
sfcwind_ens = find_ens_members(
    variable="sfcWind",
    models_list=dicts.sfcWind_models,
    season="ONDJFM",
    forecast_range="2-9",
    start_year=1961,
    end_year=2014,
    region="global",
    base_dir="/gws/nopw/j04/canari/users/benhutch/skill-maps-processed-data",
)

In [None]:
len(sfcwind_ens)

Now that we have this list, we want to find the members which match those within the nao rank_list for each year, then we need to extract these.