In [2]:
import xarray as xr
import pandas as pd
import numpy as np
import proplot as pplot
import matplotlib.pyplot as plt
import cmaps
from sklearn.linear_model import LinearRegression

import cartopy.crs as ccrs
import cartopy.feature as cfeature
from pyproj import transform
import cartopy.io.shapereader as shpreader
from cartopy.feature import ShapelyFeature

from scipy.spatial import distance
from scipy.stats import spearmanr
from sklearn.metrics.pairwise import cosine_similarity

import rioxarray
import geopandas
from shapely.geometry import mapping

import os
os.chdir('/Users/zeqinhuang/Documents/paper/HW_track')

import warnings
warnings.filterwarnings('ignore')

In [3]:
reanalyses_dir = {
    'era5':'/Volumes/Seagate_HZQ/reanalyses/era5/',
    'jra55':'/Volumes/Seagate_HZQ/reanalyses/jra55/'
}

datasets = list(reanalyses_dir.keys())

domains = {
    'D1':[80,160,10,50], #(80° to 160°E, 10° to 50° N)
    'D2':[90,150,15,45], #(90° to 150°E, 15° to 45° N)
    'D3':[100,140,20,40], #(100° to 140°E, 20° to 40° N)
}

domain_name = ['D1','D2','D3']

target_griddes = {
    'D1':{'lat': np.arange(10, 51, 1),'lon':np.arange(80, 161, 1)},
    'D2':{'lat': np.arange(15, 46, 1),'lon':np.arange(90, 151, 1)},
    'D3':{'lat': np.arange(20, 41, 1),'lon':np.arange(100, 141, 1)}
}

dist_funcs = ['E_Dist','P_Corr','S_Corr','Cosine']

In [4]:
def sel_domain(dataarray,dom):
    try:
        dataarray = dataarray.rename({'longitude':'lon','latitude':'lat'}) # for era5
    except:
        pass
    lon_min = domains[dom][0]
    lon_max = domains[dom][1]
    lat_min = domains[dom][2]
    lat_max = domains[dom][3]
    data_dom = dataarray.sel(lat = slice(lat_max,lat_min),lon = slice(lon_min,lon_max))
    return data_dom

# High-pass the Z500 and T2m data to create an un-skewed distribution of CCAs.
This operation is to avoid falsely give the impression that a particular circluation pattern induced warmer conditions (Smoliak et al. 2015).
Because surface air temperature (SAT) can be problematic to use CCAs from a climate significantly warmer or colder than the climate that one is trying to reconstruct.

A published study has mentioned that different methods to detrend or high-pass the raw anthropogenic data generally produced similar results, thus it does not matter critically for
the study here which detrending method is chosen. (See Lehner, Flavio, Clara Deser, Isla R. Simpson, and Laurent Terray. "Attributing the US Southwest's recent shift into drier conditions." Geophysical Research Letters 45, no. 12 (2018): 6251-6261.)

The high-pass procedure and code applied here are adapted from https://github.com/russellhz/extreme_heat_CCA. (also see: Horowitz, Russell L., Karen A. McKinnon, and Isla R. Simpson. “Circulation and Soil Moisture Contributions to Heatwaves in the United States.” Journal of Climate 35, no. 24 (December 15, 2022): 4431–48. https://doi.org/10.1175/JCLI-D-21-0156.1.)

In [5]:
def fourier(period, k, n):
    # Get periods for each fourier
    p = [x/period for x in range(1,k+1)]
    # Create array of ones
    result = np.ones( (n, 2*k+1) )
    # Fill in array with fourier series
    for i in range(k):
        result[:,(2*i)] = [np.sin(2*np.pi*x*p[i]) for x in range(1,n+1)]
        result[:,(2*i+1)] = [np.cos(2*np.pi*x*p[i]) for x in range(1,n+1)]
    return result

################################################################################

def seasonality_removal_worker(data, I, X):    
    # Regress fourier series on data
    B = I.dot(data)
    y_pred = X.dot(B)
    return data - y_pred

################################################################################

def seasonality_removal(data, k):
    n=len(data)
    # Create fourier series to regress
    fX = fourier(365, k=k, n=n)
    B = np.linalg.inv(fX.T.dot(fX)).dot(fX.T)
    # Stack data and apply worker function to each gridcell
    data_stack = data.stack(gridcell=('lat', 'lon'))
    result = data_stack.groupby('gridcell').map(seasonality_removal_worker, I = B, X = fX).unstack('gridcell')
    return result

################################################################################

def seasonality_removal_vals(data, k):
    n=len(data)
    # Create fourier series to regress
    fX = fourier(365, k=k, n=n)
    B = np.linalg.inv(fX.T.dot(fX)).dot(fX.T)
    # reshape data to apply worker function to each gridcell
    data_reshape = data.reshape((data.shape[0], data.shape[1] * data.shape[2]))
    result_reshape = np.apply_along_axis(seasonality_removal_worker, 0, data_reshape, I = B, X = fX)
    result = result_reshape.reshape((data.shape[0], data.shape[1], data.shape[2]))
    return result

In [None]:
# generate the MJJAS GPH500 for each dataset and each domain
for ds in datasets:
    for dom in domains:
        for i in range(1979,2023):
            GPH500_i = xr.open_dataarray(reanalyses_dir[ds] + 'GPH500_daily_' + ds + '_' + str(i) + '.nc')
            MJJAS = GPH500_i.time.dt.month.isin(range(5,10))
            GPH500_i = GPH500_i.sel(time=MJJAS)
            GPH500_i = sel_domain(GPH500_i,dom)
            if i == 1979:
                GPH500_all = GPH500_i
            else:
                GPH500_all = xr.concat([GPH500_all,GPH500_i],dim='time')
        if ds == 'era5':
            GPH500_all = GPH500_all.drop('expver')
        GPH500_all.to_netcdf('data/GPH500_' + ds + '_' + dom + '_' + '1979_2022_MJJAS_raw.nc')

In [None]:
# generate the MJJAS t2m for each dataset and each domain
for ds in datasets:
    for dom in domains:
        for i in range(1979,2023):
            if ds == 'era5':
                T2m_i = xr.open_dataarray(reanalyses_dir[ds] + '2m_temperature_daily_' + ds + '_' + str(i) + '.nc')
            else:
                T2m_i = xr.open_dataset(reanalyses_dir[ds] + 'jra_daily_t2m_' + str(i) + '.nc')['var11']
            MJJAS = T2m_i.time.dt.month.isin(range(5,10))
            T2m_i = T2m_i.sel(time=MJJAS)
            T2m_i = sel_domain(T2m_i,dom)
            if i == 1979:
                T2m_all = T2m_i
            else:
                T2m_all = xr.concat([T2m_all,T2m_i],dim='time')
        T2m_all.to_netcdf('data/T2m_' + ds + '_' + dom + '_' + '1979_2022_MJJAS_raw.nc')
  

In [7]:
# generate the GPH500 for each dataset and each domain
for ds in datasets:
    for dom in domains:
        for i in range(1979,2023):
            GPH500_i = xr.open_dataarray(reanalyses_dir[ds] + 'GPH500_daily_' + ds + '_' + str(i) + '.nc')
            GPH500_i = GPH500_i.sel(time=~((GPH500_i.time.dt.month == 2) & (GPH500_i.time.dt.day == 29)))
            GPH500_i = sel_domain(GPH500_i,dom)
            if i == 1979:
                GPH500_all = GPH500_i
            else:
                GPH500_all = xr.concat([GPH500_all,GPH500_i],dim='time')
        if ds == 'era5':
            GPH500_all = GPH500_all.drop('expver')
        GPH500_all.to_netcdf('data/GPH500_' + ds + '_' + dom + '_' + '1979_2022_raw.nc')
  

In [8]:
# generate the MJJAS t2m for each dataset and each domain
for ds in datasets:
    for dom in domains:
        for i in range(1979,2023):
            if ds == 'era5':
                T2m_i = xr.open_dataarray(reanalyses_dir[ds] + '2m_temperature_daily_' + ds + '_' + str(i) + '.nc')
            else:
                T2m_i = xr.open_dataset(reanalyses_dir[ds] + 'jra_daily_t2m_' + str(i) + '.nc')['var11']
            T2m_i = T2m_i.sel(time=~((T2m_i.time.dt.month == 2) & (T2m_i.time.dt.day == 29)))
            T2m_i = sel_domain(T2m_i,dom)
            if i == 1979:
                T2m_all = T2m_i
            else:
                T2m_all = xr.concat([T2m_all,T2m_i],dim='time')
        T2m_all.to_netcdf('data/T2m_' + ds + '_' + dom + '_' + '1979_2022_raw.nc')

In [9]:
# high-pass for raw GPH500
for ds in datasets:
    for dom in domains:
        GPH500 = xr.open_dataset('data/GPH500_' + ds + '_' + dom + '_' + '1979_2022_raw.nc')
        gph_name = {'era5':'z','jra55':'hgt'}
        GPH500_ano = GPH500.assign(anom = (('time', 'lat', 'lon'), seasonality_removal_vals(GPH500[gph_name[ds]].values, k=3)))
        GPH500_ano = GPH500_ano.anom.sel(time = GPH500_ano.time.dt.month.isin(range(5,10))) ## Select MJJAS
        GPH500_ano.to_netcdf('data/GPH500_anomalies_' + ds + '_' + dom + '_' + '1979_2022_MJJAS.nc')

