In [1]:
# (C) Copyright 1996- ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

In [2]:
import xarray as xr
import xskillscore as xs
import pandas as pd
import numpy as np

from statsmodels.api import OLS, add_constant
from sklearn.feature_selection import f_regression#, mutual_info_regression

from itertools import product, combinations, groupby
import multiprocessing # parallel processing
import tqdm # timing
import warnings
warnings.filterwarnings("ignore")

import tigramite
from tigramite import data_processing as pp
from tigramite import plotting as tp
from tigramite.pcmci import PCMCI
from tigramite.independence_tests import ParCorr#, GPDC, CMIknn

import matplotlib.pyplot as plt
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import Normalize
import seaborn as sns

In [3]:
dir_loc = ''
output_loc = dir_loc+'Deliverable/'
!mkdir $output_loc
temp_resolution = 'd'
temp_aggr = 5
months_analysed = [9, 10, 11, 12, 1, 2]
max_shift = 5 # for Correlation: negative values refer to pattern's forecasting, positive to forecasting indices
t_min, t_max = 1, 5 # for Tigramite: positive values refers to forecating of patterns
ols_min, ols_max = -2, 6 # for OLS: positive values relate to forecasting of patterns
pat_names_short = ['AtlL', 'BscL', 'IbrL', 'SclL', 'BlkL', 'BlSL', 'MedH', 'MnrL', 'MnrH'] # short names 
pat_colors = [(.02, 0.40, 0.55), (.0, 0.75, 0.7), (.0, 0.63, 0.4), (.85, 0.35, 0.1), (.8, 0.5, 0.75), 
              (.7, 0.5, 0.45), (1, 0.7, 0.75), (.65, 0.65, 0.65), (.95, 0.95, 0.2)]
color_palette = {i: j for i, j in zip(pat_names_short, pat_colors)}    

original_plotting = sns.plotting_context() # original plotting content

In [4]:
# read data
ts_patterns = xr.open_dataarray(dir_loc+'MedIndices.nc')
ts_indices = xr.open_dataarray(dir_loc+'AtmIndices.nc')

In [5]:
old_names = ts_patterns.cluster.values
new_names = [i if i<0 else pat_names_short[i] for i in old_names]
ts_patterns = ts_patterns.assign_coords({'cluster': new_names})
del(old_names, new_names)

In [6]:
"Main indices used for the analysis"
indices_used_feat_sel = ['NAO', 'AO', 'MOI1', 'WeMO', 'PrecipSahel', 'ArcticEurope']
indices_used_feat_sel += [f'SST_{i}' for i in range(1, 7)]+[f'SML1_{i}' for i in range(1, 7)]
indices_used_feat_sel += ['MJO_RMM1', 'MJO_RMM2']

# new names for the plots
indices_used_final_names = {i:i for i in indices_used_feat_sel}
for i in indices_used_final_names:
    if i=='MOI1': indices_used_final_names[i]='MO'
    if 'SML1' in i: indices_used_final_names[i]='SML'+indices_used_final_names[i][-2:]
     
"Main combinations used for the analysis"
main_type_used = 'InterAnnRemov'
main_indicator_used = 'projection_norm'

del(i)

In [7]:
# ts_patterns and atm_var_data (later on) have same dates, as ts_patterns is derived from that data
common_dates = set(ts_patterns.time.values) & set(ts_indices.time.values)
common_dates = sorted(common_dates)

In [8]:
# keep common dates and resample to the temporal resolution of interest
ts_patterns_final = ts_patterns.sel(time=sorted(common_dates))
ts_patterns_final = ts_patterns_final.resample(time=f'{temp_aggr}{temp_resolution}').mean('time')

ts_indices_final = ts_indices.sel(time=sorted(common_dates))
ts_indices_final = ts_indices_final.resample(time=f'{temp_aggr}{temp_resolution}').mean('time')

In [9]:
sns.set_context('notebook', font_scale=1)
cross_cor = ts_patterns_final.sel(Type=main_type_used, Indicator=main_indicator_used, cluster=pat_names_short)
cross_cor = cross_cor.isel(time=cross_cor.time.dt.month.isin(months_analysed))
cross_cor = cross_cor.to_dataframe('S').pivot_table(index='time', columns='cluster', values='S')
cross_cor = cross_cor[pat_names_short] # keep the actual order (before it goes alphabetically)
cross_cor = cross_cor.corr()
mask_used = np.triu(cross_cor)
ax = sns.heatmap(cross_cor, cmap='RdBu_r', vmin=-1, vmax=1, annot=True, fmt='.2f', mask=mask_used)
ax.set_ylabel('')
ax.set_xlabel('')
plt.savefig(f'{output_loc}CrossCorr{temp_aggr}{temp_resolution}.png', dpi=600)
del(cross_cor, mask_used, ax)

