# Lagged Correlation Analysis of Model

In [1]:
import os
import warnings
warnings.filterwarnings("ignore", message="invalid value encountered in true_divide")
warnings.filterwarnings("ignore",message="Unable to decode time axis into full numpy.datetime64 objects, continuing using cftime.datetime objects instead, reason: dates out of range")
warnings.filterwarnings("ignore", message="invalid value encountered in reduce")

import xarray as xr
import numpy as np
import scipy.stats as stats
import pickle as pkl

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib as mpl

import cmocean.cm as cm

In [3]:
mpl.rcParams['figure.dpi']= 120

In [4]:
def open_metric(var, reg, metric, timescale='monthly', ens_type=''):
    
    writedir = '/home/bbuchovecky/storage/so_predict_derived/'
    
    if metric == 'anom' or metric == 'mean':
        subdir = 'CTRL/'+var.upper()+'/'
        filename = var.lower()+'_ts_'+reg+'_'+timescale+'_'+metric+'.nc'
    
    if metric == 'ppp':
        subdir = 'PPP/'+var.upper()+'/'
        if ens_type != '':
            ens_type += '_'
        filename = var.lower()+'_ts_'+reg+'_'+timescale+'_'+ens_type+'ppp.nc'
        
    return xr.open_dataset(writedir+subdir+filename)

def get_plotting_labels():
    file = open('/home/bbuchovecky/storage/so_predict_derived/plotting_dicts.pkl','rb')
    plotting_dicts = pkl.load(file)
    file.close()
    
    reg_names = plotting_dicts['reg_names']
    var_su_names = plotting_dicts['var_su_names']
    abbrv_month_names = plotting_dicts['abbrv_month_names']
    
    return reg_names, var_su_names, abbrv_month_names

# Autocorrelation

## Test Case
The goal here is to compute the lagged correlations for all initial months with lags of -4 to +4 months for a single variable.

### Single init month

In [16]:
years = 10
N = 12 * years

## generate random time series of length N
## 0 is Jan, 1 is Feb, ..., 11 is Dec
a = np.arange(N)
print('a', a, '\n')

## initialization month
init = 2

## max absolute value lag (in months)
lag = 2

