### Testing functions for NAO matching ###

In [1]:
# Load autoreload extension
%load_ext autoreload
%autoreload 2

# Import local modules
import sys
import os
import pathlib
import glob
import re
import time

# Importing third party modules
import pandas as pd
import numpy as np
import xarray as xr
from tqdm import tqdm

In [2]:
# import local modules
sys.path.append("/home/users/benhutch/lagging-NAO-test-suite/alternate_lag_suite")

# Import alt lag functions
import alternate_lag_functions as funcs

In [3]:
# Print the files in the directory
dir = "/gws/nopw/j04/canari/users/benhutch/nao_stats_df/testing_nao_matching"

# List the files in the directory
# Loop over and print the files
for file in os.listdir(dir):
    print(file)

ens_members_psl_DJFM_2-9_1711021423.4131927.csv
ens_members_sfcWind_DJFM_2-9_1711016955.3786013.csv
ens_members_sfcWind_DJFM_2-9_1711017782.0252235.csv
ens_members_sfcWind_ONDJFM_2-9_1711026006.2899933.csv
overlapping_members_1969_1711017114.6193454.csv
overlapping_members_1969_1711017782.3638115.csv
overlapping_members_1969_1711021423.5805624.csv
overlapping_members_1969_1711026006.3853872.csv
overlapping_members_1970_1711017123.3044014.csv
overlapping_members_1970_1711017782.554255.csv
overlapping_members_1970_1711021423.741904.csv
overlapping_members_1970_1711026006.4870756.csv
overlapping_members_1971_1711017123.5466068.csv
overlapping_members_1971_1711017782.8350928.csv
overlapping_members_1971_1711021423.926739.csv
overlapping_members_1971_1711026006.7036777.csv
overlapping_members_1972_1711017123.7911603.csv
overlapping_members_1972_1711017782.9624166.csv
overlapping_members_1972_1711021424.1527839.csv
overlapping_members_1972_1711026006.7757576.csv
overlapping_members_1973_1711

In [13]:
# load in overlapping members 1969
overlapping_members_1969 = pd.read_csv(
    "/gws/nopw/j04/canari/users/benhutch/nao_stats_df/testing_nao_matching/overlapping_members_1969_1711021423.5805624.csv"
)

In [11]:
# load in rank_df_1969
rank_df_1969 = pd.read_csv(
    "/gws/nopw/j04/canari/users/benhutch/nao_stats_df/testing_nao_matching/rank_df_psl_1969_1711021421.9636939.csv"
)

In [15]:
rank_df_1969[:20]

Unnamed: 0.1,Unnamed: 0,ensemble_member,abs_diff,rank
0,576,CESM1-1-CAM5-CMIP5_r34i1p1f1_lag_0,0.985504,573
1,468,IPSL-CM6A-LR_r9i1p1f1_lag_0,2.374329,412
2,559,CESM1-1-CAM5-CMIP5_r2i1p1f1_lag_3,2.981232,154
3,650,NorCPM1_r2i1p1f1_lag_2,5.747742,130
4,250,EC-Earth3_r3i1_lag_2,5.95343,541
5,442,IPSL-CM6A-LR_r2i1p1f1_lag_2,6.549408,88
6,398,MIROC6_r1i1p1f1_lag_2,7.547699,115
7,485,CESM1-1-CAM5-CMIP5_r13i1p1f1_lag_1,8.159821,556
8,5,BCC-CSM2-MR_r2i1p1f1_lag_1,8.340851,442
9,220,HadGEM3-GC31-MM_r7i1_lag_0,8.463043,97


In [14]:
# overlapping members 1969
overlapping_members_1969

Unnamed: 0.1,Unnamed: 0,0
0,0,CESM1-1-CAM5-CMIP5_r34i1p1f1_lag_0
1,1,IPSL-CM6A-LR_r9i1p1f1_lag_0
2,2,CESM1-1-CAM5-CMIP5_r2i1p1f1_lag_3
3,3,NorCPM1_r2i1p1f1_lag_2
4,4,EC-Earth3_r3i1_lag_2
5,5,IPSL-CM6A-LR_r2i1p1f1_lag_2
6,6,MIROC6_r1i1p1f1_lag_2
7,7,CESM1-1-CAM5-CMIP5_r13i1p1f1_lag_1
8,8,BCC-CSM2-MR_r2i1p1f1_lag_1
9,9,HadGEM3-GC31-MM_r7i1_lag_0