In [10]:
def lagged_correlation(input_data):
    
    used_data, i_shift = input_data
    
    # shift works correctly, cause the data are on continuous time, without gaps (besides common aggregation)
    used_data_shifted = used_data.shift(time=i_shift).dropna('time') # drop NaN, otherwise dask has memory issues
    used_data_shifted = used_data_shifted.isel(time=used_data_shifted.time.dt.month.isin(months_analysed))
    
    ts_patterns_final_used = ts_patterns_final.sel(time=used_data_shifted.time.values)
    
    lagged_cor = xr.corr(ts_patterns_final_used, used_data_shifted, dim='time')
    lagged_cor.name = 'correlation'
    lagged_cor_final = xr.merge([lagged_cor])

    # because shift moves values at the next timestep, the lag is reversed. E.g. for shift 1, we have a
    # forecasting information of 1 timestep earlier, thus lag should be -1.
    lagged_cor_final = lagged_cor_final.assign_coords({'lag': -i_shift})
    
    return lagged_cor_final

In [11]:
all_lags = np.arange(-max_shift, max_shift+1).tolist()
pool = multiprocessing.Pool() # object for multiprocessing
combs = list(product([ts_indices_final], all_lags))
correlations_indices = list(tqdm.tqdm(pool.imap(lagged_correlation, combs), total=len(combs), position=0))
pool.close(); pool.join()
correlations_indices = xr.concat(correlations_indices, dim='lag').sortby('lag')
correlations_indices.to_netcdf(f'{output_loc}CorrelationsIndices_{temp_aggr}{temp_resolution}.nc')

del(combs, pool)

In [12]:
plot_ind = correlations_indices['correlation'].sel(index=indices_used_feat_sel)
plot_ind = plot_ind.assign_coords({'index': [indices_used_final_names[i] for i in plot_ind.index.values]})
plot_ind = plot_ind.sel(cluster=pat_names_short).sel(Type='InterAnnRemov', Indicator='projection_norm')
plot_ind = plot_ind.to_dataframe('Corr').reset_index()
plot_ind.rename(columns={'cluster': 'Med. Pat.'}, inplace=True)

sns.set_context('paper', font_scale=2)
plot_ind = sns.relplot(data=plot_ind, hue='Med. Pat.', x='lag', y='Corr', kind='line',
                       col='index', col_wrap=4, linewidth=3, palette=color_palette)
for i_ax in plot_ind.axes.flatten():
    i_ax.axvline(0, color='grey', linestyle='--')
    i_ax.axhline(0, color='grey', linestyle='--')
    i_ax.set_xticks(np.arange(-max_shift, max_shift+1))
    
plot_ind.fig.savefig(f'{output_loc}CorrelationsIndicesPlot_{temp_aggr}{temp_resolution}.png', dpi=600)

del(plot_ind, i_ax)

Causal connections (tigramite)

In [13]:
ts_pat_df = ts_patterns_final.sel(Type=main_type_used, Indicator=main_indicator_used).to_dataframe('Value')
ts_pat_df = ts_pat_df.pivot_table(index='time', columns='cluster', values='Value')

ts_ind_df = ts_indices_final.sel(Type=main_type_used).sel(index=indices_used_feat_sel).to_dataframe('type_used')
ts_ind_df = ts_ind_df.pivot_table(index='time', columns='index', values='type_used')
ts_ind_df.columns = [indices_used_final_names[i] for i in ts_ind_df.columns]

In [14]:
def causality_tigramite(pat_used):
    
    ts_final = pd.concat([ts_pat_df[pat_used], ts_ind_df], axis=1) # get final dataset with all indices
     
    data_mask = ~(ts_final.index.month.isin(months_analysed))*1 # keep only selected months for causality
    data_mask = np.repeat(data_mask[:, np.newaxis], len(ts_final.columns), axis=1)
    
    dataframe = pp.DataFrame(ts_final.values, var_names=ts_final.columns, mask=data_mask) # tigramite object
    
    
    parcorr = ParCorr(mask_type='y') # mask y
