## Removing model climatology w/iris/xarray ##

In [1]:
# Import local modules
import os
import sys
import argparse
import re
import glob

# Import third-party modules
from time import time
import numpy as np
import xarray as xr
import iris
import iris.coord_categorisation as icc
from iris.time import PartialDateTime

# import the functions
import remove_model_clim as rmc_func

# import tqdm
from tqdm import tqdm

# Import local modules
sys.path.append("/home/users/benhutch/lagging-NAO-test-suite/")
import dictionaries as dicts

# import modules for reading in the data
sys.path.append("/home/users/benhutch/unseen_functions/")
import functions as unseen_funcs

  _set_context_ca_bundle_path(ca_bundle_path)


In [2]:
# Set up the list of ua models
ua_models = [
    "BCC-CSM2-MR",
    "MPI-ESM1-2-HR",
    "CanESM5",
    "CMCC-CM2-SR5",
    "HadGEM3-GC31-MM",
    "EC-Earth3",
    "FGOALS-f3-L",
    "MIROC6",
    "IPSL-CM6A-LR",
]

# sensible models (those which don't prduce nan values)
sensible_models = [
    "BCC-CSM2-MR",
    "MPI-ESM1-2-HR",
    "CanESM5",
    "EC-Earth3",
]

# Set up the variables
saved_dir = "/gws/nopw/j04/canari/users/benhutch/saved_DePre/skill-maps-arrays/saved_data/"
model_season = "AYULGS"
model_variable = "ua"

# initialize the no members
no_members = 0

# Loop over the models 
for model in sensible_models:
    # Set up the file where it is stored
    members_path = f"{saved_dir}{model}/{model_season}/{model_variable}/*members.npy"

    # if this path does not exist then exit
    if not glob.glob(members_path):
        print(f"no members for {model}")
        continue

    # Get the members
    members_file = glob.glob(members_path)

    # load the members
    members_this = np.load(members_file[0])

    # Get the number of members
    no_members += members_this.shape[0]

# Print the number of members
print(f"Number of members: {no_members}")

# import the first anoms file
anoms_file = f"{saved_dir}{ua_models[0]}/{model_season}/{model_variable}/*anoms.npy"

# if the file doesn't exist, exit
if not glob.glob(anoms_file):
    print(f"no anoms file for {ua_models[0]}")
    sys.exit()

# Get the anoms file
anoms_file = glob.glob(anoms_file)

# Set up an empty array to store the anoms
anoms = np.load(anoms_file[0])

# Set up the array to store the data
all_models_array = np.zeros([
    np.shape(anoms)[0],
    no_members,
    np.shape(anoms)[2],
    np.shape(anoms)[3],
])

# print the shape of the all models array
print(f"Shape of all models array: {np.shape(all_models_array)}")

# Set up the ticker
ticker = 0

# Loop over the members
for i, model in tqdm(enumerate(sensible_models)):
    # Find the anomalies file
    model_file_path = f"{saved_dir}{model}/{model_season}/{model_variable}/*anoms.npy"
    members_path = f"{saved_dir}{model}/{model_season}/{model_variable}/*members.npy"

    # if the file doesn't exist, exit
    if not glob.glob(model_file_path):
        print(f"no anoms file for {model}")
        sys.exit()

    # Get the anoms file
    model_file = glob.glob(model_file_path)

    # Get the members file
    members_file = glob.glob(members_path)

    # Load the data
    model_data = np.load(model_file[0])

    # Load the members
    model_members = np.load(members_file[0])

    # Get the number of members
    no_members = model_members.shape[0]

    # print the model name
    print(f"Model: {model}")

    # print the model mean
    print(f"Model mean: {np.mean(model_data)}")

    # print the model spread
    print(f"Model spread: {np.std(model_data)}")

    # print the model mean nan
    print(f"Model mean nan: {np.nanmean(model_data)}")

    # print the model spread nan
    print(f"Model spread nan: {np.nanstd(model_data)}")


    if i == 0:
        # Append to the model array
        all_models_array[:, :no_members, :, :] = model_data
    else:
        # Append to the model array
        all_models_array[:, ticker : ticker + no_members, :, :] = model_data

    # Update the ticker
    ticker += no_members