In [12]:
# have a look at rank df for 1970
# find the file in {dir} which contains rank_df and 1970
for file in os.listdir(dir):
    if "rank_df" in file and "1970" in file:
        file_1970 = file

In [13]:
# Import the rank_df_1970
rank_df_1970 = pd.read_csv(f"{dir}/{file_1970}")

In [14]:
# find the file in {dir} which contains overlapping_members and 1970
for file in os.listdir(dir):
    if "overlapping_members" in file and "1970" in file:
        file_1970 = file

# Import the overlapping_members_1970
overlapping_members_1970 = pd.read_csv(f"{dir}/{file_1970}")

In [15]:
rank_df_1970

Unnamed: 0.1,Unnamed: 0,ensemble_member,abs_diff,rank
0,9,BCC-CSM2-MR_r3i1p1f1_lag_1,6.828125,20
1,5,BCC-CSM2-MR_r2i1p1f1_lag_1,13.248253,8
2,22,BCC-CSM2-MR_r6i1p1f1_lag_2,28.711838,29
3,15,BCC-CSM2-MR_r4i1p1f1_lag_3,46.729904,0
4,2,BCC-CSM2-MR_r1i1p1f1_lag_2,49.776154,22
5,8,BCC-CSM2-MR_r3i1p1f1_lag_0,56.000854,4
6,21,BCC-CSM2-MR_r6i1p1f1_lag_1,78.600784,14
7,12,BCC-CSM2-MR_r4i1p1f1_lag_0,82.41623,11
8,4,BCC-CSM2-MR_r2i1p1f1_lag_0,84.615486,2
9,20,BCC-CSM2-MR_r6i1p1f1_lag_0,89.57144,6


In [16]:
overlapping_members_1970

Unnamed: 0.1,Unnamed: 0,0
0,0,BCC-CSM2-MR_r3i1p1f1_lag_1
1,1,BCC-CSM2-MR_r2i1p1f1_lag_1
2,2,BCC-CSM2-MR_r6i1p1f1_lag_2
3,3,BCC-CSM2-MR_r4i1p1f1_lag_3
4,4,BCC-CSM2-MR_r1i1p1f1_lag_2
5,5,BCC-CSM2-MR_r3i1p1f1_lag_0
6,6,BCC-CSM2-MR_r6i1p1f1_lag_1
7,7,BCC-CSM2-MR_r4i1p1f1_lag_0
8,8,BCC-CSM2-MR_r2i1p1f1_lag_0
9,9,BCC-CSM2-MR_r6i1p1f1_lag_0


In [4]:
# Load in ens members
ens_members = pd.read_csv(f"{dir}/ens_members_psl_DJFM_2-9_1711021423.4131927.csv")

In [6]:
ens_members

Unnamed: 0.1,Unnamed: 0,0
0,0,BCC-CSM2-MR_r2i1p1f1_lag_0
1,1,BCC-CSM2-MR_r5i1p1f1_lag_0
2,2,BCC-CSM2-MR_r8i1p1f1_lag_0
3,3,BCC-CSM2-MR_r1i1p1f1_lag_0
4,4,BCC-CSM2-MR_r3i1p1f1_lag_0
...,...,...
707,707,NorCPM1_r7i1p1f1_lag_3
708,708,NorCPM1_r2i2p1f1_lag_3
709,709,NorCPM1_r5i2p1f1_lag_3
710,710,NorCPM1_r6i1p1f1_lag_3


In [None]:
%%time

# Test the NAO function
obs_nao, model_nao = funcs.calculate_nao_index(
    season="ONDJFM",
    forecast_range="2-9",
    start_year=1961,
    end_year=2014,
    models_list=["BCC-CSM2-MR"],
    plot=False,
) # test for a shorter time frame


