## Alternate lag testing ##

Here we test a new methodology for the alternate lagging. Having calculated the anomalies for all of the models, ensemble members, start dates, and forecast years in the '*calc_anoms_suite*', we want to load these into python to form arrays with shapes like:

(178, 60, 11, 72, 144)

Where the dimensions are as following:

* 178 total ensemble members (from all of the models)
* 60 start dates (~1960-2020)
* 11 forecast years (could this differ between models - *may have to watch out for*)
* 72 latitude bands
* 144 longitude bands

As a first exercise, it would be useful to load in an array for a single model, in this case: BCC-CSM2-MR. The shape would look something like:

(8, 60, 11, 72, 144)

In [1]:
# Import local modules
import sys
import os
import argparse

# Import 3rd party modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr

In [2]:
# Set up the arguments
base_dir = "/gws/nopw/j04/canari/users/benhutch/skill-maps-processed-data"
variable = "psl"
model = "HadGEM3-GC31-MM"
region = "global"
forecast_range = "all_forecast_years"
season = "DJFM"

In [4]:
# Form the directory path
dir_path = os.path.join(base_dir, variable, model, region, forecast_range,
                        season, "outputs", "anoms")

print(dir_path)

# List the files ending with *.nc in this directory
file_list = [f for f in os.listdir(dir_path) if f.endswith(".nc")]

# Print the list of files
print(file_list)

# Find the file containing "s1970" and "r1i1"
test_file = [f for f in file_list if "s1961" in f and "r1i1" in f][0]

# Print the test file
print(test_file)