Number of members: 48
Shape of all models array: (54, 48, 72, 144)


0it [00:00, ?it/s]

1it [00:00,  2.58it/s]

Model: BCC-CSM2-MR
Model mean: 4.56505611090563e-10
Model spread: 0.3356282413005829
Model mean nan: 4.56505611090563e-10
Model spread nan: 0.3356282413005829
Model: MPI-ESM1-2-HR
Model mean: 2.8071689417430434e-10
Model spread: 0.32524484395980835
Model mean nan: 2.8071689417430434e-10
Model spread nan: 0.32524484395980835
Model: CanESM5
Model mean: -6.513721784173754e-10
Model spread: 0.3120632469654083
Model mean nan: -6.513721784173754e-10


4it [00:00,  6.05it/s]

Model spread nan: 0.3120632469654083
Model: EC-Earth3
Model mean: 5.464440566704809e-10
Model spread: 0.32632938027381897
Model mean nan: 5.464440566704809e-10
Model spread nan: 0.32632938027381897





In [3]:
# print the all models array
print(all_models_array[:, :, :, :].shape)

(54, 48, 72, 144)


In [4]:
# print the shape of the all models array
print(f"Shape of all models array: {np.shape(all_models_array)}")

Shape of all models array: (54, 68, 72, 144)


In [5]:
# print the mean of the all models array
print(f"Mean of all models array: {np.mean(all_models_array)}")

Mean of all models array: -4.2444973295250995e-10


In [6]:
# print the spread of the all models array
print(f"Spread of all models array: {np.std(all_models_array)}")

Spread of all models array: 0.32339902043627533


In [7]:
# print the min and max values of the all models array
print(f"Min value of all models array: {np.min(all_models_array)}")
print(f"Max value of all models array: {np.max(all_models_array)}")

Min value of all models array: -2.8628416061401367
Max value of all models array: 3.5465149879455566


In [8]:
# check whether there are any nan values
nan_values = np.isnan(all_models_array)

# print the number of nan values
print(f"Number of nan values: {np.sum(nan_values)}")

Number of nan values: 0


In [3]:
print(all_models_array.shape)

(54, 68, 72, 144)


In [6]:
import time

current_time = time.time()
print(current_time)

1736764803.4214742


In [4]:
# # set up the save dir
save_dir = "/gws/nopw/j04/canari/users/benhutch/alternate-lag-processed-data/test-sfcWind/"

# test_fname = "tas_DJFM_global_1964_2014_2-9_4_1730914143.9212036_alternate_lag.npy"

# # load the file
# test_data = np.load(f"{save_dir}{test_fname}")

# # print the shape of the test data
# print(test_data.shape)

In [7]:
# Set up the current time
current_time = time.time()

# set up the fname
fpath = os.path.join(save_dir, f"{model_variable}_{model_season}_1961_2014_2-9_4_{current_time}.npy")

# if the fpath does not exist, save the data
if not os.path.exists(fpath):
    np.save(fpath, all_models_array)

10

In [19]:
# Process the lagging for the AYULGS season
lag0_all_models_array = all_models_array[:-3, :, :, :]
lag1_all_models_array = all_models_array[1:-2, :, :, :]
lag2_all_models_array = all_models_array[2:-1, :, :, :]
lag3_all_models_array = all_models_array[3:, :, :, :]

# Set up the new shape for the data
lagged_all_models_array = np.zeros(
    [
        np.shape(lag0_all_models_array)[0],
        np.shape(all_models_array)[1] * 4,
        np.shape(all_models_array)[2],
        np.shape(all_models_array)[3],
    ]
)

# Form a list of the lagged arrays
lagged_arr_list = [
    lag0_all_models_array,
    lag1_all_models_array,
    lag2_all_models_array,
    lag3_all_models_array,
]

# Set up a ticker
ticker = 0

# Set uyp raw members
raw_members = 107

# Loop over the lagged arrays
for i, lagged_arr in enumerate(lagged_arr_list):
    if i == 0:
        lagged_all_models_array[:, :raw_members, :, :] = lagged_arr
    else:
        lagged_all_models_array[:, ticker : ticker + raw_members, :, :] = lagged_arr

    # Update the ticker
    ticker += raw_members