In [None]:
obs_nao

In [None]:
model_nao

In [None]:
# remove the first lag - 1 time steps from the model data
lag = 4

model_nao = model_nao.isel(time=slice(lag - 1, None))

In [None]:
model_nao

In [None]:
# Remove the final lag - 1 time steps from the end of the model data
model_nao = model_nao.isel(time=slice(None, -1 * (lag - 1)))

In [None]:
model_nao

In [None]:
# Remove the first lag - 1 time steps from the obs data
obs_nao = obs_nao.isel(time=slice(lag - 1, None))

In [None]:
obs_nao

In [None]:
model_nao

In [None]:
# Calculate the correlation between the model and obs NAO index
from scipy.stats import pearsonr

print(obs_nao.shape)
print(model_nao["psl"].shape)

# Set up the ensemble mean of the model
model_nao_ens_mean = model_nao["psl"].mean(dim="ensemble_member")

In [None]:
%%time

# Calculate the correlation
corr, _ = pearsonr(obs_nao, model_nao_ens_mean)

In [None]:
# Calculate the standard deviation of the ensemble mean
sig_f_sig = np.std(model_nao_ens_mean)

# Calculate the standard deviation of the ensemble
sig_f_tot = np.std(model_nao["psl"])

# Calculate the standard deviation of the observations
sig_o_tot = np.std(obs_nao)

In [None]:
# Calculate the rpc
rpc = corr / (sig_f_sig / sig_f_tot)

# Calculate the rps
rps = rpc * (sig_o_tot / sig_f_tot)

In [None]:
# Scale the ensemble mean nao by the rps
model_nao_ens_mean = model_nao_ens_mean * rps

In [None]:
import matplotlib.pyplot as plt

# Set up a figure
fig, ax = plt.subplots()

# Plot the obs NAO
obs_nao.plot(ax=ax, label="Obs")

# Plot the model NAO
model_nao_ens_mean.plot(ax=ax, label="Model")

# Add a legend
ax.legend()

NAO looks approximately correct (but we also need to align the years), now we want to find the ensemble members which have the closest values to the signal adjusted NAO each year.

In [None]:
%%time

# For each year, calculate the absolute difference between the NAO of each
# of the ensemble members and the signal adjusted ensemble mean
# Then we want to create an ascending list of the ensemble members and their
# differences from the signal adjusted ensemble mean, from smallest to largest
# Then we want to calculate the rank of each ensemble member in this list

# Calculate the absolute difference between the NAO of each of the ensemble
# members and the signal adjusted ensemble mean
# Set up the years
years = model_nao['time.year'].values

# Extract the values for the data
model_nao_values = model_nao['psl'].values

# Extract the model

# limit years for testing
years = years[:10]

# Set up a list to store the ranks
rank_list = {year: [] for year in years}

# Loop over the years
for i, year in tqdm(enumerate(years)):
    # Get the ensemble members for the year
    year_ens = model_nao_values[i, :]

    # print the types
    print(type(year_ens))
    print(type(model_nao_ens_mean[i].values))

    # Calculate the absolute difference
    abs_diff = np.abs(year_ens - model_nao_ens_mean[i].values)

    # Create a dataframe
    df = pd.DataFrame(
        {
            "ensemble_member": model_nao['ensemble_member'].values,
            "abs_diff": abs_diff,
        }
    )

    # Sort the dataframe by the absolute difference
    df = df.sort_values(by="abs_diff")

    # Add the dataframe to the rank list
    rank_list[year] = df


In [None]:
rank_list[1974]

#### Matching for the variable ####

Now we want to pick a variable to match for, say, sfcWind. Then we want to find the ensemble members within the sfcWind ensemble which correspond to the 20 highest ranked members in the ranked list. First we want to assemble the ensemble list for this variable.