#     gpdc = GPDC(significance='analytic', gp_params=None, mask_type='y')
#     cmikn = CMIknn(significance='shuffle_test', knn=0.1, shuffle_neighbors=5, transform='ranks', mask_type='y')

    used_test = parcorr
    
    # make the analysis only for the parents of the selected pattern, and not for all connections
    selected_links = {}
    for j in range(data_mask.shape[1]):
        selected_links[j] = [(var, -lag) for var in range(data_mask.shape[1]) 
                             for lag in range(t_min, t_max + 1)]
#         if j ==0 :
#             selected_links[j] = [(var, -lag) for var in range(data_mask.shape[1]) 
#                                  for lag in range(t_min, t_max + 1)]
#         else:
#             selected_links[j] = [] 
            
    pcmci_cond_test = PCMCI(dataframe=dataframe, cond_ind_test=used_test, verbosity=0) # create the final object
    
    results = pcmci_cond_test.run_pcmci(tau_min=t_min, tau_max=t_max, pc_alpha=None, # get the causal relations
                                        alpha_level = 0.01, selected_links=selected_links)
    
    results['var_names'] = ts_final.columns
    
    return results

In [15]:
all_pat = ts_patterns.cluster.values 
pool = multiprocessing.Pool() # object for multiprocessing
caus_tigr = list(tqdm.tqdm(pool.imap(causality_tigramite, all_pat), total=len(all_pat), position=0))
pool.close(); pool.join()
caus_tigr = {i:j for i, j in zip(all_pat, caus_tigr)}

del(all_pat, pool)

In [16]:
# get max value of "autocorrelations" and "crosscorrelations"
caus_tigr_non_diag = {i: j['val_matrix'].copy() for i, j in caus_tigr.items()}
for i_pat in caus_tigr_non_diag:
    for i in range(caus_tigr_non_diag[i_pat].shape[2]):
        np.fill_diagonal(caus_tigr_non_diag[i_pat][:,:,i], 0)
        
max_cross = [np.abs(caus_tigr_non_diag[i]).max() for i in pat_names_short]
max_cross = np.ceil(max(max_cross)*20)/20
max_auto = [np.max(np.abs(np.diagonal(caus_tigr[i]['val_matrix']))) for i in pat_names_short]
max_auto = np.ceil(max(max_auto)*20)/20
max_auto = max(0.5, max_auto)

del(i_pat, i)

In [17]:
sns.set_context('paper', font_scale=2)

fig, axes = plt.subplots(3, 3, figsize=(18, 18))
axes = axes.flatten()

c_map_ax1 = fig.add_axes([0.15, 0.05, 0.3, 0.02])
ticks_auto = np.linspace(0, max_auto, 5)
cbar1 = ColorbarBase(c_map_ax1, orientation='horizontal', cmap='OrRd', ticks=ticks_auto, 
                     norm=Normalize(vmin=0, vmax=max_auto), label='auto-MCI')
c_map_ax2 = fig.add_axes([0.55, 0.05, 0.3, 0.02])
ticks_cross = [-max_cross*2/2.5, -max_cross/2.5, 0, max_cross/2.5, max_cross*2/2.5]
cbar2 = ColorbarBase(c_map_ax2, orientation='horizontal', cmap='RdBu_r', ticks=ticks_cross, 
                     norm=Normalize(vmin=-max_cross, vmax=max_cross), label='cross-MCI')

for i_c, i in enumerate(pat_names_short):
    
    # keep only the climatic indices that are causaly connected with the Mediterranean pattern of interest
    kep_ind = (caus_tigr[i]['graph']=='')[:, 0, :].sum(axis=(1))
    kep_ind = kep_ind != (t_max-t_min+2)
    
    # remove all cross links that are not from the indicators to the pattern
    matrix_used = caus_tigr[i]['val_matrix'].copy()
    graph_used = caus_tigr[i]['graph'].copy()
    for j in range(1, graph_used.shape[0]):
        graph_used[np.delete(np.arange(graph_used.shape[0]), j), j, :] = ''
    
    caus_plot = tp.plot_graph(
        val_matrix=matrix_used[kep_ind, :][:, kep_ind],
        graph=graph_used[kep_ind, :][:, kep_ind],
        var_names=caus_tigr[i]['var_names'][kep_ind],
        cmap_edges='RdBu_r', vmin_edges=-max_cross, vmax_edges=max_cross, edge_ticks=max_cross*2/5,
        cmap_nodes='OrRd', vmin_nodes=0, vmax_nodes=max_auto, node_ticks=max_auto/5,
        node_label_size=20,# int, optional (default: 10)
        link_label_fontsize=20,# : int, optional (default: 6)    
        fig_ax = (fig, axes[i_c]), show_colorbar=False,
        )
    