In [10]:
# high-pass for raw T2m
for ds in datasets:
    for dom in domains:
        T2m = xr.open_dataset('data/T2m_' + ds + '_' + dom + '_' + '1979_2022_raw.nc')
        if ds == 'jra55':
            T2m = T2m.squeeze()
        t2m_name = {'era5':'t2m','jra55':'var11'}
        T2m_ano = T2m.assign(anom = (('time', 'lat', 'lon'), seasonality_removal_vals(T2m[t2m_name[ds]].values, k=3)))
        T2m_ano = T2m_ano.anom.sel(time = T2m_ano.time.dt.month.isin(range(5,10))) ## Select MJJAS
        T2m_ano.to_netcdf('data/T2m_anomalies_' + ds + '_' + dom + '_' + '1979_2022_MJJAS.nc')

In [27]:
# # high-pass for 5-day running average filtered GPH500
# for the aim to the reduce synoptic scale fluctuations
for ds in datasets:
    for dom in domains:
        GPH500 = xr.open_dataset('data/GPH500_' + ds + '_' + dom + '_' + '1979_2022_raw.nc')
        gph_name = {'era5':'z','jra55':'hgt'}
        GPH500 = GPH500.rolling(time=5, center=False, min_periods=1).mean()
        GPH500_ano = GPH500.assign(anom = (('time', 'lat', 'lon'), seasonality_removal_vals(GPH500[gph_name[ds]].values, k=3)))
        GPH500_ano = GPH500_ano.anom.sel(time = GPH500_ano.time.dt.month.isin(range(5,10))) ## Select MJJAS
        GPH500_ano.to_netcdf('data/GPH500_5day_running_anomalies_' + ds + '_' + dom + '_' + '1979_2022_MJJAS.nc')

In [30]:
# high-pass for 5-day running average filtered T2m
# for the aim to the reduce synoptic scale fluctuations
for ds in datasets:
    for dom in domains:
        T2m = xr.open_dataset('data/T2m_' + ds + '_' + dom + '_' + '1979_2022_raw.nc')
        if ds == 'jra55':
            T2m = T2m.squeeze()
        t2m_name = {'era5':'t2m','jra55':'var11'}
        T2m = T2m.rolling(time=5, center=False, min_periods=1).mean()
        T2m_ano = T2m.assign(anom = (('time', 'lat', 'lon'), seasonality_removal_vals(T2m[t2m_name[ds]].values, k=3)))
        T2m_ano = T2m_ano.anom.sel(time = T2m_ano.time.dt.month.isin(range(5,10))) ## Select MJJAS
        T2m_ano.to_netcdf('data/T2m_5day_running_anomalies_' + ds + '_' + dom + '_' + '1979_2022_MJJAS.nc')

In [98]:
## generate for plot
ds = 'era5'
for i in range(1979,2023):
    GPH500_i = xr.open_dataarray(reanalyses_dir[ds] + 'GPH500_daily_' + ds + '_' + str(i) + '.nc')
    MJJAS = GPH500_i.time.dt.month.isin(range(5,10))
    GPH500_i = GPH500_i.sel(time=MJJAS)
    GPH500_i = GPH500_i.sel(latitude = slice(60,10),longitude = slice(70,160))
    if i == 1979:
        GPH500_all = GPH500_i
    else:
        GPH500_all = xr.concat([GPH500_all,GPH500_i],dim='time')
if ds == 'era5':
    GPH500_all = GPH500_all.drop('expver')

gph_name = {'era5':'z','jra55':'hgt'}
GPH500_all = GPH500_all.rolling(time=5, center=False, min_periods=1).mean()
GPH500_all = GPH500_all.to_dataset()
GPH500_ano = GPH500_all.assign(anom = (('time', 'latitude', 'longitude'), seasonality_removal_vals(GPH500_all[gph_name[ds]].values, k=3)))
GPH500_ano = GPH500_ano.anom.sel(time = GPH500_ano.time.dt.month.isin(range(5,10))) ## Select MJJAS
GPH500_ano.to_netcdf('data/GPH500_5day_running_anomalies_' + ds + '_' + '1979_2022_MJJAS_lat_10_60_lon_70_160.nc')

In [146]:
GPH500_all.to_netcdf('data/GPH500_' + ds + '_' + '1979_2022_MJJAS_lat_10_60_lon_70_160.nc')

In [111]:
## generate for plot
ds = 'era5'
for i in range(1979,2023):
    u_i = xr.open_dataarray(reanalyses_dir[ds] + 'u_component_of_wind_daily_' + ds + '_' + str(i) + '.nc')
    u850_i = u_i.sel(level=850)
    MJJAS = u850_i.time.dt.month.isin(range(5,10))
    u850_i = u850_i.sel(time=MJJAS)
    u850_i = u850_i.sel(latitude = slice(60,10),longitude = slice(70,160))
    if i == 1979:
        u850_all = u850_i
    else:
        u850_all = xr.concat([u850_all,u850_i],dim='time')
if ds == 'era5':
    u850_all = u850_all.drop('expver')

In [143]:
u850_1979_2010 = u850_all.sel(time=u850_all.time.dt.year.isin(range(1979,2011)))
u850_1979_2010_shape = u850_1979_2010.values.shape
u850_1979_2010_reshape = u850_1979_2010.values.reshape((32, int(u850_1979_2010_shape[0] / 32), u850_1979_2010_shape[1], u850_1979_2010_shape[2]))
u850_1979_2010_reshape_clim_mean = u850_1979_2010_reshape.mean(axis = 0)
u850_ano = u850_all.values.reshape((44, 153, u850_1979_2010_shape[1], u850_1979_2010_shape[2])) - u850_1979_2010_reshape_clim_mean
u850_ano = u850_ano.reshape(44*153, 51, 91)
u850_ano = xr.DataArray(data=u850_ano, coords=[u850_all.time, u850_all.latitude, u850_all.longitude], dims = ['time','lat','lon'])
u850_ano.to_netcdf('data/u850_anomalies_' + ds + '_' + '1979_2022_MJJAS_lat_10_60_lon_70_160.nc')

In [112]:
# generate for plot
ds = 'era5'
for i in range(1979,2023):
    u_i = xr.open_dataarray(reanalyses_dir[ds] + 'v_component_of_wind_daily_' + ds + '_' + str(i) + '.nc')
    v850_i = u_i.sel(level=850)
    MJJAS = v850_i.time.dt.month.isin(range(5,10))
    v850_i = v850_i.sel(time=MJJAS)
    v850_i = v850_i.sel(latitude = slice(60,10),longitude = slice(70,160))
    if i == 1979:
        v850_all = v850_i
    else:
        v850_all = xr.concat([v850_all,v850_i],dim='time')
if ds == 'era5':
    v850_all = v850_all.drop('expver')

In [144]:
v850_1979_2010 = v850_all.sel(time=v850_all.time.dt.year.isin(range(1979,2011)))
v850_1979_2010_shape = v850_1979_2010.values.shape
v850_1979_2010_reshape = v850_1979_2010.values.reshape((32, int(v850_1979_2010_shape[0] / 32), v850_1979_2010_shape[1], v850_1979_2010_shape[2]))
v850_1979_2010_reshape_clim_mean = v850_1979_2010_reshape.mean(axis = 0)
v850_ano = v850_all.values.reshape((44, 153, v850_1979_2010_shape[1], v850_1979_2010_shape[2])) - v850_1979_2010_reshape_clim_mean
v850_ano = v850_ano.reshape(44*153, 51, 91)
v850_ano = xr.DataArray(data=v850_ano, coords=[v850_all.time, v850_all.latitude, v850_all.longitude], dims = ['time','lat','lon'])
v850_ano.to_netcdf('data/v850_anomalies_' + ds + '_' + '1979_2022_MJJAS_lat_10_60_lon_70_160.nc')

# Construct circulation analogues for each day in JJA (we mainly focus on summer)
The analogues are detected by minimized the distance between the targe day in a specific year and the same calendar day with a 61-day window (30 days before and after the calendar day) in all other years. Four distance functions are applied in this study, mainly, **Euclidean distance**, **Pearson correlation**, **Spearman correlation**, and **Teweles-Wobus skill score**.

The potential analogues pool for each target day: total = 61 * (44 - 1) = 2623 days

i) From the pool <span style="color:red">(**2623 days**)</span> identify the *Na* closest options (<span style="color:red">**20 days**</span>, similar to Zhuang et al who found that increase the number of analogues could not significant increase the performance);

$$S_t=S_aβ+e$$
then the temperature anomalies (the dynamical component) for the target day can be estimated:
$$T_{dc}=T_sβ$$
ii) Repeat *Nr* times (<span style="color:red">**60 times**</span>, 3 domains x 4 distance functions x every 5 days), with the mean of all *Nr* samples as the final dynamical component


