### Testing functions for NAO matching ###

In [16]:
# Load autoreload extension
%load_ext autoreload
%autoreload 2

# Import local modules
import sys
import os
import glob
import re

# Importing third party modules
import pandas as pd
import numpy as np
import xarray as xr

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# import local modules
sys.path.append('/home/users/benhutch/lagging-NAO-test-suite/alternate_lag_suite')

# Import alt lag functions
import alternate_lag_functions as funcs

In [10]:
# Test the NAO function
paths = funcs.calculate_nao_index(season="ONDJFM",
                            forecast_range="2-9",)




Time dimension of obs: ['1960-12-31T00:00:00.000000000' '1961-12-31T00:00:00.000000000'
 '1962-12-31T00:00:00.000000000' '1963-12-31T00:00:00.000000000'
 '1964-12-31T00:00:00.000000000' '1965-12-31T00:00:00.000000000'
 '1966-12-31T00:00:00.000000000' '1967-12-31T00:00:00.000000000'
 '1968-12-31T00:00:00.000000000' '1969-12-31T00:00:00.000000000'
 '1970-12-31T00:00:00.000000000' '1971-12-31T00:00:00.000000000'
 '1972-12-31T00:00:00.000000000' '1973-12-31T00:00:00.000000000'
 '1974-12-31T00:00:00.000000000' '1975-12-31T00:00:00.000000000'
 '1976-12-31T00:00:00.000000000' '1977-12-31T00:00:00.000000000'
 '1978-12-31T00:00:00.000000000' '1979-12-31T00:00:00.000000000'
 '1980-12-31T00:00:00.000000000' '1981-12-31T00:00:00.000000000'
 '1982-12-31T00:00:00.000000000' '1983-12-31T00:00:00.000000000'
 '1984-12-31T00:00:00.000000000' '1985-12-31T00:00:00.000000000'
 '1986-12-31T00:00:00.000000000' '1987-12-31T00:00:00.000000000'
 '1988-12-31T00:00:00.000000000' '1989-12-31T00:00:00.000000000'
 '

In [22]:
## Define a function for preprocessing the model data
def preprocess(ds: xr.Dataset,
               forecast_range: str,
               filenames: list,):
    """
    Preprocess the model data using xarray
    """

    # /gws/nopw/j04/canari/users/benhutch/skill-maps-processed-data/psl/BCC-CSM2-MR/global/2-9/ONDJFM/outputs/all-years-ONDJFM-global-psl_Amon_BCC-CSM2-MR_dcppA-hindcast_s1961-r1i1p1f1_gn_196101-197012_years_2-9_start_1961_end_2014_anoms.nc

    # Expand the dimensions of the dataset
    ds = ds.expand_dims('ensemble_member')

    # Set up the params for the ensemble member
    # Split the filename by the final /
    filenames_split = [file.split("/")[-1] for file in filenames]

    # Split the filename by the _
    model_name = [file.split("_")[2] for file in filenames_split]

    # Split the filename by the _
    variant_label = [file.split("_")[4].split("-")[1] for file in filenames_split]

    # Extract the unique model names
    model_name = np.unique(model_name)[0]

    # Extract the unique variant labels
    variant_label = np.unique(variant_label)[0]

    # Set the ensemble member
    ds['ensemble_member'] = [f"{model_name}_{variant_label}"]

    # Extract the years from the data
    years = ds.time.dt.year.values

    # Find the unique years
    unique_years = np.unique(years)

    # If forecast range contains a hyphen
    if "-" in forecast_range:
        start_year_idx = int(forecast_range.split("-")[0])
        end_year_idx = int(forecast_range.split("-")[1])
    else:
        start_year_idx = int(forecast_range)
        end_year_idx = int(forecast_range)

    # Extract the first year
    first_year = int(unique_years[start_year_idx - 2])

    # Extract the last year
    last_year = int(unique_years[end_year_idx - 2])

    # Form the strings for the start and end dates
    start_date = f"{first_year}-01-01" ; end_date = f"{last_year + 1}-01-30"

    # Find the centre of the period between start and end date
    mid_date = pd.to_datetime(start_date) + (pd.to_datetime(end_date) - pd.to_datetime(start_date)) / 2

    # Take the mean over the time dimension
    ds = ds.sel(time=slice(start_date, end_date)).mean(dim='time')

    # Set the time to the mid date
    ds['time'] = mid_date

    # Return the dataset
    return ds

In [18]:
# Limit the paths to the first 7
paths = paths[:7]

In [23]:
bcc_test = []

# Loop over the paths
for path in paths:

    # Load the data
    ds = xr.open_mfdataset(path,
                           preprocess=lambda ds: preprocess(ds, forecast_range="2-9",
                                                            filenames=path),
                           combine='nested',
                           concat_dim='time',
                           join='override',
                           coords='minimal',
                           engine='netcdf4',
                           parallel=True,)
    
    # Append the data to the list
    bcc_test.append(ds)

# Concatenate the data
bcc_test = xr.concat(bcc_test, dim='ensemble_member')

In [24]:
bcc_test

Unnamed: 0,Array,Chunk
Bytes,14.95 MiB,40.50 kiB
Shape,"(54, 7, 72, 144)","(1, 1, 72, 144)"
Dask graph,378 chunks in 2654 graph layers,378 chunks in 2654 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 14.95 MiB 40.50 kiB Shape (54, 7, 72, 144) (1, 1, 72, 144) Dask graph 378 chunks in 2654 graph layers Data type float32 numpy.ndarray",54  1  144  72  7,

Unnamed: 0,Array,Chunk
Bytes,14.95 MiB,40.50 kiB
Shape,"(54, 7, 72, 144)","(1, 1, 72, 144)"
Dask graph,378 chunks in 2654 graph layers,378 chunks in 2654 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