plt.subplots_adjust(left=0, bottom=0.1, right=1, top=.97, wspace=0.1, hspace=0.1) 
fig.savefig(f'{output_loc}Causality_{temp_aggr}{temp_resolution}.png', dpi=600)
del(fig, axes, c_map_ax1, ticks_auto, cbar1, c_map_ax2, ticks_cross, cbar2, i_c, i, kep_ind, matrix_used,
    graph_used, j, caus_plot)

In [18]:
sns.set_context('paper', font_scale=1)

for i in pat_names_short[:]:
    # keep only the climatic indices that are causaly connected with the Mediterranean pattern of interest
    kep_ind = (caus_tigr[i]['graph']=='')[:, 0, :].sum(axis=(1))
    kep_ind = kep_ind != (t_max-t_min+2)
    
    # remove all cross links that are not from the indicators to the pattern
    matrix_used = caus_tigr[i]['val_matrix'].copy()
    graph_used = caus_tigr[i]['graph'].copy()
    for j in range(1, graph_used.shape[0]):
        graph_used[np.delete(np.arange(graph_used.shape[0]), j), j, :] = ''
    
    caus_plot = tp.plot_graph(
        val_matrix=matrix_used[kep_ind, :][:, kep_ind],
        graph=graph_used[kep_ind, :][:, kep_ind],
        var_names=caus_tigr[i]['var_names'][kep_ind],
        link_colorbar_label='cross-MCI', vmin_edges=-max_cross, vmax_edges=max_cross, edge_ticks=max_cross*2/5,
        node_colorbar_label='auto-MCI', vmin_nodes=0, vmax_nodes=max_auto, node_ticks=max_auto/5,
        figsize=(4, 4),
        save_name=f'{output_loc}Causality_{i}_{temp_aggr}{temp_resolution}.png'
        )
    
    if i!=pat_names_short[-1]: plt.close()
        
del(i, kep_ind, matrix_used, graph_used, j, caus_plot, max_cross, max_auto)

Multilinear with Feature Selection (sklearn)

In [19]:
def ols_incl_pattern(input_data):
    
    type_used, ind_used, shift_used, cl_used = input_data
    
    ts_pat_df = ts_patterns_final.sel(Type=type_used, Indicator=ind_used, cluster=cl_used).to_dataframe('Pat')
    
    ts_ind_df = ts_indices_final.sel(Type=type_used, index=indices_used_feat_sel)
    ts_ind_df = ts_ind_df.assign_coords({'index': [indices_used_final_names[i] for i in ts_ind_df.index.values]})
    ts_ind_df = ts_ind_df.to_dataframe('type_used')
    ts_ind_df = ts_ind_df.pivot_table(index='time', columns='index', values='type_used')
    
    ts_ind_df = pd.concat([ts_pat_df[['Pat']], ts_ind_df], axis=1) # use also the actual pattern timeseries

    ts_ind_df = ts_ind_df.shift(shift_used).dropna()

    ts_ind_df = ts_ind_df[ts_ind_df.index.month.isin(months_analysed)]
    ts_pat_df = ts_pat_df.loc[ts_ind_df.index]
    
#     feat_sel_test = mutual_info_regression(ts_ind_df, ts_pat_df[['Pat']])
    feat_sel_test = f_regression(ts_ind_df, ts_pat_df[['Pat']])[0]
    feat_sel_test = np.abs(feat_sel_test) # make absolute because sometimes negative values occur
    feat_sel_test /= np.max(feat_sel_test)
    
    feat_sel = pd.DataFrame({'Feat_sel_test': feat_sel_test}, index=ts_ind_df.columns)
    feat_sel = feat_sel.sort_values('Feat_sel_test', ascending=False)

    final_ind_used = feat_sel.index[:7] # keep always at least 7 instances
    if len(feat_sel.query('Feat_sel_test>0.05').index)>7:
        final_ind_used = feat_sel.query('Feat_sel_test>0.05').index
    final_ind_used = final_ind_used[:10] # no more than 10 instances (in case too many the Feat_sel_test>0.05)
    c = [comb for i in range(len(final_ind_used)) for comb in combinations(final_ind_used, i + 1)]
    
    ols_col_names = ['const', 'Pat']+list(indices_used_final_names.values())
    ols_col_names += ['corr', 'corr_adj', 'n_features', 'pattern_used']
    ols_results = pd.DataFrame(columns=ols_col_names, index=range(len(c)))
    for i_index, i_c in enumerate(c):
        res_comb = OLS(ts_pat_df['Pat'], add_constant(ts_ind_df[list(i_c)])).fit()    
        ols_results.loc[i_index, ['const']+list(i_c)] = res_comb.params
        ols_results.loc[i_index, 'corr'] = res_comb.rsquared**.5
        ols_results.loc[i_index, 'corr_adj'] = res_comb.rsquared_adj**.5
        ols_results.loc[i_index, 'n_features'] = len(i_c)
        ols_results.loc[i_index, 'pattern_used'] = 'Pat' in list(i_c)

    ols_results = ols_results.astype('float')
    ols_results = xr.DataArray(ols_results).rename({'dim_0': 'combination', 'dim_1': 'ols_param'})
    ols_results = ols_results.assign_coords({'cluster': cl_used})
    ols_results = ols_results.assign_coords({'lag': -shift_used, 'Type': type_used, 'Indicator': ind_used})

    ols_results = ols_results.expand_dims(['lag', 'Type', 'Indicator', 'cluster']) # for xr.combine_by_coords
    
    return ols_results