In [31]:
# CCA analysis for 2022
for ds in datasets:    
    for dom in domain_name:
        GPH500_ano_mjjas = xr.open_dataarray('data/GPH500_5day_running_anomalies_' + ds + '_' + dom + '_1979_2022_MJJAS.nc')
        t2m_ano_mjjas = xr.open_dataarray('data/t2m_5day_running_anomalies_' + ds + '_' + dom + '_1979_2022_MJJAS.nc')
        all_times = GPH500_ano_mjjas.time
        all_years = list(range(1979,2023))
        n_year = len(all_years) # each year has 153 days in MJJAS
        target_year = 2022
        GPH500_ano_mjjas_2022 = GPH500_ano_mjjas.sel(time=GPH500_ano_mjjas.time.dt.year==2022)
        GPH500_ano_jja_2022 = GPH500_ano_mjjas_2022.sel(time=GPH500_ano_mjjas_2022.time.dt.month.isin(range(6,9))) # focus on JJA
        t2m_ano_mjjas_2022 = t2m_ano_mjjas.sel(time=t2m_ano_mjjas.time.dt.year==2022)
        t2m_ano_jja_2022 = t2m_ano_mjjas_2022.sel(time=t2m_ano_mjjas_2022.time.dt.month.isin(range(6,9))) # focus on JJA
        lon = GPH500_ano_jja_2022.lon
        lat = GPH500_ano_jja_2022.lat

        nlat = len(lat)
        nlon = len(lon)
        temp_OLS = np.empty((4, 5, 92,nlat*nlon)) # 4 for 4 distance functions, 5 for every other 5 days

        for dist_func in dist_funcs:
            for i in range(92):       # total 92 days in JJA
                GPH500_ano_jja_2022_i = GPH500_ano_jja_2022[i]
                pool_i_time = []
                # for y in range(len(all_years)):
                for y in range(len(all_years)):
                    if all_years[y] == target_year: # skip index for the target year
                        pass
                    else:
                        pool_i_time = pool_i_time + list(all_times[(y*153+i):(61+y*153+i)].values) # 24 and 39 represent 24th May and 7th June when i is for 1st June
                GPH500_ano_pool_i = GPH500_ano_mjjas[GPH500_ano_mjjas.time.isin(pool_i_time)]
                t2m_ano_pool_i = t2m_ano_mjjas[t2m_ano_mjjas.time.isin(pool_i_time)]
                if dist_func != 'tws':
                    ndays = GPH500_ano_pool_i.shape[0]
                    nlat  = GPH500_ano_pool_i.shape[1]
                    nlon  = GPH500_ano_pool_i.shape[2]
                    GPH500_ano_pool_i = GPH500_ano_pool_i.values.reshape((ndays,nlat*nlon))
                    GPH500_ano_pool_i_1 = GPH500_ano_pool_i[0::5]  # select every other 5 days to avoid selecting from consecutive days from the same weather event
                    GPH500_ano_pool_i_2 = GPH500_ano_pool_i[1::5]
                    GPH500_ano_pool_i_3 = GPH500_ano_pool_i[2::5]
                    GPH500_ano_pool_i_4 = GPH500_ano_pool_i[3::5]
                    GPH500_ano_pool_i_5 = GPH500_ano_pool_i[4::5]
                    t2m_ano_pool_i = t2m_ano_pool_i.values.reshape((ndays,nlat*nlon))
                    t2m_ano_pool_i_1 = t2m_ano_pool_i[0::5]  # select every other 5 days to avoid selecting from consecutive days from the same weather event
                    t2m_ano_pool_i_2 = t2m_ano_pool_i[1::5]
                    t2m_ano_pool_i_3 = t2m_ano_pool_i[2::5]
                    t2m_ano_pool_i_4 = t2m_ano_pool_i[3::5]
                    t2m_ano_pool_i_5 = t2m_ano_pool_i[4::5]
                    pool_i_time_1 = pool_i_time[0::5]
                    pool_i_time_2 = pool_i_time[1::5]
                    pool_i_time_3 = pool_i_time[2::5]
                    pool_i_time_4 = pool_i_time[3::5]
                    pool_i_time_5 = pool_i_time[4::5]
                    GPH500_ano_jja_2022_i = GPH500_ano_jja_2022_i.values.reshape((1,nlat*nlon))
                    if dist_func == dist_funcs[0]: # Euclidean distance
                        dist_pool_1 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_1, metric='euclidean')[0]
                        dist_pool_2 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_2, metric='euclidean')[0]
                        dist_pool_3 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_3, metric='euclidean')[0]
                        dist_pool_4 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_4, metric='euclidean')[0]
                        dist_pool_5 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_5, metric='euclidean')[0]
                    elif dist_func == dist_funcs[1]: # Pearson correlation
                        dist_pool_1 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_1]))[0,1:]
                        dist_pool_2 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_2]))[0,1:]
                        dist_pool_3 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_3]))[0,1:]
                        dist_pool_4 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_4]))[0,1:]
                        dist_pool_5 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_5]))[0,1:]
                    elif dist_func == dist_funcs[2]: # Spearman correlation
                        dist_pool_1 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_1]),axis=1)[0][0,1:]
                        dist_pool_2 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_2]),axis=1)[0][0,1:]
                        dist_pool_3 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_3]),axis=1)[0][0,1:]
                        dist_pool_4 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_4]),axis=1)[0][0,1:]
                        dist_pool_5 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_5]),axis=1)[0][0,1:]
                    elif dist_func == dist_funcs[3]: # Cosine similarity
                        dist_pool_1 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_1)[0]
                        dist_pool_2 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_2)[0]
                        dist_pool_3 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_3)[0]
                        dist_pool_4 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_4)[0]
                        dist_pool_5 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_5)[0]
                    if dist_func == dist_funcs[0]:
                        pat_pool_best20_1 = GPH500_ano_pool_i_1[dist_pool_1.argsort()[:20]]
                        pat_pool_best20_2 = GPH500_ano_pool_i_2[dist_pool_2.argsort()[:20]]
                        pat_pool_best20_3 = GPH500_ano_pool_i_3[dist_pool_3.argsort()[:20]]
                        pat_pool_best20_4 = GPH500_ano_pool_i_4[dist_pool_4.argsort()[:20]]
                        pat_pool_best20_5 = GPH500_ano_pool_i_5[dist_pool_5.argsort()[:20]]
                        t2m_pool_best20_1 = t2m_ano_pool_i_1[dist_pool_1.argsort()[:20]]
                        t2m_pool_best20_2 = t2m_ano_pool_i_2[dist_pool_2.argsort()[:20]]
                        t2m_pool_best20_3 = t2m_ano_pool_i_3[dist_pool_3.argsort()[:20]]
                        t2m_pool_best20_4 = t2m_ano_pool_i_4[dist_pool_4.argsort()[:20]]
                        t2m_pool_best20_5 = t2m_ano_pool_i_5[dist_pool_5.argsort()[:20]]
                        pat_pool_best20_time_1 = np.array(pool_i_time_1)[dist_pool_1.argsort()[:20]]
                        pat_pool_best20_time_2 = np.array(pool_i_time_2)[dist_pool_2.argsort()[:20]]
                        pat_pool_best20_time_3 = np.array(pool_i_time_3)[dist_pool_3.argsort()[:20]]
                        pat_pool_best20_time_4 = np.array(pool_i_time_4)[dist_pool_4.argsort()[:20]]
                        pat_pool_best20_time_5 = np.array(pool_i_time_5)[dist_pool_5.argsort()[:20]]
                    else:
                        pat_pool_best20_1 = GPH500_ano_pool_i_1[dist_pool_1.argsort()[-20:]]
                        pat_pool_best20_2 = GPH500_ano_pool_i_2[dist_pool_2.argsort()[-20:]]
                        pat_pool_best20_3 = GPH500_ano_pool_i_3[dist_pool_3.argsort()[-20:]]
                        pat_pool_best20_4 = GPH500_ano_pool_i_4[dist_pool_4.argsort()[-20:]]
                        pat_pool_best20_5 = GPH500_ano_pool_i_5[dist_pool_5.argsort()[-20:]]
                        t2m_pool_best20_1 = t2m_ano_pool_i_1[dist_pool_1.argsort()[-20:]]
                        t2m_pool_best20_2 = t2m_ano_pool_i_2[dist_pool_2.argsort()[-20:]]
                        t2m_pool_best20_3 = t2m_ano_pool_i_3[dist_pool_3.argsort()[-20:]]
                        t2m_pool_best20_4 = t2m_ano_pool_i_4[dist_pool_4.argsort()[-20:]]
                        t2m_pool_best20_5 = t2m_ano_pool_i_5[dist_pool_5.argsort()[-20:]]
                        pat_pool_best20_time_1 = np.array(pool_i_time_1)[dist_pool_1.argsort()[-20:]]
                        pat_pool_best20_time_2 = np.array(pool_i_time_2)[dist_pool_2.argsort()[-20:]]
                        pat_pool_best20_time_3 = np.array(pool_i_time_3)[dist_pool_3.argsort()[-20:]]
                        pat_pool_best20_time_4 = np.array(pool_i_time_4)[dist_pool_4.argsort()[-20:]]
                        pat_pool_best20_time_5 = np.array(pool_i_time_5)[dist_pool_5.argsort()[-20:]]
                else:
                    pass

                pat_pool_best20_da_1 = xr.DataArray(pat_pool_best20_1.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_1,lat,lon],name='Analogue circulation pattern',dims=['analogue_time','lat','lon'])
                pat_pool_best20_da_2 = xr.DataArray(pat_pool_best20_2.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_2,lat,lon],name='Analogue circulation pattern',dims=['analogue_time','lat','lon'])
                pat_pool_best20_da_3 = xr.DataArray(pat_pool_best20_3.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_3,lat,lon],name='Analogue circulation pattern',dims=['analogue_time','lat','lon'])
                pat_pool_best20_da_4 = xr.DataArray(pat_pool_best20_4.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_4,lat,lon],name='Analogue circulation pattern',dims=['analogue_time','lat','lon'])
                pat_pool_best20_da_5 = xr.DataArray(pat_pool_best20_5.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_5,lat,lon],name='Analogue circulation pattern',dims=['analogue_time','lat','lon'])

                pat_pool_best20_da_1.to_netcdf('data/similarity_pattern_2022/similarity_pattern_' + ds + '_' + dom + '_' + dist_func + '_one_' + str(i).zfill(2) + '_r.nc')
                pat_pool_best20_da_2.to_netcdf('data/similarity_pattern_2022/similarity_pattern_' + ds + '_' + dom + '_' + dist_func + '_two_' + str(i).zfill(2) + '_r.nc')
                pat_pool_best20_da_3.to_netcdf('data/similarity_pattern_2022/similarity_pattern_' + ds + '_' + dom + '_' + dist_func + '_three_' + str(i).zfill(2) + '_r.nc')
                pat_pool_best20_da_4.to_netcdf('data/similarity_pattern_2022/similarity_pattern_' + ds + '_' + dom + '_' + dist_func + '_four_' + str(i).zfill(2) + '_r.nc')
                pat_pool_best20_da_5.to_netcdf('data/similarity_pattern_2022/similarity_pattern_' + ds + '_' + dom + '_' + dist_func + '_five_' + str(i).zfill(2) + '_r.nc')

                ## dynamic adjustment
                X1 = pat_pool_best20_1.transpose()
                X2 = pat_pool_best20_2.transpose()
                X3 = pat_pool_best20_3.transpose()
                X4 = pat_pool_best20_4.transpose()
                X5 = pat_pool_best20_5.transpose()
                y = GPH500_ano_jja_2022_i[0]
                Xt_1 = t2m_pool_best20_1.transpose()
                Xt_2 = t2m_pool_best20_2.transpose()
                Xt_3 = t2m_pool_best20_3.transpose()
                Xt_4 = t2m_pool_best20_4.transpose()
                Xt_5 = t2m_pool_best20_5.transpose()
                ######### OLS #########
                # Construct temp analog
                B1 = np.linalg.lstsq(X1, y, rcond = None)[0]
                B2 = np.linalg.lstsq(X2, y, rcond = None)[0]
                B3 = np.linalg.lstsq(X3, y, rcond = None)[0]
                B4 = np.linalg.lstsq(X4, y, rcond = None)[0]
                B5 = np.linalg.lstsq(X5, y, rcond = None)[0]
                temp_OLS[dist_funcs.index(dist_func),0,i,:] = np.matmul(Xt_1, B1)
                temp_OLS[dist_funcs.index(dist_func),1,i,:] = np.matmul(Xt_2, B2)
                temp_OLS[dist_funcs.index(dist_func),2,i,:] = np.matmul(Xt_3, B3)
                temp_OLS[dist_funcs.index(dist_func),3,i,:] = np.matmul(Xt_4, B4)
                temp_OLS[dist_funcs.index(dist_func),4,i,:] = np.matmul(Xt_5, B5)

        temp_OLS_reshape = temp_OLS.reshape(4,5,92,nlat,nlon)
        every_other_5day = ['one','two','three','four','five']
        temp_OLS_reshape = xr.DataArray(
            temp_OLS_reshape,coords=[dist_funcs,every_other_5day,GPH500_ano_jja_2022.time.values,lat,lon],
            name='dynamic_t2m',dims=['dist_func','every_other_5day','target_time','lat','lon'])
        temp_OLS_reshape.to_netcdf('data/dynamic_t2m_2022/dynamic_t2m_' + ds + '_' + dom + '_r.nc')


