In [None]:
# import the relevant modules
import os
from netCDF4 import Dataset
import xarray as xr
import numpy as np

In [None]:
# processing the observations
# trying to do it correctly

# path for the full .nc file
# will likely need to load this in chunks with dask
path_obs_nc = "/home/users/benhutch/ERA5_psl/ERA5-full-data.nc"

In [None]:
# write a function to first select the iceland and azores //
# gridboxes
# also remap the obs data to a 2.5x2.5 degree grid
# first define the new grid
def create_new_grid():
    
    lon = np.arange(-180, 180, 2.5)
    lat = np.arange(-90, 90, 2.5)
    
    return xr.Dataset({'lon': (['lon'], lon),
                       'lat': (['lat'], lat)})

# define the gridspec for azores and iceland
def get_grid_spec(location):
    if location == 'azores':
        return {'lon1': -28, 'lon2': -20, 'lat1': 36, 'lat2': 40}
    elif location == 'iceland':
        return {'lon1': -25, 'lon2': -16, 'lat1': 63, 'lat2': 70}
    else:
        raise ValueError('Location must be azores or iceland')

# define a function to select the data within a grid
def select_data_within_grid(dataset, grid):
    new_grid = create_new_grid()
    # interpolate dataset to new grid
    ds_interp = dataset.interp_like(new_grid, method='linear')  # Change 'linear' to 'nearest' for nearest neighbour interpolation
    return ds_interp.sel(lon=slice(grid['lon1'], grid['lon2']), lat=slice(grid['lat1'], grid['lat2']))

# define a function to select the months DJFM
def select_months(dataset):
    return dataset.sel(time=dataset['time.season'] == 'DJFM')

# define a function which calculates the model mean state
def calculate_model_mean_state(dataset):
    model_mean_state = dataset.mean(dim='time')
    return model_mean_state

# define a function which calculates the model anomalies
def calculate_model_anomalies(dataset, model_mean_state):
    model_anomalies = dataset - model_mean_state
    return model_anomalies

# define a function which first shifts the data back by 3 months
# then calculates the annual mean anomalies
def calculate_annual_mean_anomalies(dataset):
    dataset = dataset.shift(time=-3)
    dataset = dataset.resample(time='Y').mean(dim='time')
    return dataset

# define a function which takes the azores and iceland anomalies
# and calculates the NAO index
def calculate_NAO_index(azores_anomalies, iceland_anomalies):
    
    # take the spatial mean of the azores anomalies
    azores_anomalies = azores_anomalies.mean(dim=['lat', 'lon'])
    # take the spatial mean of the iceland anomalies
    iceland_anomalies = iceland_anomalies.mean(dim=['lat', 'lon'])

    # calculate the NAO index
    NAO_index = azores_anomalies - iceland_anomalies

    return NAO_index

# define a function which takes a forward running mean of the NAO index
# for 8 years
def calculate_NAO_index_running_mean(NAO_index):
    NAO_index = NAO_index.rolling(time=8).mean()
    return NAO_index


In [None]:
# define the main function for processing the observations
# for azores and iceland
def process_observations(path_obs_nc, location):
        
        # load in the full dataset into chunks with dask
        dataset = xr.open_dataset(path_obs_nc, chunks={'time': 50})

        # get the grid spec for the location
        grid_spec = get_grid_spec(location)

        # select the data within the grid
        dataset = select_data_within_grid(dataset, grid_spec)

        # select the months DJFM
        dataset = select_months(dataset)

        # calculate the model mean state
        model_mean_state = calculate_model_mean_state(dataset)

        # calculate the model anomalies
        model_anomalies = calculate_model_anomalies(dataset, model_mean_state)

        # calculate the annual mean anomalies
        annual_mean_anomalies = calculate_annual_mean_anomalies(model_anomalies)

        # return the annual mean anomalies
        return annual_mean_anomalies

# define the main function for processing the observations
# for azores and iceland
def main(path_obs_nc):

        # process the observations for azores
        azores_annual_mean_anomalies = process_observations(path_obs_nc, 'azores')
        
        # process the observations for iceland
        iceland_annual_mean_anomalies = process_observations(path_obs_nc, 'iceland')
        
        # calculate the NAO index
        NAO_index = calculate_NAO_index(azores_annual_mean_anomalies, iceland_annual_mean_anomalies)
        
        # calculate the NAO index running mean
        NAO_index_running_mean = calculate_NAO_index_running_mean(NAO_index)
        
        # return the NAO index running mean
        return NAO_index_running_mean        

In [None]:
# run the main function
NAO_index_running_mean = main(path_obs_nc)