# print the shape of the lagged all models array
print(lagged_all_models_array.shape)

(51, 428, 72, 144)


In [37]:
# check whether there are any NaNs in the data
# any zeros 

True


In [26]:
# set up the current time
current_time = time.time()

# set up the lagged fpath
lagged_fpath = os.path.join(save_dir, f"{model_variable}_{model_season}_global_1964_2014_2-9_4_20_{current_time}_alternate_lag.npy")

# if the fpath does not exist, save the data
if not os.path.exists(lagged_fpath):
    np.save(lagged_fpath, lagged_all_models_array)

In [25]:
lagged_fpath

'/gws/nopw/j04/canari/users/benhutch/alternate-lag-processed-data/test-sfcWind/ua_AYULGS_global_1964_2014_2-9_4_20_1736523627.9138052_alternate_lag.npy'

In [2]:
# Set up the variables
model = "IPSL-CM6A-LR" # all files in this folder for some reason
variable = "ua"
season = "AYULGS"
start_year = 1961
end_year = 2014
region = "global"
forecast_range = "2-9"
frequency = "Amon"

# test init year
init_years = np.arange(1961, 2014 + 1, 1)

In [3]:
# Find the files
files_dir = "/work/scratch-nopw2/benhutch/ua/CanESM5/global/all_forecast_years/AYULGS/outputs/"

# # Find the files in the directory
# # for the model, season, and init year
# files = glob.glob(f"{files_dir}/*{season}*s{init_years[0]}*.nc")

# # print the files
# print(files)

files_list = []

# loop over the init years
for init_year in tqdm(init_years):
    # Find the files in the directory
    # for the model, season, and init year
    files_this = glob.glob(f"{files_dir}/*{season}*{model}*s{init_year}*.nc")

    # append the files this to the files list
    files_list.extend(files_this)

# print the files list
print(files_list[0])
print(files_list[-1])

# Regular expression pattern for the desired format
pattern = re.compile(r".*s\d{4}-(r\d+i\d+p\d+f\d+).*")

# Extract the 'r*i?p?f?' part from each directory name
extracted_parts = [
    pattern.match(file).group(1) for file in files_list if pattern.match(file)
]

# if the model is CanESM5
# limit to those containing "p2"
# if model == "CanESM5":
#     extracted_parts = [
#         part for part in extracted_parts if "p2" in part
#     ]

# Print the extracted parts
print("Extracted parts:", extracted_parts)

# Find the unique combinations of r*i?p?f?
unique_combinations = np.unique(extracted_parts)
print("Unique combinations:", unique_combinations)
print("Number of unique combinations:", len(unique_combinations))
# sys.exit()

# Set up a list for the init years
init_year_list = []

# loop over the init years
for init_year in tqdm(init_years):
    # Empty member list
    member_list = []
    for variant_label in tqdm(unique_combinations):
        # Find the files with the variant label
        files_this = [
                file for file in files_list if f"s{init_year}-{variant_label}" in file
            ]
        
        # # print the files this
        # print(files_this)

        # Open all leads for specified variant label
        # and init_year
        member_ds = xr.open_mfdataset(
            files_this,
            combine="nested",
            concat_dim="time",
            preprocess=lambda ds: unseen_funcs.preprocess_boilerplate(ds),
            parallel=False,
            engine="netcdf4",
            coords="minimal",  # expecting identical coords
            data_vars="minimal",  # expecting identical vars
            compat="override",  # speed up
        ).squeeze()

        # init_year = start_year and variant_label is unique_variant_labels[0]
        if init_year == start_year and variant_label == unique_combinations[0]:
            # Set new int time
            member_ds = unseen_funcs.set_integer_time_axis(
                xro=member_ds, frequency=frequency, first_month_attr=True
            )
        else:
            # Set new integer time
            member_ds = unseen_funcs.set_integer_time_axis(member_ds, frequency=frequency)

        # Append the member dataset to the member list
        member_list.append(member_ds)
    # Concatenate the member list along the ensemble_member dimension
    member_ds = xr.concat(member_list, "member")
    # Append the member dataset to the init_year list
    init_year_list.append(member_ds)