In [10]:
# CCA analysis for 2022
for ds in datasets:    
    for dom in domain_name:
        GPH500_ano_mjjas = xr.open_dataarray('data/GPH500_5day_running_anomalies_' + ds + '_' + dom + '_1979_2022_MJJAS.nc')
        t2m_ano_mjjas = xr.open_dataarray('data/t2m_5day_running_anomalies_' + ds + '_' + dom + '_1979_2022_MJJAS.nc')
        all_times = GPH500_ano_mjjas.time
        all_years = list(range(1979,2023))
        n_year = len(all_years) # each year has 153 days in MJJAS
        target_year = 2022
        GPH500_ano_mjjas_2022 = GPH500_ano_mjjas.sel(time=GPH500_ano_mjjas.time.dt.year==2022)
        GPH500_ano_jja_2022 = GPH500_ano_mjjas_2022.sel(time=GPH500_ano_mjjas_2022.time.dt.month.isin(range(6,9))) # focus on JJA
        t2m_ano_mjjas_2022 = t2m_ano_mjjas.sel(time=t2m_ano_mjjas.time.dt.year==2022)
        t2m_ano_jja_2022 = t2m_ano_mjjas_2022.sel(time=t2m_ano_mjjas_2022.time.dt.month.isin(range(6,9))) # focus on JJA
        lon = GPH500_ano_jja_2022.lon
        lat = GPH500_ano_jja_2022.lat

        nlat = len(lat)
        nlon = len(lon)
        temp_OLS = np.empty((4, 5, 92,nlat*nlon)) # 4 for 4 distance functions, 5 for every other 5 days

        for dist_func in dist_funcs:
            for i in range(92):       # total 92 days in JJA
                if i != 85:
                    continue
                else:
                    pass
                GPH500_ano_jja_2022_i = GPH500_ano_jja_2022[i]
                t2m_ano_jja_2022_i = t2m_ano_jja_2022[i]
                pool_i_time = []
                # for y in range(len(all_years)):
                for y in range(len(all_years)):
                    if all_years[y] == target_year: # skip index for the target year
                        pass
                    else:
                        pool_i_time = pool_i_time + list(all_times[(y*153+i):(61+y*153+i)].values) # 24 and 39 represent 24th May and 7th June when i is for 1st June
                GPH500_ano_pool_i = GPH500_ano_mjjas[GPH500_ano_mjjas.time.isin(pool_i_time)]
                t2m_ano_pool_i = t2m_ano_mjjas[t2m_ano_mjjas.time.isin(pool_i_time)]
                if dist_func != 'tws':
                    ndays = GPH500_ano_pool_i.shape[0]
                    nlat  = GPH500_ano_pool_i.shape[1]
                    nlon  = GPH500_ano_pool_i.shape[2]
                    GPH500_ano_pool_i = GPH500_ano_pool_i.values.reshape((ndays,nlat*nlon))
                    GPH500_ano_pool_i_1 = GPH500_ano_pool_i[0::5]  # select every other 5 days to avoid selecting from consecutive days from the same weather event
                    GPH500_ano_pool_i_2 = GPH500_ano_pool_i[1::5]
                    GPH500_ano_pool_i_3 = GPH500_ano_pool_i[2::5]
                    GPH500_ano_pool_i_4 = GPH500_ano_pool_i[3::5]
                    GPH500_ano_pool_i_5 = GPH500_ano_pool_i[4::5]
                    t2m_ano_pool_i = t2m_ano_pool_i.values.reshape((ndays,nlat*nlon))
                    t2m_ano_pool_i_1 = t2m_ano_pool_i[0::5]  # select every other 5 days to avoid selecting from consecutive days from the same weather event
                    t2m_ano_pool_i_2 = t2m_ano_pool_i[1::5]
                    t2m_ano_pool_i_3 = t2m_ano_pool_i[2::5]
                    t2m_ano_pool_i_4 = t2m_ano_pool_i[3::5]
                    t2m_ano_pool_i_5 = t2m_ano_pool_i[4::5]
                    pool_i_time_1 = pool_i_time[0::5]
                    pool_i_time_2 = pool_i_time[1::5]
                    pool_i_time_3 = pool_i_time[2::5]
                    pool_i_time_4 = pool_i_time[3::5]
                    pool_i_time_5 = pool_i_time[4::5]
                    GPH500_ano_jja_2022_i = GPH500_ano_jja_2022_i.values.reshape((1,nlat*nlon))
                    if dist_func == dist_funcs[0]: # Euclidean distance
                        dist_pool_1 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_1, metric='euclidean')[0]
                        dist_pool_2 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_2, metric='euclidean')[0]
                        dist_pool_3 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_3, metric='euclidean')[0]
                        dist_pool_4 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_4, metric='euclidean')[0]
                        dist_pool_5 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_5, metric='euclidean')[0]
                    elif dist_func == dist_funcs[1]: # Pearson correlation
                        dist_pool_1 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_1]))[0,1:]
                        dist_pool_2 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_2]))[0,1:]
                        dist_pool_3 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_3]))[0,1:]
                        dist_pool_4 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_4]))[0,1:]
                        dist_pool_5 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_5]))[0,1:]
                    elif dist_func == dist_funcs[2]: # Spearman correlation
                        dist_pool_1 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_1]),axis=1)[0][0,1:]
                        dist_pool_2 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_2]),axis=1)[0][0,1:]
                        dist_pool_3 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_3]),axis=1)[0][0,1:]
                        dist_pool_4 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_4]),axis=1)[0][0,1:]
                        dist_pool_5 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_5]),axis=1)[0][0,1:]
                    elif dist_func == dist_funcs[3]: # Cosine similarity
                        dist_pool_1 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_1)[0]
                        dist_pool_2 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_2)[0]
                        dist_pool_3 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_3)[0]
                        dist_pool_4 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_4)[0]
                        dist_pool_5 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_5)[0]
                    if dist_func == dist_funcs[0]:
                        pat_pool_best20_1 = GPH500_ano_pool_i_1[dist_pool_1.argsort()[:20]]
                        pat_pool_best20_2 = GPH500_ano_pool_i_2[dist_pool_2.argsort()[:20]]
                        pat_pool_best20_3 = GPH500_ano_pool_i_3[dist_pool_3.argsort()[:20]]
                        pat_pool_best20_4 = GPH500_ano_pool_i_4[dist_pool_4.argsort()[:20]]
                        pat_pool_best20_5 = GPH500_ano_pool_i_5[dist_pool_5.argsort()[:20]]
                        t2m_pool_best20_1 = t2m_ano_pool_i_1[dist_pool_1.argsort()[:20]]
                        t2m_pool_best20_2 = t2m_ano_pool_i_2[dist_pool_2.argsort()[:20]]
                        t2m_pool_best20_3 = t2m_ano_pool_i_3[dist_pool_3.argsort()[:20]]
                        t2m_pool_best20_4 = t2m_ano_pool_i_4[dist_pool_4.argsort()[:20]]
                        t2m_pool_best20_5 = t2m_ano_pool_i_5[dist_pool_5.argsort()[:20]]
                        pat_pool_best20_time_1 = np.array(pool_i_time_1)[dist_pool_1.argsort()[:20]]
                        pat_pool_best20_time_2 = np.array(pool_i_time_2)[dist_pool_2.argsort()[:20]]
                        pat_pool_best20_time_3 = np.array(pool_i_time_3)[dist_pool_3.argsort()[:20]]
                        pat_pool_best20_time_4 = np.array(pool_i_time_4)[dist_pool_4.argsort()[:20]]
                        pat_pool_best20_time_5 = np.array(pool_i_time_5)[dist_pool_5.argsort()[:20]]
                    else:
                        pat_pool_best20_1 = GPH500_ano_pool_i_1[dist_pool_1.argsort()[-20:]]
                        pat_pool_best20_2 = GPH500_ano_pool_i_2[dist_pool_2.argsort()[-20:]]
                        pat_pool_best20_3 = GPH500_ano_pool_i_3[dist_pool_3.argsort()[-20:]]
                        pat_pool_best20_4 = GPH500_ano_pool_i_4[dist_pool_4.argsort()[-20:]]
                        pat_pool_best20_5 = GPH500_ano_pool_i_5[dist_pool_5.argsort()[-20:]]
                        t2m_pool_best20_1 = t2m_ano_pool_i_1[dist_pool_1.argsort()[-20:]]
                        t2m_pool_best20_2 = t2m_ano_pool_i_2[dist_pool_2.argsort()[-20:]]
                        t2m_pool_best20_3 = t2m_ano_pool_i_3[dist_pool_3.argsort()[-20:]]
                        t2m_pool_best20_4 = t2m_ano_pool_i_4[dist_pool_4.argsort()[-20:]]
                        t2m_pool_best20_5 = t2m_ano_pool_i_5[dist_pool_5.argsort()[-20:]]
                        pat_pool_best20_time_1 = np.array(pool_i_time_1)[dist_pool_1.argsort()[-20:]]
                        pat_pool_best20_time_2 = np.array(pool_i_time_2)[dist_pool_2.argsort()[-20:]]
                        pat_pool_best20_time_3 = np.array(pool_i_time_3)[dist_pool_3.argsort()[-20:]]
                        pat_pool_best20_time_4 = np.array(pool_i_time_4)[dist_pool_4.argsort()[-20:]]
                        pat_pool_best20_time_5 = np.array(pool_i_time_5)[dist_pool_5.argsort()[-20:]]
                else:
                    pass

                ## dynamic adjustment
                X1 = pat_pool_best20_1.transpose()
                X2 = pat_pool_best20_2.transpose()
                X3 = pat_pool_best20_3.transpose()
                X4 = pat_pool_best20_4.transpose()
                X5 = pat_pool_best20_5.transpose()
                y = GPH500_ano_jja_2022_i[0]
                Xt_1 = t2m_pool_best20_1.transpose()
                Xt_2 = t2m_pool_best20_2.transpose()
                Xt_3 = t2m_pool_best20_3.transpose()
                Xt_4 = t2m_pool_best20_4.transpose()
                Xt_5 = t2m_pool_best20_5.transpose()
                ######### OLS #########
                # Construct temp analog
                B1 = np.linalg.lstsq(X1, y, rcond = None)[0]
                B2 = np.linalg.lstsq(X2, y, rcond = None)[0]
                B3 = np.linalg.lstsq(X3, y, rcond = None)[0]
                B4 = np.linalg.lstsq(X4, y, rcond = None)[0]
                B5 = np.linalg.lstsq(X5, y, rcond = None)[0]
                temp_OLS[dist_funcs.index(dist_func),0,i,:] = np.matmul(Xt_1, B1)
                temp_OLS[dist_funcs.index(dist_func),1,i,:] = np.matmul(Xt_2, B2)
                temp_OLS[dist_funcs.index(dist_func),2,i,:] = np.matmul(Xt_3, B3)
                temp_OLS[dist_funcs.index(dist_func),3,i,:] = np.matmul(Xt_4, B4)
                temp_OLS[dist_funcs.index(dist_func),4,i,:] = np.matmul(Xt_5, B5)

                if i == 85:
                    y_reshape_1 = xr.DataArray(np.matmul(X1, B1).reshape(nlat,nlon),coords=[lat,lon],name='CCA_constructed_pattern',dims=['lat','lon'])
                    y_reshape_2 = xr.DataArray(np.matmul(X2, B1).reshape(nlat,nlon),coords=[lat,lon],name='CCA_constructed_pattern',dims=['lat','lon'])
                    y_reshape_3 = xr.DataArray(np.matmul(X3, B1).reshape(nlat,nlon),coords=[lat,lon],name='CCA_constructed_pattern',dims=['lat','lon'])
                    y_reshape_4 = xr.DataArray(np.matmul(X4, B1).reshape(nlat,nlon),coords=[lat,lon],name='CCA_constructed_pattern',dims=['lat','lon'])
                    y_reshape_5 = xr.DataArray(np.matmul(X5, B1).reshape(nlat,nlon),coords=[lat,lon],name='CCA_constructed_pattern',dims=['lat','lon'])
                    y_original = xr.DataArray(y.reshape(nlat,nlon),coords=[lat,lon],name='original_circulation_pattern',dims=['lat','lon'])
                    temp_OLS_1 = xr.DataArray(np.matmul(Xt_1, B1).reshape(nlat,nlon),coords=[lat,lon],name='CCA_constructed_dynamic_t2m',dims=['lat','lon'])
                    temp_OLS_2 = xr.DataArray(np.matmul(Xt_2, B1).reshape(nlat,nlon),coords=[lat,lon],name='CCA_constructed_dynamic_t2m',dims=['lat','lon'])
                    temp_OLS_3 = xr.DataArray(np.matmul(Xt_3, B1).reshape(nlat,nlon),coords=[lat,lon],name='CCA_constructed_dynamic_t2m',dims=['lat','lon'])
                    temp_OLS_4 = xr.DataArray(np.matmul(Xt_4, B1).reshape(nlat,nlon),coords=[lat,lon],name='CCA_constructed_dynamic_t2m',dims=['lat','lon'])
                    temp_OLS_5 = xr.DataArray(np.matmul(Xt_5, B1).reshape(nlat,nlon),coords=[lat,lon],name='CCA_constructed_dynamic_t2m',dims=['lat','lon'])
                    x_t2m_original = xr.DataArray(t2m_ano_jja_2022_i,coords=[lat,lon],name='original_t2m',dims=['lat','lon'])
                    y_reshape_1.to_netcdf('data/dynamic_t2m_2022/constructed_pattern_' + ds + '_' + dom + '_' + dist_func + '_one_' + str(i).zfill(2) + '_r.nc')
                    y_reshape_2.to_netcdf('data/dynamic_t2m_2022/constructed_pattern_' + ds + '_' + dom + '_' + dist_func + '_two_' + str(i).zfill(2) + '_r.nc')
                    y_reshape_3.to_netcdf('data/dynamic_t2m_2022/constructed_pattern_' + ds + '_' + dom + '_' + dist_func + '_three_' + str(i).zfill(2) + '_r.nc')
                    y_reshape_4.to_netcdf('data/dynamic_t2m_2022/constructed_pattern_' + ds + '_' + dom + '_' + dist_func + '_four_' + str(i).zfill(2) + '_r.nc')
                    y_reshape_5.to_netcdf('data/dynamic_t2m_2022/constructed_pattern_' + ds + '_' + dom + '_' + dist_func + '_five_' + str(i).zfill(2) + '_r.nc')
                    y_original.to_netcdf('data/dynamic_t2m_2022/original_pattern_' + ds + '_' + dom + '_' + dist_func + str(i).zfill(2) + '_r.nc')
                    temp_OLS_1.to_netcdf('data/dynamic_t2m_2022/dynamic_t2m_' + ds + '_' + dom + '_' + dist_func + '_one_' + str(i).zfill(2) + '_r.nc')
                    temp_OLS_2.to_netcdf('data/dynamic_t2m_2022/dynamic_t2m_' + ds + '_' + dom + '_' + dist_func + '_two_' + str(i).zfill(2) + '_r.nc')
                    temp_OLS_3.to_netcdf('data/dynamic_t2m_2022/dynamic_t2m_' + ds + '_' + dom + '_' + dist_func + '_three_' + str(i).zfill(2) + '_r.nc')
                    temp_OLS_4.to_netcdf('data/dynamic_t2m_2022/dynamic_t2m_' + ds + '_' + dom + '_' + dist_func + '_four_' + str(i).zfill(2) + '_r.nc')
                    temp_OLS_5.to_netcdf('data/dynamic_t2m_2022/dynamic_t2m_' + ds + '_' + dom + '_' + dist_func + '_five_' + str(i).zfill(2) + '_r.nc')
                    x_t2m_original.to_netcdf('data/dynamic_t2m_2022/original_t2m_' + ds + '_' + dom + '_' + dist_func + str(i).zfill(2) + '_r.nc')

                pat_pool_best20_da_1 = xr.DataArray(t2m_pool_best20_1.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_1,lat,lon],name='Dynamic adjusted T2m',dims=['analogue_time','lat','lon'])
                pat_pool_best20_da_2 = xr.DataArray(t2m_pool_best20_2.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_2,lat,lon],name='Dynamic adjusted T2m',dims=['analogue_time','lat','lon'])
                pat_pool_best20_da_3 = xr.DataArray(t2m_pool_best20_3.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_3,lat,lon],name='Dynamic adjusted T2m',dims=['analogue_time','lat','lon'])
                pat_pool_best20_da_4 = xr.DataArray(t2m_pool_best20_4.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_4,lat,lon],name='Dynamic adjusted T2m',dims=['analogue_time','lat','lon'])
                pat_pool_best20_da_5 = xr.DataArray(t2m_pool_best20_5.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_5,lat,lon],name='Dynamic adjusted T2m',dims=['analogue_time','lat','lon'])

            print(ds + dom + dist_func)