In [None]:
# Define a function for finding the ensemble members for a given variable
def find_ens_members(
    variable: str,
    models_list: list,
    season: str,
    forecast_range: str,
    start_year: int = 1961,
    end_year: int = 2014,
    lag: int = 4,
    region: str = "global",
    base_dir: str = "/gws/nopw/j04/canari/users/benhutch/skill-maps-processed-data",
):
    """
    Forms a list of all of the ensemble members for a given variable.

    Args:
            variable (str): The variable to find the ensemble members for.
            models_list (list): A list of the models to find the ensemble members for.
            start_year (int): The start year of the data.
            end_year (int): The end year of the data.
            region (str): The region to find the ensemble members for.
            base_dir (str): The base directory of the data.

    Returns:
            ens_members (list): A list of the ensemble members for the given variable.
    """

    # Set up a list to store the ensemble members
    ens_members = []

    # Loop over the models
    for model in tqdm(models_list):
        # Set up the fstem
        fstem = (
            f"{base_dir}/{variable}/{model}/{region}/{forecast_range}/{season}/outputs/"
        )

        # Set up the filename
        fname = f"*s{start_year}-*years_{forecast_range}_start_{start_year}_end_{end_year}_anoms.nc"

        # Find the files
        files = glob.glob(f"{fstem}{fname}")

        # Extract the filenames
        fnames = [file.split("/")[-1] for file in files]

        # Split the fnames by _
        fnames_split = [fname.split("_")[4] for fname in fnames]

        # Split this list by -
        members = [fname.split("-")[1] for fname in fnames_split]

        # Find the unique members
        unique_members = list(set(members))

        # New members list
        new_members = []

        if lag is not None:
            for lag_idx in range(lag):
                # append the model to the members
                new_members += [
                    f"{model}_{member}_lag_{lag_idx}" for member in unique_members
                ]
        else:
            # append the model to the members
            # f"{model}_{member}" for member in members
            new_members = [f"{model}_{member}" for member in members]

        # Add the members to the list
        ens_members.extend(new_members)

    return ens_members

In [None]:
# Import dictionaries
sys.path.append("/home/users/benhutch/lagging-NAO-test-suite/")

# Import dictionaries
import dictionaries as dicts

In [None]:
# test this function
sfcwind_ens = find_ens_members(
    variable="sfcWind",
    models_list=dicts.sfcWind_models,
    season="ONDJFM",
    forecast_range="2-9",
    start_year=1961,
    end_year=2014,
    region="global",
    base_dir="/gws/nopw/j04/canari/users/benhutch/skill-maps-processed-data",
)

In [None]:
len(sfcwind_ens)

In [None]:
# Write a function which finds the overlapping members
# between rank_list and sfcwind_ens and returns a list of
# the overlapping members
def find_overlapping_members(
    rank_list: dict,
    ens_mems: list,
    no_members: int = 20,
):
    """
    Finds the overlapping members between the rank list and the sfcwind ensemble members.

    Args:
        rank_list (dict): A dictionary containing the ranks of the ensemble members.
        ens_mems (list): A list of the ensemble members for the variable.
        no_members (int): The number of members to return.

    Returns:
        overlapping_members (list): A list of the overlapping members.
    """

    # Set up a dictionary to store the overlapping members for each year
    overlapping_members = {year: [] for year in rank_list.keys()}

    # Loop over the years in the rank list
    for year in rank_list:
        # Get the ensemble members for the year, sorted in descending order
        year_df = rank_list[year]

        # Sort this dataframe by the absolute difference
        year_df = year_df.sort_values(by="abs_diff")

        # Loop over the ensemble members
        for member in year_df["ensemble_member"]:
            # Check if the member is in the sfcwind ensemble members
            if member in sfcwind_ens:
                # Add the member to the overlapping members for the year
                overlapping_members[year].append(member)
                # If we have found 20 members for the year, stop looking
                if len(overlapping_members[year]) == 20:
                    break

    return overlapping_members

Now that we have this list, we want to find the members which match those within the nao rank_list for each year, then we need to extract these.

In [None]:
# Test the new function
ovelap_mem = find_overlapping_members(rank_list, sfcwind_ens)