# Concatenate the init_year list along the init dimension
# and rename as lead time
ds = xr.concat(init_year_list, "init").rename({"time": "lead"})

# filter the items in unique variaant labels to extract just the ints
# Apply the function to extract the numeric parts
numeric_labels = [unseen_funcs.extract_numeric(label) for label in unique_combinations]

# Set up the members
ds["member"] = numeric_labels
ds["init"] = np.arange(init_years[0], init_years[-1] + 1)

  0%|          | 0/54 [00:00<?, ?it/s]

100%|██████████| 54/54 [00:00<00:00, 149.56it/s]


/work/scratch-nopw2/benhutch/ua/CanESM5/global/all_forecast_years/AYULGS/outputs/all-years-AYULGS-global-ua_Amon_IPSL-CM6A-LR_dcppA-hindcast_s1961-r2i1p1f1_gr_196201-197112.nc
/work/scratch-nopw2/benhutch/ua/CanESM5/global/all_forecast_years/AYULGS/outputs/all-years-AYULGS-global-ua_Amon_IPSL-CM6A-LR_dcppA-hindcast_s2014-r8i1p1f1_gr_201501-202412.nc
Extracted parts: ['r2i1p1f1', 'r1i1p1f1', 'r10i1p1f1', 'r6i1p1f1', 'r4i1p1f1', 'r3i1p1f1', 'r5i1p1f1', 'r9i1p1f1', 'r7i1p1f1', 'r8i1p1f1', 'r2i1p1f1', 'r1i1p1f1', 'r10i1p1f1', 'r6i1p1f1', 'r4i1p1f1', 'r3i1p1f1', 'r5i1p1f1', 'r9i1p1f1', 'r7i1p1f1', 'r8i1p1f1', 'r2i1p1f1', 'r1i1p1f1', 'r10i1p1f1', 'r6i1p1f1', 'r4i1p1f1', 'r3i1p1f1', 'r5i1p1f1', 'r9i1p1f1', 'r7i1p1f1', 'r8i1p1f1', 'r2i1p1f1', 'r1i1p1f1', 'r10i1p1f1', 'r6i1p1f1', 'r4i1p1f1', 'r3i1p1f1', 'r5i1p1f1', 'r9i1p1f1', 'r7i1p1f1', 'r8i1p1f1', 'r2i1p1f1', 'r1i1p1f1', 'r10i1p1f1', 'r6i1p1f1', 'r4i1p1f1', 'r3i1p1f1', 'r5i1p1f1', 'r9i1p1f1', 'r7i1p1f1', 'r8i1p1f1', 'r2i1p1f1', 'r1i1p1f1', '

100%|██████████| 10/10 [00:00<00:00, 30.71it/s]
100%|██████████| 10/10 [00:00<00:00, 37.65it/s]
100%|██████████| 10/10 [00:00<00:00, 37.45it/s]
100%|██████████| 10/10 [00:00<00:00, 35.74it/s]
100%|██████████| 10/10 [00:00<00:00, 34.59it/s]
100%|██████████| 10/10 [00:00<00:00, 30.77it/s]
100%|██████████| 10/10 [00:00<00:00, 35.87it/s]
100%|██████████| 10/10 [00:00<00:00, 37.31it/s]
100%|██████████| 10/10 [00:00<00:00, 35.63it/s]
100%|██████████| 10/10 [00:00<00:00, 37.12it/s]
100%|██████████| 10/10 [00:00<00:00, 37.36it/s]
100%|██████████| 10/10 [00:00<00:00, 33.76it/s]
100%|██████████| 10/10 [00:00<00:00, 36.16it/s]
100%|██████████| 10/10 [00:00<00:00, 34.84it/s]
100%|██████████| 10/10 [00:00<00:00, 36.93it/s]
100%|██████████| 10/10 [00:00<00:00, 36.40it/s]
100%|██████████| 10/10 [00:00<00:00, 36.47it/s]
100%|██████████| 10/10 [00:00<00:00, 37.19it/s]
100%|██████████| 10/10 [00:00<00:00, 32.56it/s]
100%|██████████| 10/10 [00:00<00:00, 36.92it/s]
100%|██████████| 10/10 [00:00<00:00, 36.

In [4]:
level = 85000

# print the plevs
print(ds["plev"])

# extrcat the level
ds_level = ds.sel(plev=level)

<xarray.DataArray 'plev' (plev: 19)> Size: 76B
array([100000.,  92500.,  85000.,  70000.,  60000.,  50000.,  40000.,  30000.,
        25000.,  20000.,  15000.,  10000.,   7000.,   5000.,   3000.,   2000.,
         1000.,    500.,    100.], dtype=float32)
Coordinates:
  * plev     (plev) float32 76B 1e+05 9.25e+04 8.5e+04 ... 1e+03 500.0 100.0
Attributes:
    axis:           Z
    units:          Pa
    standard_name:  air_pressure
    long_name:      pressure
    positive:       down


In [6]:
ds_level

# # print the size of the dataset in mb
# print(ds_level.nbytes / 1e6)

Unnamed: 0,Array,Chunk
Bytes,213.57 MiB,405.00 kiB
Shape,"(54, 10, 10, 72, 144)","(1, 1, 10, 72, 144)"
Dask graph,540 chunks in 1730 graph layers,540 chunks in 1730 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 213.57 MiB 405.00 kiB Shape (54, 10, 10, 72, 144) (1, 1, 10, 72, 144) Dask graph 540 chunks in 1730 graph layers Data type float32 numpy.ndarray",10  54  144  72  10,

Unnamed: 0,Array,Chunk
Bytes,213.57 MiB,405.00 kiB
Shape,"(54, 10, 10, 72, 144)","(1, 1, 10, 72, 144)"
Dask graph,540 chunks in 1730 graph layers,540 chunks in 1730 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [7]:
model

'IPSL-CM6A-LR'

In [8]:
# if the model is CanESM5
if model == "CanESM5":
    # Selet the appropriate lead winters
    lead_start_idx = 1
    lead_end_idx = 8
elif model == "BCC-CSM2-MR":
    # Select the appropriate lead winters
    lead_start_idx = 2 # because s1961 is initialized in jan 1961
    lead_end_idx = 9 # equivalent to 1970 for s1961
else:
    print("nov init models")

    # Selet the appropriate lead winters
    lead_start_idx = 1
    lead_end_idx = 8

# extract the lead times in the range 1:8
ds_level_lead = ds_level.isel(lead=slice(lead_start_idx - 1, lead_end_idx))


nov init models


In [9]:
ds_level_lead

Unnamed: 0,Array,Chunk
Bytes,170.86 MiB,324.00 kiB
Shape,"(54, 10, 8, 72, 144)","(1, 1, 8, 72, 144)"
Dask graph,540 chunks in 1731 graph layers,540 chunks in 1731 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 170.86 MiB 324.00 kiB Shape (54, 10, 8, 72, 144) (1, 1, 8, 72, 144) Dask graph 540 chunks in 1731 graph layers Data type float32 numpy.ndarray",10  54  144  72  8,

Unnamed: 0,Array,Chunk
Bytes,170.86 MiB,324.00 kiB
Shape,"(54, 10, 8, 72, 144)","(1, 1, 8, 72, 144)"
Dask graph,540 chunks in 1731 graph layers,540 chunks in 1731 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
%%time

# calculate the anoms
ds_level_lead_clim =ds_level_lead.groupby("init").mean("lead").mean("member")

# print the climatology
print(ds_level_lead_clim["ua"].values)

[[[9.9692100e+36 9.9692100e+36 9.9692100e+36 ... 9.9692100e+36
   9.9692100e+36 9.9692100e+36]
  [9.9692100e+36 9.9692100e+36 9.9692100e+36 ... 9.9692100e+36
   9.9692100e+36 9.9692100e+36]
  [9.9692100e+36 9.9692100e+36 9.9692100e+36 ... 9.9692100e+36
   9.9692100e+36 9.9692100e+36]
  ...
  [1.5286477e+00 1.5664822e+00 1.6013638e+00 ... 1.4086828e+00
   1.4492314e+00 1.4889749e+00]
  [1.5111630e+00 1.5322226e+00 1.5478671e+00 ... 1.4322027e+00
   1.4609129e+00 1.4866323e+00]
  [1.0627997e+00 1.0734419e+00 1.0788283e+00 ... 1.0206864e+00
   1.0364944e+00 1.0495173e+00]]

 [[9.9692100e+36 9.9692100e+36 9.9692100e+36 ... 9.9692100e+36
   9.9692100e+36 9.9692100e+36]
  [9.9692100e+36 9.9692100e+36 9.9692100e+36 ... 9.9692100e+36
   9.9692100e+36 9.9692100e+36]
  [9.9692100e+36 9.9692100e+36 9.9692100e+36 ... 9.9692100e+36
   9.9692100e+36 9.9692100e+36]
  ...
  [1.4173652e+00 1.4572709e+00 1.4937508e+00 ... 1.2920731e+00
   1.3342743e+00 1.3757168e+00]
  [1.4566053e+00 1.4794064e+00 1.497

In [11]:
# remove the climatology
ds_level_lead_anom = ds_level_lead - ds_level_lead_clim

In [12]:
# take the mean over leads
ds_level_lead_anom_mean = ds_level_lead_anom.mean("lead")

In [13]:
ds_level_lead_anom_mean["ua"].shape

(54, 10, 72, 144)

In [14]:
ds_level_lead_anom_mean

Unnamed: 0,Array,Chunk
Bytes,21.36 MiB,40.50 kiB
Shape,"(54, 10, 72, 144)","(1, 1, 72, 144)"
Dask graph,540 chunks in 1903 graph layers,540 chunks in 1903 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 21.36 MiB 40.50 kiB Shape (54, 10, 72, 144) (1, 1, 72, 144) Dask graph 540 chunks in 1903 graph layers Data type float32 numpy.ndarray",54  1  144  72  10,

Unnamed: 0,Array,Chunk
Bytes,21.36 MiB,40.50 kiB
Shape,"(54, 10, 72, 144)","(1, 1, 72, 144)"
Dask graph,540 chunks in 1903 graph layers,540 chunks in 1903 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [15]:
%%time

# save_dir
save_dir = f"/gws/nopw/j04/canari/users/benhutch/saved_DePre/skill-maps-arrays/saved_data/{model}/{season}/{variable}/"

# if the save directory does not exist
if not os.path.exists(save_dir):
    # make the directory
    os.makedirs(save_dir)

# set up fnames for the init years file
fname_init_years = f"{save_dir}{model}_{variable}_{season}_{start_year}-{end_year}_init_years.npy"

# Set up a file name for the members 
fname_members = f"{save_dir}{model}_{variable}_{season}_{start_year}-{end_year}_members.npy"

# Set up a file name for the lats
fname_lats = f"{save_dir}{model}_{variable}_{season}_{start_year}-{end_year}_lats.npy"

# Set up a file name for the lons
fname_lons = f"{save_dir}{model}_{variable}_{season}_{start_year}-{end_year}_lons.npy"

# set iup a file name for the plev
fname_plev = f"{save_dir}{model}_{variable}_{season}_{start_year}-{end_year}_plev.txt"

# set up a file name for the anoms
fname_anoms = f"{save_dir}{model}_{variable}_{season}_{start_year}-{end_year}_anoms.npy"

# save the init years
np.save(fname_init_years, ds_level_lead_anom_mean["init"].values)

# save the members
np.save(fname_members, ds_level_lead_anom_mean["member"].values)

# save the lats
np.save(fname_lats, ds_level_lead_anom_mean["lat"].values)

# save the lons
np.save(fname_lons, ds_level_lead_anom_mean["lon"].values)

# save the plev
np.savetxt(fname_plev, [level])

# save the anoms
np.save(fname_anoms, ds_level_lead_anom_mean["ua"].values)

CPU times: user 2.46 s, sys: 6.1 s, total: 8.56 s
Wall time: 18.7 s


In [16]:
save_dir

'/gws/nopw/j04/canari/users/benhutch/saved_DePre/skill-maps-arrays/saved_data/IPSL-CM6A-LR/AYULGS/ua/'

In [17]:
# import the files
init_years = np.load(fname_init_years)
members = np.load(fname_members)
lats = np.load(fname_lats)
lons = np.load(fname_lons)
plevs = np.loadtxt(fname_plev)
anoms = np.load(fname_anoms)

# print the init years
print(init_years)
print(members)
print(lats)
print(lons)
print(plevs)
print(anoms)

[1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974
 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988
 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002
 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014]
[10  1  2  3  4  5  6  7  8  9]
[-90.  -87.5 -85.  -82.5 -80.  -77.5 -75.  -72.5 -70.  -67.5 -65.  -62.5
 -60.  -57.5 -55.  -52.5 -50.  -47.5 -45.  -42.5 -40.  -37.5 -35.  -32.5
 -30.  -27.5 -25.  -22.5 -20.  -17.5 -15.  -12.5 -10.   -7.5  -5.   -2.5
   0.    2.5   5.    7.5  10.   12.5  15.   17.5  20.   22.5  25.   27.5
  30.   32.5  35.   37.5  40.   42.5  45.   47.5  50.   52.5  55.   57.5
  60.   62.5  65.   67.5  70.   72.5  75.   77.5  80.   82.5  85.   87.5]
[-180.  -177.5 -175.  -172.5 -170.  -167.5 -165.  -162.5 -160.  -157.5
 -155.  -152.5 -150.  -147.5 -145.  -142.5 -140.  -137.5 -135.  -132.5
 -130.  -127.5 -125.  -122.5 -120.  -117.5 -115.  -112.5 -110.  -107.5
 -105.  -102.5 -100.   -97.5  -95.   -92.

In [18]:
print(anoms.shape)

(54, 10, 72, 144)


In [18]:
# perform the ensemble lagging
lag0_anoms = anoms[:, :, :, :]
lag1_anoms = anoms[1:, :, :, :]
lag2_anoms = anoms[2:, :, :, :]
lag3_anoms = anoms[3:, :, :, :]

# print the shapes
print(lag0_anoms.shape)
print(lag1_anoms.shape)
print(lag2_anoms.shape)
print(lag3_anoms.shape)

# print the first value of each
print(lag0_anoms[0, 0, 0, 0])
print(lag1_anoms[0, 0, 0, 0])
print(lag2_anoms[0, 0, 0, 0])
print(lag3_anoms[0, 0, 0, 0])

# cut off the ends
lag0_anoms_cut = lag0_anoms[:-3, :, :, :]
lag1_anoms_cut = lag1_anoms[:-2 :, :, :]
lag2_anoms_cut = lag2_anoms[:-1, :, :, :]
lag3_anoms_cut = lag3_anoms[:, :, :, :]

# print the shapes
print(lag0_anoms_cut.shape)
print(lag1_anoms_cut.shape)
print(lag2_anoms_cut.shape)
print(lag3_anoms_cut.shape)

# assert that none of the values are nan
assert not np.any(np.isnan(lag0_anoms_cut))
assert not np.any(np.isnan(lag1_anoms_cut))
assert not np.any(np.isnan(lag2_anoms_cut))
assert not np.any(np.isnan(lag3_anoms_cut))

# stack these via the 1th axis
stacked_anoms = np.stack([lag0_anoms_cut, lag1_anoms_cut, lag2_anoms_cut, lag3_anoms_cut], axis=1)

(54, 20, 72, 144)
(53, 20, 72, 144)
(52, 20, 72, 144)
(51, 20, 72, 144)
-0.12641631
-0.014120072
-0.2626549
-0.073126465
(51, 20, 72, 144)
(51, 20, 72, 144)
(51, 20, 72, 144)
(51, 20, 72, 144)


In [19]:
print(stacked_anoms.shape)

# flatten the 1th and 2th axes
flattened_anoms = stacked_anoms.reshape(
    np.shape(stacked_anoms)[0], 
    np.shape(stacked_anoms)[1] * np.shape(stacked_anoms)[2],
    np.shape(stacked_anoms)[3],
    np.shape(stacked_anoms)[4]
)

(51, 4, 20, 72, 144)


In [20]:
# assert that there are no nans in the flattened anoms
assert not np.any(np.isnan(flattened_anoms))