era5D1E_Dist
era5D1P_Corr
era5D1S_Corr
era5D1Cosine
era5D2E_Dist
era5D2P_Corr
era5D2S_Corr
era5D2Cosine
era5D3E_Dist
era5D3P_Corr
era5D3S_Corr
era5D3Cosine
jra55D1E_Dist
jra55D1P_Corr
jra55D1S_Corr
jra55D1Cosine
jra55D2E_Dist
jra55D2P_Corr
jra55D2S_Corr
jra55D2Cosine
jra55D3E_Dist
jra55D3P_Corr
jra55D3S_Corr
jra55D3Cosine


In [710]:
GPH500_ano_mjjas = xr.open_dataarray('data/GPH500_anomalies_' + ds + '_' + dom + '_1979_2022_MJJAS.nc')
t2m_ano_mjjas = xr.open_dataarray('data/t2m_anomalies_' + ds + '_' + dom + '_1979_2022_MJJAS.nc')
all_times = GPH500_ano_mjjas.time
all_years = list(range(1979,2023))
n_year = len(all_years) # each year has 153 days in MJJAS
target_year = 2022
GPH500_ano_mjjas_2022 = GPH500_ano_mjjas.sel(time=GPH500_ano_mjjas.time.dt.year==2022)
GPH500_ano_jja_2022 = GPH500_ano_mjjas_2022.sel(time=GPH500_ano_mjjas_2022.time.dt.month.isin(range(6,9))) # focus on JJA
lon = GPH500_ano_jja_2022.lon
lat = GPH500_ano_jja_2022.lat
nlat = len(lat)
nlon = len(lon)
dist_func = dist_funcs[0]