In [20]:
combs = list(product(ts_patterns_final.Type.values, ts_patterns_final.Indicator.values, 
                     range(ols_min, ols_max), ts_patterns_final.cluster.values))
pool = multiprocessing.Pool() # object for multiprocessing
multivar_corr_feat_sel_all = list(tqdm.tqdm(pool.imap(ols_incl_pattern, combs), total=len(combs), position=0))
pool.close(); pool.join()

multivar_corr_feat_sel_all = xr.combine_by_coords(multivar_corr_feat_sel_all)
file_name = f'{output_loc}MultiVarCorrelationsIndicesFeatSel_InclPat{temp_aggr}{temp_resolution}.nc'
multivar_corr_feat_sel_all.to_netcdf(file_name)

del(combs, pool, file_name)

In [21]:
def max_corr(input_data):
    
    multi_corr_used, type_used, ind_used, lag_used, cl_used = input_data
    
    data = multi_corr_used.sel(lag=lag_used, Type=type_used, Indicator=ind_used, cluster=cl_used)

    data = data.assign_coords({'combination': data.sel(ols_param='n_features').values})
    
    # nans cause problem, so change them to 0
    combs_modified = [i if ~np.isnan(i) else 0 for i in data.combination.values]
    data = data.assign_coords({'combination': combs_modified})
    
#     reslt = data.sel(ols_param='corr_adj').groupby('combination').max()
    
    data_corr_adj = data.sel(ols_param='corr_adj').fillna(-999) # again convert nan, otherwise function breaks

    reslt = data_corr_adj.groupby('combination').apply(lambda c: c.argmax(dim='combination'))
    reslt = [(i[1].isel(combination=j.values)) for i, j in zip(data.groupby('combination'), reslt)]
    reslt = xr.concat(reslt, dim='combination')
    
    reslt_aux = (reslt.isel(combination=0)*0).assign_coords({'combination': -1}) # in case only comb=0 exists
    reslt = xr.concat([reslt_aux, reslt], dim='combination')

    reslt = reslt.isel(combination=reslt.combination!=0) # NaN's combs are changed to 0, so should be removed
    reslt = reslt.where(reslt != -999) # convert the -999 back to nan
    reslt = reslt.assign_coords({'lag': lag_used, 'Type': type_used, 'Indicator': ind_used, 'cluster': cl_used})

    reslt = reslt.expand_dims(['lag', 'Type', 'Indicator', 'cluster']) # for xr.combine_by_coords

    return reslt

In [22]:
max_corr_feat_sel = []

test_full = multivar_corr_feat_sel_all
test_no_param = multivar_corr_feat_sel_all.where(multivar_corr_feat_sel_all.sel(ols_param='pattern_used')==0)
test_param = multivar_corr_feat_sel_all.where(multivar_corr_feat_sel_all.sel(ols_param='pattern_used')==1)

for test, i_name in zip([test_full, test_no_param, test_param], ['All', 'NoSame', 'YesSame']):
    
    combs = list(product([test], multivar_corr_feat_sel_all.Type.values,
                     multivar_corr_feat_sel_all.Indicator.values, 
                     multivar_corr_feat_sel_all.lag.values, multivar_corr_feat_sel_all.cluster.values))
    pool = multiprocessing.Pool() # object for multiprocessing
    max_corr_feat_sel_i = list(tqdm.tqdm(pool.imap(max_corr, combs), total=len(combs), position=0))
    pool.close(); pool.join()

    max_corr_feat_sel_i = xr.combine_by_coords(max_corr_feat_sel_i)
    max_corr_feat_sel_i = max_corr_feat_sel_i.assign_coords({'Subset': i_name})
    max_corr_feat_sel.append(max_corr_feat_sel_i)

