### Testing functions for NAO matching ###

In [None]:
# Load autoreload extension
%load_ext autoreload
%autoreload 2

# Import local modules
import sys
import os
import glob
import re

# Importing third party modules
import pandas as pd
import numpy as np
import xarray as xr
from tqdm import tqdm

In [None]:
# import local modules
sys.path.append('/home/users/benhutch/lagging-NAO-test-suite/alternate_lag_suite')

# Import alt lag functions
import alternate_lag_functions as funcs

In [None]:
# Test the NAO function
paths = funcs.calculate_nao_index(season="ONDJFM",
                            forecast_range="2-9",)


In [None]:
## Define a function for preprocessing the model data
def preprocess(ds: xr.Dataset,
               forecast_range: str,
               filenames: list,
               lag: int,):
    """
    Preprocess the model data using xarray
    """

    # /gws/nopw/j04/canari/users/benhutch/skill-maps-processed-data/psl/BCC-CSM2-MR/global/2-9/ONDJFM/outputs/all-years-ONDJFM-global-psl_Amon_BCC-CSM2-MR_dcppA-hindcast_s1961-r1i1p1f1_gn_196101-197012_years_2-9_start_1961_end_2014_anoms.nc

    # Expand the dimensions of the dataset
    ds = ds.expand_dims('ensemble_member')

    # Set up the params for the ensemble member
    # Split the filename by the final /
    filenames_split = [file.split("/")[-1] for file in filenames]

    # Split the filename by the _
    model_name = [file.split("_")[2] for file in filenames_split]

    # Split the filename by the _
    variant_label = [file.split("_")[4].split("-")[1] for file in filenames_split]

    # Extract the unique model names
    model_name = np.unique(model_name)[0]

    # Extract the unique variant labels
    variant_label = np.unique(variant_label)[0]

    # Set the ensemble member
    ds['ensemble_member'] = [f"{model_name}_{variant_label}_lag_{lag}"]

    # Extract the years from the data
    years = ds.time.dt.year.values

    # Find the unique years
    unique_years = np.unique(years)

    # If forecast range contains a hyphen
    if "-" in forecast_range:
        start_year_idx = int(forecast_range.split("-")[0])
        end_year_idx = int(forecast_range.split("-")[1])
    else:
        start_year_idx = int(forecast_range)
        end_year_idx = int(forecast_range)

    # Extract the first year
    first_year = int(unique_years[start_year_idx - 2])

    # Extract the last year
    last_year = int(unique_years[end_year_idx - 2])

    # If the forecast range is years 2-9
    if forecast_range == "2-9":
        # Form the strings for the start and end dates
        start_date = f"{first_year}-01-01" ; end_date = f"{last_year + 1}-01-01"
    elif forecast_range == "2-5":
        # Form the strings for the start and end dates depending on the lag
        if lag == 0:
            start_date = f"{first_year}-01-01" ; end_date = f"{first_year + 1}-01-01"
        else:
            start_date = f"{first_year + lag}-01-01" ; end_date = f"{last_year + lag + 1}-01-01"
    else:
        # Assertion error, forecast range not recognised
        assert False, "Forecast range not recognised"

    # Find the centre of the period between start and end date
    mid_date = pd.to_datetime(start_date) + (pd.to_datetime(end_date) - pd.to_datetime(start_date)) / 2

    # Take the mean over the time dimension
    ds = ds.sel(time=slice(start_date, end_date)).mean(dim='time')

    # If the lag is 0
    if lag == 0:
        # Set the time to the mid date
        ds['time'] = mid_date
    else:
        # Set the time to the mid date
        ds['time'] = mid_date + pd.DateOffset(years=lag)

    # Return the dataset
    return ds

In [None]:
# Limit the paths to the first 8
paths = paths[:50]

In [None]:
bcc_test = []

# Loop over the paths
for path in tqdm(paths):

    # Loop over the lags
    for k in range(0, 4):
        # print(f"Processing lag index {k}")

        # Load the data
        ds = xr.open_mfdataset(path,
                            preprocess=lambda ds: preprocess(ds, forecast_range="2-9",
                                                                filenames=path,
                                                                lag=k),
                            combine='nested',
                            concat_dim='time',
                            join='override',
                            coords='minimal',
                            engine='netcdf4',
                            parallel=True,)
    
        # Append the data to the list
        bcc_test.append(ds)

# Concatenate the data
bcc_test = xr.concat(bcc_test, dim='ensemble_member')

In [None]:
bcc_test