temp_OLS_1 = np.empty((92,nlat*nlon))
temp_OLS_2 = np.empty((92,nlat*nlon))
temp_OLS_3 = np.empty((92,nlat*nlon))
temp_OLS_4 = np.empty((92,nlat*nlon))
temp_OLS_5 = np.empty((92,nlat*nlon))

for i in range(92):       # total 92 days in JJA
    GPH500_ano_jja_2022_i = GPH500_ano_jja_2022[i]
    t2m_ano_jja_2022_i = t2m_ano_jja_2022[i]
    pool_i_time = []
    # for y in range(len(all_years)):
    for y in range(len(all_years)):
        if all_years[y] == target_year: # skip index for the target year
            pass
        else:
            pool_i_time = pool_i_time + list(all_times[(y*153+i):(61+y*153+i)].values) # 24 and 39 represent 24th May and 7th June when i is for 1st June
    GPH500_ano_pool_i = GPH500_ano_mjjas[GPH500_ano_mjjas.time.isin(pool_i_time)]
    t2m_ano_pool_i = t2m_ano_mjjas[t2m_ano_mjjas.time.isin(pool_i_time)]
    if dist_func != 'tws':
        ndays = GPH500_ano_pool_i.shape[0]
        nlat  = GPH500_ano_pool_i.shape[1]
        nlon  = GPH500_ano_pool_i.shape[2]
        GPH500_ano_pool_i = GPH500_ano_pool_i.values.reshape((ndays,nlat*nlon))
        GPH500_ano_pool_i_1 = GPH500_ano_pool_i[0::5]  # select every other 5 days to avoid selecting from consecutive days from the same weather event
        GPH500_ano_pool_i_2 = GPH500_ano_pool_i[1::5]
        GPH500_ano_pool_i_3 = GPH500_ano_pool_i[2::5]
        GPH500_ano_pool_i_4 = GPH500_ano_pool_i[3::5]
        GPH500_ano_pool_i_5 = GPH500_ano_pool_i[4::5]
        t2m_ano_pool_i = t2m_ano_pool_i.values.reshape((ndays,nlat*nlon))
        t2m_ano_pool_i_1 = t2m_ano_pool_i[0::5]  # select every other 5 days to avoid selecting from consecutive days from the same weather event
        t2m_ano_pool_i_2 = t2m_ano_pool_i[1::5]
        t2m_ano_pool_i_3 = t2m_ano_pool_i[2::5]
        t2m_ano_pool_i_4 = t2m_ano_pool_i[3::5]
        t2m_ano_pool_i_5 = t2m_ano_pool_i[4::5]
        pool_i_time_1 = pool_i_time[0::5]
        pool_i_time_2 = pool_i_time[1::5]
        pool_i_time_3 = pool_i_time[2::5]
        pool_i_time_4 = pool_i_time[3::5]
        pool_i_time_5 = pool_i_time[4::5]
        GPH500_ano_jja_2022_i = GPH500_ano_jja_2022_i.values.reshape((1,nlat*nlon))
        t2m_ano_jja_2022_i = t2m_ano_jja_2022_i.values.reshape((1,nlat*nlon))
        if dist_func == dist_funcs[0]: # Euclidean distance
            dist_pool_1 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_1, metric='euclidean')[0]
            dist_pool_2 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_2, metric='euclidean')[0]
            dist_pool_3 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_3, metric='euclidean')[0]
            dist_pool_4 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_4, metric='euclidean')[0]
            dist_pool_5 = distance.cdist(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_5, metric='euclidean')[0]
        elif dist_func == dist_funcs[1]: # Pearson correlation
            dist_pool_1 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_1]))[0,1:]
            dist_pool_2 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_2]))[0,1:]
            dist_pool_3 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_3]))[0,1:]
            dist_pool_4 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_4]))[0,1:]
            dist_pool_5 = np.corrcoef(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_5]))[0,1:]
        elif dist_func == dist_funcs[2]: # Spearman correlation
            dist_pool_1 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_1]),axis=1)[0][0,1:]
            dist_pool_2 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_2]),axis=1)[0][0,1:]
            dist_pool_3 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_3]),axis=1)[0][0,1:]
            dist_pool_4 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_4]),axis=1)[0][0,1:]
            dist_pool_5 = spearmanr(np.row_stack([GPH500_ano_jja_2022_i,GPH500_ano_pool_i_5]),axis=1)[0][0,1:]
        elif dist_func == dist_funcs[3]: # Cosine similarity
            dist_pool_1 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_1)[0]
            dist_pool_2 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_2)[0]
            dist_pool_3 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_3)[0]
            dist_pool_4 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_4)[0]
            dist_pool_5 = cosine_similarity(GPH500_ano_jja_2022_i, GPH500_ano_pool_i_5)[0]
        if dist_func == dist_funcs[0]:
            pat_pool_best20_1 = GPH500_ano_pool_i_1[dist_pool_1.argsort()[:20]]
            pat_pool_best20_2 = GPH500_ano_pool_i_2[dist_pool_2.argsort()[:20]]
            pat_pool_best20_3 = GPH500_ano_pool_i_3[dist_pool_3.argsort()[:20]]
            pat_pool_best20_4 = GPH500_ano_pool_i_4[dist_pool_4.argsort()[:20]]
            pat_pool_best20_5 = GPH500_ano_pool_i_5[dist_pool_5.argsort()[:20]]
            t2m_pool_best20_1 = t2m_ano_pool_i_1[dist_pool_1.argsort()[:20]]
            t2m_pool_best20_2 = t2m_ano_pool_i_2[dist_pool_2.argsort()[:20]]
            t2m_pool_best20_3 = t2m_ano_pool_i_3[dist_pool_3.argsort()[:20]]
            t2m_pool_best20_4 = t2m_ano_pool_i_4[dist_pool_4.argsort()[:20]]
            t2m_pool_best20_5 = t2m_ano_pool_i_5[dist_pool_5.argsort()[:20]]
            pat_pool_best20_time_1 = np.array(pool_i_time_1)[dist_pool_1.argsort()[:20]]
            pat_pool_best20_time_2 = np.array(pool_i_time_2)[dist_pool_2.argsort()[:20]]
            pat_pool_best20_time_3 = np.array(pool_i_time_3)[dist_pool_3.argsort()[:20]]
            pat_pool_best20_time_4 = np.array(pool_i_time_4)[dist_pool_4.argsort()[:20]]
            pat_pool_best20_time_5 = np.array(pool_i_time_5)[dist_pool_5.argsort()[:20]]
        else:
            pat_pool_best20_1 = GPH500_ano_pool_i_1[dist_pool_1.argsort()[-20:]]
            pat_pool_best20_2 = GPH500_ano_pool_i_2[dist_pool_2.argsort()[-20:]]
            pat_pool_best20_3 = GPH500_ano_pool_i_3[dist_pool_3.argsort()[-20:]]
            pat_pool_best20_4 = GPH500_ano_pool_i_4[dist_pool_4.argsort()[-20:]]
            pat_pool_best20_5 = GPH500_ano_pool_i_5[dist_pool_5.argsort()[-20:]]
            t2m_pool_best20_1 = t2m_ano_pool_i_1[dist_pool_1.argsort()[-20:]]
            t2m_pool_best20_2 = t2m_ano_pool_i_2[dist_pool_2.argsort()[-20:]]
            t2m_pool_best20_3 = t2m_ano_pool_i_3[dist_pool_3.argsort()[-20:]]
            t2m_pool_best20_4 = t2m_ano_pool_i_4[dist_pool_4.argsort()[-20:]]
            t2m_pool_best20_5 = t2m_ano_pool_i_5[dist_pool_5.argsort()[-20:]]
            pat_pool_best20_time_1 = np.array(pool_i_time_1)[dist_pool_1.argsort()[-20:]]
            pat_pool_best20_time_2 = np.array(pool_i_time_2)[dist_pool_2.argsort()[-20:]]
            pat_pool_best20_time_3 = np.array(pool_i_time_3)[dist_pool_3.argsort()[-20:]]
            pat_pool_best20_time_4 = np.array(pool_i_time_4)[dist_pool_4.argsort()[-20:]]
            pat_pool_best20_time_5 = np.array(pool_i_time_5)[dist_pool_5.argsort()[-20:]]
    else:
        pass

    pat_pool_best20_da_1 = xr.DataArray(pat_pool_best20_1.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_1,lat,lon],name='Analogue circulation pattern',dims=['analogue_time','lat','lon'])
    pat_pool_best20_da_2 = xr.DataArray(pat_pool_best20_2.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_2,lat,lon],name='Analogue circulation pattern',dims=['analogue_time','lat','lon'])
    pat_pool_best20_da_3 = xr.DataArray(pat_pool_best20_3.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_3,lat,lon],name='Analogue circulation pattern',dims=['analogue_time','lat','lon'])
    pat_pool_best20_da_4 = xr.DataArray(pat_pool_best20_4.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_4,lat,lon],name='Analogue circulation pattern',dims=['analogue_time','lat','lon'])
    pat_pool_best20_da_5 = xr.DataArray(pat_pool_best20_5.reshape(20,nlat,nlon),coords=[pat_pool_best20_time_5,lat,lon],name='Analogue circulation pattern',dims=['analogue_time','lat','lon'])

    ## dynamic adjustment
    X1 = pat_pool_best20_1.transpose()
    X2 = pat_pool_best20_2.transpose()
    X3 = pat_pool_best20_3.transpose()
    X4 = pat_pool_best20_4.transpose()
    X5 = pat_pool_best20_5.transpose()
    y = GPH500_ano_jja_2022_i[0]
    Xt_1 = t2m_pool_best20_1.transpose()
    Xt_2 = t2m_pool_best20_2.transpose()
    Xt_3 = t2m_pool_best20_3.transpose()
    Xt_4 = t2m_pool_best20_4.transpose()
    Xt_5 = t2m_pool_best20_5.transpose()
    ######### OLS #########
    # Construct temp analog
    B1 = np.linalg.lstsq(X1, y, rcond = None)[0]
    B2 = np.linalg.lstsq(X2, y, rcond = None)[0]
    B3 = np.linalg.lstsq(X3, y, rcond = None)[0]
    B4 = np.linalg.lstsq(X4, y, rcond = None)[0]
    B5 = np.linalg.lstsq(X5, y, rcond = None)[0]
    temp_OLS_1[i,:] = np.matmul(Xt_1, B1)
    temp_OLS_2[i,:] = np.matmul(Xt_2, B2)
    temp_OLS_3[i,:] = np.matmul(Xt_3, B3)
    temp_OLS_4[i,:] = np.matmul(Xt_4, B4)
    temp_OLS_5[i,:] = np.matmul(Xt_5, B5)
    