/gws/nopw/j04/canari/users/benhutch/skill-maps-processed-data/psl/HadGEM3-GC31-MM/global/all_forecast_years/DJFM/outputs/anoms
['all-years-DJFM-global-psl_Amon_HadGEM3-GC31-MM_dcppA-hindcast_s1960-r10i1_gn_196011-197103-anoms.nc', 'all-years-DJFM-global-psl_Amon_HadGEM3-GC31-MM_dcppA-hindcast_s1960-r1i1_gn_196011-197103-anoms.nc', 'all-years-DJFM-global-psl_Amon_HadGEM3-GC31-MM_dcppA-hindcast_s1960-r2i1_gn_196011-197103-anoms.nc', 'all-years-DJFM-global-psl_Amon_HadGEM3-GC31-MM_dcppA-hindcast_s1960-r3i1_gn_196011-197103-anoms.nc', 'all-years-DJFM-global-psl_Amon_HadGEM3-GC31-MM_dcppA-hindcast_s1960-r4i1_gn_196011-197103-anoms.nc', 'all-years-DJFM-global-psl_Amon_HadGEM3-GC31-MM_dcppA-hindcast_s1960-r5i1_gn_196011-197103-anoms.nc', 'all-years-DJFM-global-psl_Amon_HadGEM3-GC31-MM_dcppA-hindcast_s1960-r6i1_gn_196011-197103-anoms.nc', 'all-years-DJFM-global-psl_Amon_HadGEM3-GC31-MM_dcppA-hindcast_s1960-r7i1_gn_196011-197103-anoms.nc', 'all-years-DJFM-global-psl_Amon_HadGEM3-GC31-MM_dcppA-h

In [5]:
# Load in the test file using xarray
test_ds = xr.open_dataset(os.path.join(dir_path, test_file))

test_ds

# Extract the data for the variable
test_ds_psl = test_ds[variable]

# Print the data
print(test_ds_psl)

# Extract the years from the time dimension
test_ds_psl_years = test_ds_psl.time.dt.year

<xarray.DataArray 'psl' (time: 11, lat: 72, lon: 144)>
[114048 values with dtype=float32]
Coordinates:
  * time     (time) object 1961-11-01 00:00:00 ... 1971-11-01 00:00:00
  * lon      (lon) float64 -180.0 -177.5 -175.0 -172.5 ... 172.5 175.0 177.5
  * lat      (lat) float64 -90.0 -87.5 -85.0 -82.5 -80.0 ... 80.0 82.5 85.0 87.5
Attributes:
    standard_name:  air_pressure_at_mean_sea_level
    long_name:      Sea Level Pressure
    units:          Pa
    cell_methods:   area: time: mean
    comment:        Sea Level Pressure
    original_name:  mo: (stash: m01s16i222, lbproc: 128)
    cell_measures:  area: areacella


In [9]:
# Set up an empty array to store the data
years = np.arange(1961, 2015) # 1964 to 2014 inclusive - initialisation years

# Set up the number of years
num_years = len(years)

# Set up the number of ensemble members
nens = 10 # Now for HadGEM3-GC31-MM

# Extract the number of forecast years
no_forecast_years = test_ds_psl.shape[0]

# Set up the no of lats
no_lats = test_ds_psl.shape[1]

# Set up the no of lons
no_lons = test_ds_psl.shape[2]

# Set up an empty array to store the data
test_array = np.zeros([num_years, nens, no_forecast_years, no_lats, no_lons])

# Print the shape of the array
print(test_array.shape)

(54, 10, 11, 72, 144)


In [10]:
# Loop over the years
for i, year in enumerate(years):
    # Loop over the ensemble members
    for j in range(nens):
        # logging to know where we are
        print(" year index: ", i, " ensemble member index: ", j)
        
        # Find the file containing "s1970" and "r1i1"
        test_file = [f for f in file_list if f"s{year}" in f and f"r{j+1}i1" in f][0]
        # Load in the test file using xarray
        test_ds = xr.open_dataset(os.path.join(dir_path, test_file))
        # Extract the data for the variable
        test_ds_psl = test_ds[variable]
        # Store the data in the array
        test_array[i, j, :, :, :] = test_ds_psl

 year index:  0  ensemble member index:  0
 year index:  0  ensemble member index:  1
 year index:  0  ensemble member index:  2
 year index:  0  ensemble member index:  3
 year index:  0  ensemble member index:  4
 year index:  0  ensemble member index:  5
 year index:  0  ensemble member index:  6
 year index:  0  ensemble member index:  7
 year index:  0  ensemble member index:  8
 year index:  0  ensemble member index:  9
 year index:  1  ensemble member index:  0
 year index:  1  ensemble member index:  1
 year index:  1  ensemble member index:  2
 year index:  1  ensemble member index:  3
 year index:  1  ensemble member index:  4
 year index:  1  ensemble member index:  5
 year index:  1  ensemble member index:  6
 year index:  1  ensemble member index:  7
 year index:  1  ensemble member index:  8
 year index:  1  ensemble member index:  9
 year index:  2  ensemble member index:  0
 year index:  2  ensemble member index:  1
 year index:  2  ensemble member index:  2
 year index

In [11]:
print(test_array.shape)

# Print the array
print(test_array)

(54, 10, 11, 72, 144)
[[[[[-2.29538345e+02 -2.29538345e+02 -2.29538345e+02 ...
     -2.29538345e+02 -2.29538345e+02 -2.29538345e+02]
    [-2.44167618e+02 -2.45504974e+02 -2.46731537e+02 ...
     -2.40617188e+02 -2.42375000e+02 -2.43585220e+02]
    [-2.71251434e+02 -2.74858673e+02 -2.77549713e+02 ...
     -2.60718048e+02 -2.65893463e+02 -2.69452423e+02]
    ...
    [-3.77497864e+02 -3.82338074e+02 -3.87235809e+02 ...
     -3.63231537e+02 -3.67918335e+02 -3.72695312e+02]
    [-4.13723724e+02 -4.17587372e+02 -4.21272736e+02 ...
     -4.01080963e+02 -4.05515625e+02 -4.09715210e+02]
    [-4.02450287e+02 -4.04583099e+02 -4.06536926e+02 ...
     -3.95260651e+02 -3.97759949e+02 -4.00160522e+02]]

   [[-4.18053986e+02 -4.18053986e+02 -4.18053986e+02 ...
     -4.18053986e+02 -4.18053986e+02 -4.18053986e+02]
    [-3.79526978e+02 -3.79004974e+02 -3.77387787e+02 ...
     -3.85093750e+02 -3.82593750e+02 -3.79343048e+02]
    [-3.32001434e+02 -3.32069611e+02 -3.28471588e+02 ...
     -3.42444611e+02 -3

Now we have loaded the data into an array, we want to test how taking the alternate lag would work.

In [16]:
# Set up the parameters for the alternate lag calculation
forecast_range = "2-5"

# Write a function to calculate the lagged correlation
def alternate_lag(data: np.array,
                  forecast_range: str,
                  years: np.array,
                  lag: int = 4) -> np.array:
    """
    Calculate the lagged correlation for a given forecast range and lag.

    Parameters
    ----------
    data : np.array
        Array of data to calculate the lagged correlation for.
        Should have dimensions (num_years, nens, no_forecast_years, no_lats, no_lons).
    forecast_range : str
        The forecast range to calculate the lagged correlation for.
        This should be in the format "x-y" where x and y are integers.
    years : np.array
        Array of years to calculate the lagged correlation for.
        Should have dimensions (num_years,).
    lag : int
        The lag to calculate the lagged correlation for.
        The default is 4.

    Returns
    -------
    lagged_correlation : np.array
        Array of lagged correlation values with dimensions (num_years, nens, no_lats, no_lons).
    """

    # Assert that the forecast range is in the correct format
    assert "-" in forecast_range, "forecast_range should be in the format 'x-y' where x and y are integers"

    # Extract the forecast range
    forecast_range_list = forecast_range.split("-")

    # Extract the start and end years
    start_year = int(forecast_range_list[0]) ; end_year = int(forecast_range_list[1])

    # Assert that end year is 6 or less than start year
    assert end_year <= 6, "end_year should be 6 or less to be valid for four year lagged correlation"

    # Assert that end year is greater than start year
    assert end_year > start_year, "end_year should be greater than start_year"

    # Set up the number of lagged years
    no_lagged_years = data.shape[0] - lag + 1

    print("no_lagged_years: ", no_lagged_years)

    # Extract the lagged years
    # TODO: Fix this so that it works for any lag

    # Create an empty array to store the lagged correlation
    lagged_correlation = np.zeros([no_lagged_years, data.shape[1] * lag, data.shape[3], data.shape[4]])

    # Loop over the years
    for i in range(no_lagged_years):
        print("Processing data for lag year index: ", i)
        # Loop over the lag
        for j in range(lag):
            print("Processing data for lag index: ", j)
            # Extract the data for the lagged year
            lagged_year_data = data[i + (lag - 1) - j, :, :, :, :]

            # Print which data we are extracting
            print("Extracting data for year index: ", (lag - 1) - j)
            print("Extracting data for year: ", years[i + (lag - 1) - j])
            print("For lag index: ", j)

            # Loop over the ensemble members
            for k in range(data.shape[1]):
                # Extract the data for the ensemble member
                ensemble_member_data = lagged_year_data[k, :, :, :]

                # print the years which we are taking the mean over
                print("start year: ", start_year + j, " end year: ", end_year + j)

                # Take the mean over the forecast years
                ensemble_member_data_mean = np.mean(ensemble_member_data[start_year + j:end_year + j, :, :], axis=0)

                # Print the year index, ensemble member index and lag index
                print("year index: ", i, " ensemble member index: ", k, " lag index: ", j)

                # Print which we are appending to
                print("Appending to: year index: ", i, " ensemble member index: ", j + k * lag)

                # Append the data to the array
                lagged_correlation[i, j + k * lag, :, :] = ensemble_member_data_mean

    # Return the lagged correlation
    return lagged_correlation

In [17]:
# Test the function
lagged_correlation = alternate_lag(test_array, forecast_range, years)

# Print the shape of the lagged correlation
print(lagged_correlation.shape)

# Print the lagged correlation
print(lagged_correlation)

no_lagged_years:  51
Processing data for lag year index:  0
Processing data for lag index:  0
Extracting data for year index:  3
Extracting data for year:  1964
For lag index:  0
start year:  2  end year:  5
year index:  0  ensemble member index:  0  lag index:  0
Appending to: year index:  0  ensemble member index:  0
start year:  2  end year:  5
year index:  0  ensemble member index:  1  lag index:  0
Appending to: year index:  0  ensemble member index:  4
start year:  2  end year:  5
year index:  0  ensemble member index:  2  lag index:  0
Appending to: year index:  0  ensemble member index:  8
start year:  2  end year:  5
year index:  0  ensemble member index:  3  lag index:  0
Appending to: year index:  0  ensemble member index:  12
start year:  2  end year:  5
year index:  0  ensemble member index:  4  lag index:  0
Appending to: year index:  0  ensemble member index:  16
start year:  2  end year:  5
year index:  0  ensemble member index:  5  lag index:  0
Appending to: year inde