In [None]:
ovelap_mem[1974]

Now that we have this, we can find the matching anoms files for each year. Define a function for extracting this.

In [None]:
# Define a function for finding the ensemble members for a given variable
def find_matched_members(
    overlap_mem: dict,
    variable: str,
    forecast_range: str,
    season: str,
    n_matched_mems=20,
    start_year: int = 1961,
    end_year: int = 2014,
    lag: int = 4,
    region: str = "global",
    alt_lag: bool = False,
    base_dir: str = "/gws/nopw/j04/canari/users/benhutch/skill-maps-processed-data",
):
    """
    Finds the ensemble member for a given variable that matches the overlapping members.
    These overlapping members are the ones that are closest to the signal adjusted ensemble mean
    for the NAO index, which are also present in the matched variables ensemble members.

    Args:
            overlap_mem (dict): A dictionary containing the overlapping members for the variable.
            variable (str): The variable to find the matched member for.
            forecast_range (str): The forecast range of the data.
            season (str): The season of the data.
            n_matched_mems (int): The number of matched members to return.
            start_year (int): The start year of the data.
            end_year (int): The end year of the data.
            region (str): The region of the data.
            alt_lag (bool): Whether to use the alternate lag suite.
            base_dir (str): The base directory of the data.

    Returns:
            matched_member (dict): A dictionary containing the matched member for the variable.
    """

    # If alt_lag is True, we want to use the alternate lag suite
    if forecast_range in ["1", "2"]:
        # Set up the base directory
        AssertionError("Alternate lag suite not yet implemented")

    # Set up the years
    years = list(overlap_mem.keys())

    # Extract the first lag 0 file, for the first year
    # To get the dimensions of the data (lat, lon)
    # Set up the fstem
    members = overlap_mem[years[0]]

    # Constrain members to those with lag_0
    test_member = [member for member in members if "lag_0" in member][0]

    # # Print the member
    # print("test member: ", test_member)

    # Set up the model
    model = test_member.split("_")[0]
    member = test_member.split("_")[1]

    # Set up the fstem
    fstem = f"{base_dir}/{variable}/{model}/{region}/{forecast_range}/{season}/outputs/"

    # Set up the fname
    fname = f"all-years*_s{years[0]}-{member}*years_{forecast_range}_start_{start_year}_end_{end_year}_anoms.nc"

    # # Print the path
    # print(f"{fstem}{fname}")

    # Find the files
    first_file = glob.glob(f"{fstem}{fname}")[0]

    # # Print the first file
    # print("first file: ", first_file)

    # Open the file
    first_ds = xr.open_dataset(
        first_file,
        chunks={"time": "auto", "lat": "auto", "lon": "auto"},
        engine="netcdf4",
    )

    # Extract the dimensions
    lats = first_ds["lat"].values
    lons = first_ds["lon"].values

    # Close the dataset
    first_ds.close()

    # Set up the array
    # LIke (20, 54, 72, 144)
    # After taking the avg. over the time window
    matched_mem_arr = np.zeros([n_matched_mems, len(years), len(lats), len(lons)])

    # SHift the years lift by the offset
    if forecast_range == "2-9":
        years = [year - 5 for year in years]
        df_offset = 5
    elif forecast_range == "2-5":
        # FIXME: Check this
        years = [year - 3 for year in years]
        df_offset = 3
    else:
        # Assertion error, forecast range not recognised
        assert False, "Forecast range not recognised"

    # Loop over the years
    for i, year in enumerate(
        tqdm(years, desc="Extracting NAO matched members for year: ")
    ):
        print("year: ", year)
        # Get the overlapping members for the year
        year_df = overlap_mem[year + df_offset]

        # Loop over the members
        for j, member in enumerate(year_df):

            # Set up the fstem
            model = member.split("_")[0]
            member_id = member.split("_")[1]
            lag_idx = int(member.split("_")[3])

            # Print the components
            print("model: ", model, "member: ", member_id, "lag_idx: ", lag_idx)

            # Set up the fstem
            fstem = f"{base_dir}/{variable}/{model}/{region}/{forecast_range}/{season}/outputs/"

            # Set up the s{year}, depending on the lag index
            # E.g. for 1964, lag 0, we want s1964
            # E.g. for 1964, lag 1, we want s1963
            # Asser that the lag index is an integer
            assert isinstance(lag_idx, int)

            # Set up the year
            init_year = year - lag_idx

            if season in ["DJF", "DJFM", "ONDJFM"]:
                # Set up the fname
                fname = f"all-years*_s{init_year}-{member_id}*years_{forecast_range}_start_{start_year}_end_{end_year}_anoms.nc"
            else:
                # Extract the first digit in forecast range
                first_digit = forecast_range.split("-")[0]
                last_digit = forecast_range.split("-")[1]

                # Set up the forecast range
                forecast_range_sum = f"{first_digit + 1}-{last_digit + 1}"

                # Set up the fname
                fname = f"{variable}_s{init_year}-{member_id}*years_{forecast_range_sum}_start_{start_year}_end_{end_year}_anoms.nc"

            # Set up the start year and end year indices
            if "-" in forecast_range:
                start_year_idx = int(forecast_range.split("-")[0])
                end_year_idx = int(forecast_range.split("-")[1])
            else:
                start_year_idx = int(forecast_range)
                end_year_idx = int(forecast_range)

            # If the model name is BCC-CSM2-MR, we need to set up the indexes differently
            # If the model name is BCC-CSM2-MR
            if model == "BCC-CSM2-MR":
                # Set the start year index
                start_year_idx = start_year_idx
                # Set the end year index
                end_year_idx = end_year_idx + 1  # jan of this year
            else:
                # Set the start year index
                start_year_idx = start_year_idx - 1

                # Set the end year index
                end_year_idx = end_year_idx - 1 + 1  # jan of this year

            # Find the files
            file = glob.glob(f"{fstem}{fname}")

            # Assert that the length of the file is 1
            assert len(file) == 1, f"Length of file is not 1, it is {len(file)}"

            # Print the loading file
            print(f"Loading file: {file[0].split('/')[-1]}")

            # Open the dataset
            ds = xr.open_dataset(
                file[0],
                chunks={"time": "auto", "lat": "auto", "lon": "auto"},
                engine="netcdf4",
            )

            # Extract the years
            years = ds.time.dt.year.values

            # Find the unique years
            unique_years = np.unique(years)

            # Extract the first year
            first_year = int(unique_years[start_year_idx])
            last_year = int(unique_years[end_year_idx])

            # If the forecast range is years 2-9
            if forecast_range == "2-9":
                # Form the strings for the start and end dates
                start_date = f"{first_year}-01-01"
                end_date = f"{last_year}-01-01"
            elif forecast_range == "2-5":
                # Form the strings for the start and end dates depending on the lag
                if lag_idx == 0:
                    start_date = f"{first_year}-01-01"
                    end_date = f"{first_year}-01-01"
                else:
                    start_date = f"{first_year + lag_idx}-01-01"
                    end_date = f"{last_year + lag_idx}-01-01"
            else:
                # Assertion error, forecast range not recognised
                assert False, "Forecast range not recognised"

            # Take the mean over the time dimension between the start and end dates
            ds = ds.sel(time=slice(start_date, end_date)).mean("time")

            # Extract the values
            vals = ds[variable].values

            # Add the values to the array
            matched_mem_arr[j, i, :, :] = vals

            # Close the dataset
            ds.close()

    # Return the array
    return matched_mem_arr

In [None]:
matched_sfcwind_arr = find_matched_members(
    overlap_mem=ovelap_mem,
    variable="sfcWind",
    forecast_range="2-9",
    season="ONDJFM",
    n_matched_mems=20,
    start_year=1961,
    end_year=2014,
    lag=4,
    region="global",
    alt_lag=False,
    base_dir="/gws/nopw/j04/canari/users/benhutch/skill-maps-processed-data",
)

In [None]:
print(matched_sfcwind_arr.shape)

In [None]:
print(matched_sfcwind_arr)