In [86]:
# CCA analysis for historical period
for ds in datasets:    
    for dom in domain_name:
        GPH500_ano_mjjas = xr.open_dataarray('data/GPH500_5day_running_anomalies_' + ds + '_' + dom + '_1979_2022_MJJAS.nc')
        t2m_ano_mjjas = xr.open_dataarray('data/t2m_5day_running_anomalies_' + ds + '_' + dom + '_1979_2022_MJJAS.nc')
        all_times = GPH500_ano_mjjas.time
        all_years = list(range(1979,2023))
        n_year = len(all_years) # each year has 153 days in MJJAS
        
        lon = GPH500_ano_mjjas.lon
        lat = GPH500_ano_mjjas.lat
        nlat = len(lat)
        nlon = len(lon)
        temp_OLS = np.empty((44, 4, 5, 92,nlat*nlon)) # 44 for 44 years in the historical period, 4 for 4 distance functions, 5 for every other 5 days

        for ty in range(1979,2023):
            target_year = ty
            GPH500_ano_mjjas_ty = GPH500_ano_mjjas.sel(time=GPH500_ano_mjjas.time.dt.year==ty)
            GPH500_ano_jja_ty = GPH500_ano_mjjas_ty.sel(time=GPH500_ano_mjjas_ty.time.dt.month.isin(range(6,9))) # focus on JJA
            t2m_ano_mjjas_ty = t2m_ano_mjjas.sel(time=t2m_ano_mjjas.time.dt.year==ty)
            t2m_ano_jja_ty = t2m_ano_mjjas_ty.sel(time=t2m_ano_mjjas_ty.time.dt.month.isin(range(6,9))) # focus on JJA

            for dist_func in dist_funcs:
                for i in range(92):       # total 92 days in JJA
                    GPH500_ano_jja_ty_i = GPH500_ano_jja_ty[i]
                    pool_i_time = []
                    # for y in range(len(all_years)):
                    for y in range(len(all_years)):
                        if all_years[y] == target_year: # skip index for the target year
                            pass
                        else:
                            pool_i_time = pool_i_time + list(all_times[(y*153+i):(61+y*153+i)].values) # 24 and 39 represent 24th May and 7th June when i is for 1st June
                    GPH500_ano_pool_i = GPH500_ano_mjjas[GPH500_ano_mjjas.time.isin(pool_i_time)]
                    t2m_ano_pool_i = t2m_ano_mjjas[t2m_ano_mjjas.time.isin(pool_i_time)]
                    if dist_func != 'tws':
                        ndays = GPH500_ano_pool_i.shape[0]
                        nlat  = GPH500_ano_pool_i.shape[1]
                        nlon  = GPH500_ano_pool_i.shape[2]
                        GPH500_ano_pool_i = GPH500_ano_pool_i.values.reshape((ndays,nlat*nlon))
                        GPH500_ano_pool_i_1 = GPH500_ano_pool_i[0::5]  # select every other 5 days to avoid selecting from consecutive days from the same weather event
                        GPH500_ano_pool_i_2 = GPH500_ano_pool_i[1::5]
                        GPH500_ano_pool_i_3 = GPH500_ano_pool_i[2::5]
                        GPH500_ano_pool_i_4 = GPH500_ano_pool_i[3::5]
                        GPH500_ano_pool_i_5 = GPH500_ano_pool_i[4::5]
                        t2m_ano_pool_i = t2m_ano_pool_i.values.reshape((ndays,nlat*nlon))
                        t2m_ano_pool_i_1 = t2m_ano_pool_i[0::5]  # select every other 5 days to avoid selecting from consecutive days from the same weather event
                        t2m_ano_pool_i_2 = t2m_ano_pool_i[1::5]
                        t2m_ano_pool_i_3 = t2m_ano_pool_i[2::5]
                        t2m_ano_pool_i_4 = t2m_ano_pool_i[3::5]
                        t2m_ano_pool_i_5 = t2m_ano_pool_i[4::5]
                        pool_i_time_1 = pool_i_time[0::5]
                        pool_i_time_2 = pool_i_time[1::5]
                        pool_i_time_3 = pool_i_time[2::5]
                        pool_i_time_4 = pool_i_time[3::5]
                        pool_i_time_5 = pool_i_time[4::5]
                        GPH500_ano_jja_ty_i = GPH500_ano_jja_ty_i.values.reshape((1,nlat*nlon))
                        if dist_func == dist_funcs[0]: # Euclidean distance
                            dist_pool_1 = distance.cdist(GPH500_ano_jja_ty_i, GPH500_ano_pool_i_1, metric='euclidean')[0]
                            dist_pool_2 = distance.cdist(GPH500_ano_jja_ty_i, GPH500_ano_pool_i_2, metric='euclidean')[0]
                            dist_pool_3 = distance.cdist(GPH500_ano_jja_ty_i, GPH500_ano_pool_i_3, metric='euclidean')[0]
                            dist_pool_4 = distance.cdist(GPH500_ano_jja_ty_i, GPH500_ano_pool_i_4, metric='euclidean')[0]
                            dist_pool_5 = distance.cdist(GPH500_ano_jja_ty_i, GPH500_ano_pool_i_5, metric='euclidean')[0]
                        elif dist_func == dist_funcs[1]: # Pearson correlation
                            dist_pool_1 = np.corrcoef(np.row_stack([GPH500_ano_jja_ty_i,GPH500_ano_pool_i_1]))[0,1:]
                            dist_pool_2 = np.corrcoef(np.row_stack([GPH500_ano_jja_ty_i,GPH500_ano_pool_i_2]))[0,1:]
                            dist_pool_3 = np.corrcoef(np.row_stack([GPH500_ano_jja_ty_i,GPH500_ano_pool_i_3]))[0,1:]
                            dist_pool_4 = np.corrcoef(np.row_stack([GPH500_ano_jja_ty_i,GPH500_ano_pool_i_4]))[0,1:]
                            dist_pool_5 = np.corrcoef(np.row_stack([GPH500_ano_jja_ty_i,GPH500_ano_pool_i_5]))[0,1:]
                        elif dist_func == dist_funcs[2]: # Spearman correlation
                            dist_pool_1 = spearmanr(np.row_stack([GPH500_ano_jja_ty_i,GPH500_ano_pool_i_1]),axis=1)[0][0,1:]
                            dist_pool_2 = spearmanr(np.row_stack([GPH500_ano_jja_ty_i,GPH500_ano_pool_i_2]),axis=1)[0][0,1:]
                            dist_pool_3 = spearmanr(np.row_stack([GPH500_ano_jja_ty_i,GPH500_ano_pool_i_3]),axis=1)[0][0,1:]
                            dist_pool_4 = spearmanr(np.row_stack([GPH500_ano_jja_ty_i,GPH500_ano_pool_i_4]),axis=1)[0][0,1:]
                            dist_pool_5 = spearmanr(np.row_stack([GPH500_ano_jja_ty_i,GPH500_ano_pool_i_5]),axis=1)[0][0,1:]
                        elif dist_func == dist_funcs[3]: # Cosine similarity
                            dist_pool_1 = cosine_similarity(GPH500_ano_jja_ty_i, GPH500_ano_pool_i_1)[0]
                            dist_pool_2 = cosine_similarity(GPH500_ano_jja_ty_i, GPH500_ano_pool_i_2)[0]
                            dist_pool_3 = cosine_similarity(GPH500_ano_jja_ty_i, GPH500_ano_pool_i_3)[0]
                            dist_pool_4 = cosine_similarity(GPH500_ano_jja_ty_i, GPH500_ano_pool_i_4)[0]
                            dist_pool_5 = cosine_similarity(GPH500_ano_jja_ty_i, GPH500_ano_pool_i_5)[0]
                        if dist_func == dist_funcs[0]:
                            pat_pool_best20_1 = GPH500_ano_pool_i_1[dist_pool_1.argsort()[:20]]
                            pat_pool_best20_2 = GPH500_ano_pool_i_2[dist_pool_2.argsort()[:20]]
                            pat_pool_best20_3 = GPH500_ano_pool_i_3[dist_pool_3.argsort()[:20]]
                            pat_pool_best20_4 = GPH500_ano_pool_i_4[dist_pool_4.argsort()[:20]]
                            pat_pool_best20_5 = GPH500_ano_pool_i_5[dist_pool_5.argsort()[:20]]
                            t2m_pool_best20_1 = t2m_ano_pool_i_1[dist_pool_1.argsort()[:20]]
                            t2m_pool_best20_2 = t2m_ano_pool_i_2[dist_pool_2.argsort()[:20]]
                            t2m_pool_best20_3 = t2m_ano_pool_i_3[dist_pool_3.argsort()[:20]]
                            t2m_pool_best20_4 = t2m_ano_pool_i_4[dist_pool_4.argsort()[:20]]
                            t2m_pool_best20_5 = t2m_ano_pool_i_5[dist_pool_5.argsort()[:20]]
                        else:
                            pat_pool_best20_1 = GPH500_ano_pool_i_1[dist_pool_1.argsort()[-20:]]
                            pat_pool_best20_2 = GPH500_ano_pool_i_2[dist_pool_2.argsort()[-20:]]
                            pat_pool_best20_3 = GPH500_ano_pool_i_3[dist_pool_3.argsort()[-20:]]
                            pat_pool_best20_4 = GPH500_ano_pool_i_4[dist_pool_4.argsort()[-20:]]
                            pat_pool_best20_5 = GPH500_ano_pool_i_5[dist_pool_5.argsort()[-20:]]
                            t2m_pool_best20_1 = t2m_ano_pool_i_1[dist_pool_1.argsort()[-20:]]
                            t2m_pool_best20_2 = t2m_ano_pool_i_2[dist_pool_2.argsort()[-20:]]
                            t2m_pool_best20_3 = t2m_ano_pool_i_3[dist_pool_3.argsort()[-20:]]
                            t2m_pool_best20_4 = t2m_ano_pool_i_4[dist_pool_4.argsort()[-20:]]
                            t2m_pool_best20_5 = t2m_ano_pool_i_5[dist_pool_5.argsort()[-20:]]
                    else:
                        pass

                    ## dynamic adjustment
                    X1 = pat_pool_best20_1.transpose()
                    X2 = pat_pool_best20_2.transpose()
                    X3 = pat_pool_best20_3.transpose()
                    X4 = pat_pool_best20_4.transpose()
                    X5 = pat_pool_best20_5.transpose()
                    y = GPH500_ano_jja_ty_i[0]
                    Xt_1 = t2m_pool_best20_1.transpose()
                    Xt_2 = t2m_pool_best20_2.transpose()
                    Xt_3 = t2m_pool_best20_3.transpose()
                    Xt_4 = t2m_pool_best20_4.transpose()
                    Xt_5 = t2m_pool_best20_5.transpose()
                    ######### OLS #########
                    # Construct temp analog
                    B1 = np.linalg.lstsq(X1, y, rcond = None)[0]
                    B2 = np.linalg.lstsq(X2, y, rcond = None)[0]
                    B3 = np.linalg.lstsq(X3, y, rcond = None)[0]
                    B4 = np.linalg.lstsq(X4, y, rcond = None)[0]
                    B5 = np.linalg.lstsq(X5, y, rcond = None)[0]
                    temp_OLS[range(1979,2023).index(ty),dist_funcs.index(dist_func),0,i,:] = np.matmul(Xt_1, B1)
                    temp_OLS[range(1979,2023).index(ty),dist_funcs.index(dist_func),1,i,:] = np.matmul(Xt_2, B2)
                    temp_OLS[range(1979,2023).index(ty),dist_funcs.index(dist_func),2,i,:] = np.matmul(Xt_3, B3)
                    temp_OLS[range(1979,2023).index(ty),dist_funcs.index(dist_func),3,i,:] = np.matmul(Xt_4, B4)
                    temp_OLS[range(1979,2023).index(ty),dist_funcs.index(dist_func),4,i,:] = np.matmul(Xt_5, B5)

        temp_OLS_reshape = temp_OLS.reshape(44,4,5,92,nlat,nlon)
        every_other_5day = ['one','two','three','four','five']
        year = range(1979,2023)
        temp_OLS_reshape = xr.DataArray(
            temp_OLS_reshape,coords=[year,dist_funcs,every_other_5day,GPH500_ano_jja_ty.time.values,lat,lon],
            name='dynamic_t2m',dims=['year','dist_func','every_other_5day','target_time','lat','lon'])
        temp_OLS_reshape.to_netcdf('data/dynamic_t2m_historical/dynamic_t2m_' + ds + '_' + dom + '.nc')