max_corr_feat_sel = xr.concat(max_corr_feat_sel, dim='Subset')

# drop the auxilary combination=-1
max_corr_feat_sel = max_corr_feat_sel.isel(combination=max_corr_feat_sel.combination>0)

del(test_full, test_no_param, test_param, test, i_name, combs, max_corr_feat_sel_i)

In [23]:
sns.set_context('paper', font_scale=2)

data_plot = max_corr_feat_sel.sel(Type=main_type_used, Indicator=main_indicator_used, 
                                  cluster=pat_names_short, ols_param='corr_adj', Subset='All')
data_plot = data_plot.to_dataframe('Corr').reset_index()
data_plot.combination = data_plot.combination.astype(int)
data_plot = data_plot.query('combination<=11 and lag<0')
data_plot = data_plot.rename(columns={'combination': 'Predictors used', 'cluster': 'Pattern'})
data_plot = sns.relplot(data=data_plot, x='lag', y='Corr', kind='line', hue='Predictors used',   
                        col='Pattern', col_wrap=3, 
                        palette=sns.color_palette(n_colors=len(data_plot['Predictors used'].unique())))

data_plot.fig.savefig(f'{output_loc}MultiCorrelationsIndices_{temp_aggr}{temp_resolution}.png', dpi=600)
del(data_plot)

In [24]:
sns.set_context('paper', font_scale=2)

data_plot = max_corr_feat_sel.sel(Type=main_type_used, Indicator=main_indicator_used, 
                                  cluster=pat_names_short, ols_param='corr_adj', Subset='All')
data_plot_diff = data_plot.diff('combination').assign_coords({'combination': data_plot.combination.values[1:]-1})
data_plot = data_plot_diff/data_plot*100
data_plot = data_plot.assign_coords({'combination': data_plot.combination.values+1})
data_plot = data_plot.to_dataframe('Improvement (%)').reset_index()
data_plot = data_plot.query('combination<=11 and lag<0')
data_plot = data_plot.rename(columns={'combination': 'Predictors used', 'cluster': 'Pattern'})
data_plot = sns.relplot(data=data_plot, x='Predictors used', y='Improvement (%)', kind='line', hue='lag', 
                        col='Pattern',  
                        col_wrap=3, palette=sns.color_palette(n_colors=len(data_plot.lag.unique())))

data_plot.fig.savefig(f'{output_loc}MultiCorrelationsIndicesPlotElbow_{temp_aggr}{temp_resolution}.png', dpi=600)
del(data_plot, data_plot_diff)

In [25]:
sns.set_context('notebook', font_scale=1)

used_combs = np.arange(1,7)
used_par = max_corr_feat_sel.sel(Type=main_type_used, Indicator=main_indicator_used, lag=[-5, -4, -3, -2, -1], 
                                  cluster=pat_names_short[:], combination=used_combs, Subset='All')
used_par = used_par.dropna(dim='ols_param', how='all')
used_par = (~np.isnan(used_par)).sum('combination')
used_par = used_par.where(used_par!=0)/len(used_combs)
used_par = used_par.to_dataframe('S').reset_index()
used_par = used_par.query("ols_param not in ['const', 'corr', 'corr_adj', 'n_features', 'pattern_used']")
used_par = used_par.rename(columns={'ols_param': 'Predictor', 'cluster': 'Pattern'})

def draw_heatmap(*args, **kwargs):
    data = kwargs.pop('data')
    d = data.pivot(index=args[1], columns=args[0], values=args[2])
    ax = sns.heatmap(d, cmap='YlOrRd', vmax=1, vmin=0, cbar=False, annot=True, fmt='.1f')
    ax.set_yticks(np.arange(len(d.index))+.5)
    ax.set_yticklabels(d.index)

fg = sns.FacetGrid(used_par, col='Pattern', col_wrap=3)
fg.fig.set_size_inches(12, 18)
fg.map_dataframe(draw_heatmap, 'lag', 'Predictor', 'S')
fg.fig.savefig(f'{output_loc}MultiCorrelationsIndicesImportance_{temp_aggr}{temp_resolution}.png', dpi=600)
del(fg, draw_heatmap, used_par, used_combs)