for (im,m) in enumerate(range(-lag,lag+1)):    
#     trim = 12*(abs(m)//12+1)
    trim = 12*((abs(m)-1)//12+1)
        
    if m < 0:
        tmp_a_init = a[init+trim:N:12]
        tmp_a_lag = a[init+trim+m:N-trim+init:12]

    else:
        tmp_a_init = a[init:N-m:12]
        tmp_a_lag = a[init+m:N:12]
     
    if tmp_a_init.size != tmp_a_lag.size:
        print('### lag', m)
        print('### trim', trim)
        print('### tmp_a_init ', tmp_a_init.size)
        print('### tmp_a_lag ', tmp_a_lag.size)
        print('tmp_a_init', tmp_a_init)
        print('tmp_a_lag ', tmp_a_lag, '\n')
        
    else:
        print('lag', m)
        print('trim', trim)
        print('tmp_a_init', tmp_a_init)
        print('tmp_a_lag ', tmp_a_lag, '\n')

a [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119] 

lag -2
trim 12
tmp_a_init [ 14  26  38  50  62  74  86  98 110]
tmp_a_lag  [ 12  24  36  48  60  72  84  96 108] 

lag -1
trim 12
tmp_a_init [ 14  26  38  50  62  74  86  98 110]
tmp_a_lag  [ 13  25  37  49  61  73  85  97 109] 

lag 0
trim 0
tmp_a_init [  2  14  26  38  50  62  74  86  98 110]
tmp_a_lag  [  2  14  26  38  50  62  74  86  98 110] 

lag 1
trim 12
tmp_a_init [  2  14  26  38  50  62  74  86  98 110]
tmp_a_lag  [  3  15  27  39  51  63  75  87  99 111] 

lag 2
trim 12
tmp_a_init [  2  14  

### All (12) init months

In [5]:
maxlag = 2

## generate random time series of length N
years = 10
N = 12 * years
a = np.random.randint(0,100, (N))

## rows are different init months, cols are different lags
r_matrix = np.zeros((12, 2*maxlag+1))
p_matrix = np.zeros((12, 2*maxlag+1))

lag_matrix = np.zeros((12, 2*maxlag+1))
init_matrix = np.zeros((12, 2*maxlag+1))

for (it,init) in enumerate(range(0,12)):
    for (ig,lag) in enumerate(range(-maxlag,maxlag+1)):    
        trim = 12*((abs(lag)-1)//12+1)
        init_matrix[it][ig] = init
        lag_matrix[it][ig] = lag

        if lag < 0:
            tmp_a_init = a[init+trim:N:12]
            tmp_a_lagged = a[init+trim+lag:N-trim+init:12]

        else:
            tmp_a_init = a[init:N-lag:12]
            tmp_a_lagged = a[init+lag:N:12]

        r_matrix[it][ig], p_matrix[it][ig] = stats.pearsonr(tmp_a_init, tmp_a_lagged)

## Create function

In [5]:
def compute_lagged_autocorr(a, maxlag):
    ## convert DataArray to NumPy array
    a = a.values
    N = a.size
    
    ## rows are different init months, cols are different lags
    r_matrix = np.zeros((2*maxlag+1, 12))
    p_matrix = np.zeros((2*maxlag+1, 12))

    init_matrix = np.zeros((2*maxlag+1, 12))
    lag_matrix = np.zeros((2*maxlag+1, 12))
    
    for (it,init) in enumerate(range(0,12)):
        for (ig,lag) in enumerate(range(-maxlag,maxlag+1)):    
            trim = 12*((abs(lag)-1)//12+1)
            init_matrix[ig][it] = init
            lag_matrix[ig][it] = lag

            if lag < 0:
                tmp_a_init = a[init+trim:N:12]
                tmp_a_lagged = a[init+trim+lag:N-trim+init:12]

            else:
                tmp_a_init = a[init:N-lag:12]
                tmp_a_lagged = a[init+lag:N:12]

            r_matrix[ig][it], p_matrix[ig][it] = stats.pearsonr(tmp_a_init, tmp_a_lagged)
            
    return r_matrix, p_matrix, init_matrix, lag_matrix

In [6]:
## compute t-test with 95 degrees of freedom (95% confidence interval) to determine threshold for hatching
def plot_single_lagged_autocorr(variable, region, maxlag=24, threshold=0.75, hatch='//', figsize=(6,7)):
    
    ts = open_metric(variable, 'so', 'anom')[region]
    
    r, p, inits, lags = compute_lagged_autocorr(ts, maxlag=maxlag)

    reg_names, var_su_names, abbrv_month_names = get_plotting_labels()

    fig,ax = plt.subplots(figsize=figsize)
    im = ax.pcolormesh(r, vmin=-1.0, vmax=1.0, cmap=cm.balance, edgecolor='k', linewidth=0.01)

    # for (i, j), z in np.ndenumerate(inits):
    #     ax.text(j+0.5, i+0.5, '%3d'%z, ha='center', va='center', fontsize=4)

    ax.set_yticks((np.arange(r.shape[0]) + 0.5)[::4])
    ax.set_yticklabels(np.arange(-24,25)[::4])
    ax.set_ylabel('LAG')

    ax.set_xticks(np.arange(r.shape[1]) + 0.5)
    ax.set_xticklabels(abbrv_month_names)
    ax.set_xlabel('INIT MONTH')
    
    ax.set_title(reg_names[region]+' '+var_su_names[variable]+' autocorrelation')
   
    cb = fig.colorbar(im, ax=ax, label='Correlation coefficient')

    if hatch:
        masked = np.ma.masked_array(r, mask=np.where(abs(r) > threshold, 0, 1))
        masked_im = ax.pcolor(masked, hatch=hatch, alpha=0)

        cb.ax.plot([-1,1], [threshold,threshold], color='k')
        cb.ax.plot([-1,1], [-threshold,-threshold], color='k')
        cb.ax.fill_between([-1,1], [1,1], [threshold,threshold], hatch=hatch, alpha=0)
        cb.ax.fill_between([-1,1], [-threshold,-threshold], [-1,-1], hatch=hatch, alpha=0)

In [16]:
# plot_lagged_autocorr('mld', 'SouthernOcean', hatch=False)

In [78]:
def plot_regional_lagged_autocorr(variable, maxlag=24, threshold=0.75, hatch='//', figsize=(8,7)):
    reg_names, var_su_names, abbrv_month_names = get_plotting_labels()
    region_list = ['Weddell', 'Indian', 'WestPacific', 'SouthernOcean', 'Ross', 'AmundBell']
    
    fig,ax = plt.subplots(2, 3, figsize=figsize)
    
    for (ireg,reg) in enumerate(region_list):  
        ts = open_metric(variable, 'so', 'anom')[reg]
        r, p, inits, lags = compute_lagged_autocorr(ts, maxlag=maxlag)
        
        im = ax[int(ireg/3),ireg%3].pcolormesh(r, vmin=-1.0, vmax=1.0, cmap=cm.balance, edgecolor='k', linewidth=0.01)
        
        if hatch:
            masked = np.ma.masked_array(r, mask=np.where(abs(r) > threshold, 0, 1))
            masked_im = ax[int(ireg/3),ireg%3].pcolor(masked, hatch=hatch, alpha=0)
        
        ax[int(ireg/3),ireg%3].set_yticks((np.arange(r.shape[0]) + 0.5)[::4])
        ax[int(ireg/3),ireg%3].set_yticklabels(np.arange(-24,25)[::4])
        ax[int(ireg/3),ireg%3].set_xticks(np.arange(r.shape[1]) + 0.5)
        ax[int(ireg/3),ireg%3].set_xticklabels(abbrv_month_names, rotation=45, ha="right", rotation_mode="anchor")
        ax[int(ireg/3),ireg%3].set_title(reg_names[reg], fontweight='bold', fontsize='large')
        
        if ireg%3 == 0:
            ax[int(ireg/3),ireg%3].set_ylabel('LAG')
        if ireg//3 == 1:
            ax[int(ireg/3),ireg%3].set_xlabel('INIT MONTH')
    
    fig.suptitle(var_su_names[variable], x=0, fontweight='bold', fontsize='x-large')
    fig.tight_layout()
    
    cb = fig.colorbar(im, ax=ax.ravel().tolist(), label='correlation coefficient')
    
    if hatch:
        cb.ax.plot([-1,1], [threshold,threshold], color='k')
        cb.ax.plot([-1,1], [-threshold,-threshold], color='k')
        cb.ax.fill_between([-1,1], [1,1], [threshold,threshold], hatch=hatch, alpha=0)
        cb.ax.fill_between([-1,1], [-threshold,-threshold], [-1,-1], hatch=hatch, alpha=0)
        
    return fig,ax

In [90]:
# plot_regional_lagged_autocorr('sfc_chl', hatch=False, figsize=(14,12));