# Imports & Constants

In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import ast
import matplotlib.pyplot as plt
import numpy as np
import pickle
import pandas as pd
import random
import seaborn as sns
import statsmodels.api as sm

In [4]:
from random import choices
from scipy import stats
from tqdm.notebook import tqdm
from matplotlib.colors import ListedColormap

In [5]:
from jupyter_utils import style, mean_std, display_test, display_group_test, \
    scatter_annotate, show_corrtest_mask_corr, pointplot, pointplot_horizontal, add_grey, \
    prep_horizontal_pointplot_errobar_data, map_model, prep_LM_pointplot, draw_sample_with_replacement, t_test
from ortogonolize_utils import compute_coefficient

In [6]:
import warnings
warnings.filterwarnings(action='ignore', category=np.VisibleDeprecationWarning)
warnings.filterwarnings(action='ignore', message='All-NaN slice encountered')
warnings.filterwarnings(action='ignore', message='Precision loss occurred in moment calculation due to catastrophic cancellation. This occurs when the data are nearly identical. Results may be unreliable.')
warnings.filterwarnings(action='ignore', message='Mean of empty slice')
warnings.filterwarnings(action='ignore', category=stats.ConstantInputWarning)
warnings.filterwarnings(action='ignore', message='indexing past lexsort depth may impact performance')

In [7]:
sns.set_theme(style="whitegrid")

In [8]:
PATH = '/Users/galina.ryazanskaya/Downloads/thesis?/code?/processed_values'

In [9]:
PATH_FIG = '/Users/galina.ryazanskaya/Downloads/thesis?/figures/de/'

# Load data

In [10]:
combined_data_averaged = pd.read_csv('/Users/galina.ryazanskaya/Downloads/thesis?/code?/processed_values/de_averaged.csv', index_col=0)

In [11]:
combined_data_averaged.columns = [c if '(' not in c else ast.literal_eval(c) for c in combined_data_averaged.columns]

In [12]:
combined_data_all = pd.read_csv('/Users/galina.ryazanskaya/Downloads/thesis?/code?/processed_values/de_all.csv', index_col=0, header=[0, 1, 2])

In [13]:
TASKS = ['anger', 'fear', 'happiness', 'sadness']

In [14]:
def task_data(df, task, keep_target=True, fill_synt=True):
    subset = df[task].dropna(axis=0, how='any')
    if task == 'fear' and 'KG_018' in subset.index:
        subset = subset.drop(['KG_018'])   # incorrect task
    if fill_synt:
        subset['syntactic'] = subset['syntactic'].fillna(0.0)
    if keep_target:
        subset = pd.concat([subset, df['target'].loc[subset.index]], axis=1)
    return subset

In [15]:
def aplly_to_all_tasks(df, f, tasks=TASKS, to_df=True, *args, **kwargs):
    res = {}
    for task in tasks:
        data = task_data(df, task)
        res[task] = f(data, *args, **kwargs)
    if to_df:
        if all(isinstance(v, pd.Series) for v in res.values()):
            return pd.DataFrame(res)
        elif all(isinstance(v, pd.DataFrame) for v in res.values()):
            return pd.concat(list(res.values()), keys=list(res.keys()), names=['task'], axis=1)
        else:
            return res
    return res

# Bootstrap

In [16]:
scale_cols = ['saps_total',
             'sans_total',
             'panss_pos',
             'panss_neg',
             'panss_o',
             'panss_total']

In [17]:
cols_LM = [col for col in combined_data_averaged.columns if col[0] == 'LM']
cols_synt = [col for col in combined_data_averaged.columns if col[0] == 'syntactic']
cols_lex = [col for col in combined_data_averaged.columns if col[0] == 'lexical']
cols_graph = [col for col in combined_data_averaged.columns if col[0] == 'graph']
cols_av = cols_LM + cols_synt + cols_lex + cols_graph

In [18]:
## 1 sample with replacement
## 1.1 compute r for each scale for each iteration
## 1.2 compute t test for each iteration

def bootstrap(df, cols_av, scale_cols, N, col_to_correct_for=('syntactic', 'mean_sent_len'), group=None):
    dict_scales_sapmles = {k: {scale: {metric: [] for metric in cols_av} for scale in scale_cols} \
                           for k in ('sample_raw', 'r', 't', 'r_corr_w_control')}
    for i in tqdm(range(N)):
        sample = draw_sample_with_replacement(df, seed=i)
        for scale in scale_cols:
            for col in cols_av:
                if group:
                    t_test_res = t_test(sample, col, group)
                    dict_scales_sapmles['t'][scale][col].append(t_test_res)
                    
                r_raw = compute_coefficient(sample, scale, col)[0]
                dict_scales_sapmles['sample_raw'][scale][col].append(r_raw)
                
                droped = sample.dropna(subset=[col, scale])
                r = stats.pearsonr(droped[col], droped[scale])[0]
                dict_scales_sapmles['r'][scale][col].append(r)
                
                if col != col_to_correct_for:
                    
                    droped_c = sample.dropna(subset=[col, col_to_correct_for])
                    r_c = stats.pearsonr(droped_c[col], droped_c[col_to_correct_for])[0]
                    dict_scales_sapmles['r_corr_w_control'][scale][col].append(r_c)
                    
    return dict_scales_sapmles

**expensive to compute**

In [19]:
# N = 1000
# dict_scales_sapmles = bootstrap(combined_data_averaged, cols_av, scale_cols, N, col_to_correct_for=('syntactic', 'mean_sent_len'), group='group')

In [20]:
# reform = {(scale, measure): dict_scales_sapmles[measure][scale] for scale in scale_cols for measure in dict_scales_sapmles}

In [21]:
# with open('processed_values/de_scales_samples_wo_o.pickle', 'wb') as f:
#     pickle.dump(reform, f)

In [22]:
with open('processed_values/de_scales_samples_wo_o.pickle', 'rb') as f:
    reform = pickle.load(f)

In [23]:
reformed_d = pd.DataFrame(reform)

### Bootstrap for each task

In [24]:
def bootstrap_tasks(df, cols_av, scale_cols, N, col_to_correct_for=('syntactic', 'mean_sent_len'), group=None):
    res_c = ('sample_raw', 'r', 't', 'r_corr_w_control')
    dict_scales_sapmles = {k: {scale: {metric: [] for metric in cols_av} for scale in scale_cols} for k in res_c}
    for i in tqdm(range(N)):
        sample = draw_sample_with_replacement(df, seed=i)
        scale_independent = {k: {metric: [] for metric in cols_av} for k in ('t', 'r_corr_w_control')}
        for col in cols_av:
            if group:
                t_test_res = t_test(sample, col, group)
                scale_independent['t'][col].append(t_test_res)
            if col != col_to_correct_for:
                dropped_c = sample.dropna(subset=[col, col_to_correct_for])
                r_c = stats.pearsonr(dropped_c[col], dropped_c[col_to_correct_for])[0]
                scale_independent['r_corr_w_control'][col] = r_c
            for scale in scale_cols:
                scale_ = ('target', scale)
                r_raw = compute_coefficient(sample, scale_, col)[0]
                dict_scales_sapmles['sample_raw'][scale][col].append(r_raw)
                
                dropped = sample.dropna(subset=[col, scale_])
                r = stats.pearsonr(dropped[col], dropped[scale_])[0]
                dict_scales_sapmles['r'][scale][col].append(r)
                               
                for k in scale_independent:
                    dict_scales_sapmles[k][scale][col].append(scale_independent[k][col])
                
    return dict_scales_sapmles

In [25]:
scale_cols

['saps_total',
 'sans_total',
 'panss_pos',
 'panss_neg',
 'panss_o',
 'panss_total']

**expensive to compute**

In [26]:
# N = 1000
# dict_scales_sapmles_tasks = aplly_to_all_tasks(combined_data_all, bootstrap_tasks, cols_av=cols_av, scale_cols=scale_cols, N=N, col_to_correct_for=('syntactic', 'mean_sent_len'), group=('target', 'group'))

In [27]:
# reform_tasks = {(task, scale, measure): dict_scales_sapmles_tasks[task][measure][scale] for scale in scale_cols 
#                 for task in dict_scales_sapmles_tasks
#                 for measure in dict_scales_sapmles_tasks[task]}

In [28]:
# with open('processed_values/de_scales_samples_wo_o_tasks.pickle', 'wb') as f:
#     pickle.dump(reform_tasks, f)

In [29]:
with open('processed_values/de_scales_samples_wo_o_tasks.pickle', 'rb') as f:
    reform_tasks = pickle.load(f)

In [30]:
reformed_tasks = pd.DataFrame(reform_tasks)
reformed_tasks.columns.names = ['TASK', 'scale', 'measure']

# Plot & Analyze

In [31]:
figprms = {'syntactic': 
               {'subplot_size': (9, 4.5),
                'wspace': 0.25,
                'hspace': 0.125,
                'yt': 0.925
                }, 
           'LM': 
               {'subplot_size': (9, 5.5),
                'wspace': 0.275,
                'hspace': 0.125,
                'yt': 0.92
               }, 
           'lexical': 
               {'subplot_size': (9, 2),
                'wspace': 0.2,
                'hspace': 0.25,
                'yt': 0.925
               }, 
           'graph': 
               {'subplot_size': (9, 3.5), 
                'wspace': 0.3,
                'hspace': 0.125,
                'yt': 0.925
               }}

In [32]:
def get_fparams(m_type, n_sublots_height, n_sublots_width, figparams):
    subplot_size = figprms[m_type]['subplot_size']
    width = subplot_size[0] * n_sublots_width
    height = subplot_size[1] * n_sublots_height
    figsize = (width, height)
    wspace = figprms[m_type]['wspace']
    hspace = figprms[m_type]['hspace']
    yt = figprms[m_type]['yt']
    return figsize, wspace, hspace, yt

In [33]:
ORDERED_SCALES = ['panss_pos', 'panss_neg', 'panss_o', 'panss_total', 'saps_total', 'sans_total']

### Plot horizontal bar plots

In [34]:
def plot_horizontal_tasks(df, title, measure, m_type='syntactic', plot_abs=False, figparams=figprms):
    figsize, wspace, hspace, yt = get_fparams(m_type, 3, 2, figparams)
    fig, axes = plt.subplots(3, 2, figsize=figsize, sharex=True)
    fig.suptitle(title, y=yt+0.25)
    plt.subplots_adjust(wspace=wspace) #left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)
    
    ab = 'abs ' if plot_abs else ''
    
    axs = axes.flatten()

    for i, scale in enumerate(ORDERED_SCALES):
        ax = axs[i]
        d = prep_horizontal_pointplot_errobar_data(df[scale].loc[m_type], measure, plot_abs=plot_abs)
        pointplot_horizontal(d, x=measure, ax=ax)
        ax.set_title(f'{ab}{measure} {scale}')
    
    add_grey(axes)

    if plot_abs:
        for ax in axes.reshape(-1): 
            ax.set_xlabel('abs ' + measure);
    return fig

In [35]:
def plot_all(df, m_type='syntactic', measure='r', path=PATH_FIG, dpi=150, plot_abs=False, figparams=figprms):
    ab = 'abs_' if plot_abs else ''
    absolute_value = f' (absolute {measure} value)' if plot_abs else ''
    fig = plot_horizontal_tasks(df, title=f'cross-scale comparison for {m_type} metrics{absolute_value}', 
                                measure=measure, m_type=m_type, plot_abs=plot_abs, figparams=figprms)
    plt.savefig(f'{path}{m_type}/{ab}scale_r.png', dpi=dpi)
    plt.close(fig)
    
    figsize, wspace, hspace, yt = get_fparams(m_type, 1, 1, figparams)
    d_t = prep_horizontal_pointplot_errobar_data(df['panss_o'].loc[m_type], 't')
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    fig.suptitle(f'group difference (t-test) for {m_type} metrics')
    pointplot_horizontal(d_t, 't', ax=ax)
    add_grey(ax, r=2)
    plt.savefig(f'{path}{m_type}/t.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    d_c = prep_horizontal_pointplot_errobar_data(df['panss_o'].loc[m_type], 'r_corr_w_control')
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    fig.suptitle(f'correlation with mean sentence length for {m_type} metrics')
    pointplot_horizontal(d_c, 'r_corr_w_control', ax=ax)
    ax.set_xlabel('r');
    add_grey(ax)
    plt.savefig(f'{path}{m_type}/corr_len.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)

    figsize, wspace, hspace, yt = get_fparams(m_type, 1, 2, figparams)
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    fig.suptitle(f'group difference and correlation with mean sentence length for {m_type} metrics', y=yt+0.12)
    plt.subplots_adjust(wspace=wspace)
    pointplot_horizontal(d_t, x='t', ax=axes[0])
    axes[0].set_title('group difference (t-test)')
    pointplot_horizontal(d_c, x='r_corr_w_control', ax=axes[1])
    axes[1].set_title('correlation with mean sentence length')
    axes[1].set_xlabel('r');
    add_grey(axes[0], r=2)
    add_grey(axes[1])
    plt.savefig(f'{path}{m_type}/t_test_corr_len.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)

In [36]:
for m_type in reformed_d.index.unique(level=0):
    plot_all(reformed_d, m_type, plot_abs=True)
    plot_all(reformed_d, m_type, plot_abs=False)

### Plot vertical bar plots for LMs

In [37]:
order = ['bert', 'glove_tf', 'glove_avg', 'w2v_tf', 'w2v_avg']

In [38]:
def plot_LM_scales(df, title, measure='r', plot_abs=False, figsize=(18, 18), order=order):
    ab = 'abs ' if plot_abs else ''
    absolute_value = f' (absolute {measure} value)' if plot_abs else ''
    
    fig, axes = plt.subplots(3, 2, figsize=figsize, sharey=True)
    fig.suptitle(title + absolute_value, y=0.91)
    plt.subplots_adjust(wspace=0.075)
    
    axs = axes.flatten()
    for i, scale in enumerate(ORDERED_SCALES):
        ax = axs[i]
        d = prep_LM_pointplot(df.loc['LM', scale], measure, plot_abs=plot_abs)
        pointplot(d, x='model', y=measure, hue='metric', ax=ax, order=order, use_errorbar=True)
        ax.set_title(f'{ab}{measure} {scale}')
    
    add_grey(axes, line_dir='h')
    if plot_abs:
        for ax in axes.reshape(-1): 
            ax.set_ylabel('abs ' + measure);
    return fig

In [39]:
def plot_all_LM(df, path=PATH_FIG, dpi=150, plot_abs=False, figsize=(9, 9), measure='r'):
    
    ab = 'abs_' if plot_abs else ''
    
    fig = plot_LM_scales(df, 'cross-scale comparison of LM metrcis across models', plot_abs=plot_abs)
    plt.savefig(f'{path}LM/model/{ab}scale_r.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    d_lm_t = prep_LM_pointplot(df.loc['LM', 'panss_o'], 't')
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    fig.suptitle('group difference (t-test) for LM metrcis across models')
    pointplot(d_lm_t, x='model', y='t', hue='metric', ax=ax, order=order, use_errorbar=True)
    add_grey(ax, r=2, line_dir='h')
    plt.savefig(f'{path}LM/model/t.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    absolute_value = f' (absolute {measure} value)' if plot_abs else ''
    d_lm_c = prep_LM_pointplot(df.loc['LM', 'panss_o'], 'r_corr_w_control', plot_abs=plot_abs)
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    fig.suptitle('correlation with mean sentence length for LM metrcis across models' + absolute_value)
    pointplot(d_lm_c, x='model', y='r_corr_w_control', hue='metric', ax=ax, order=order, use_errorbar=True)
    ax.set_xlabel(f'{ab}r')
    add_grey(ax, line_dir='h')
    plt.savefig(f'{path}LM/model/{ab}corr_len.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    

    fig, axes = plt.subplots(1, 2, figsize=(18, 6))
    fig.suptitle('group difference and correlation with mean sentence length of LM metrcis across models')

    pointplot(d_lm_t, x='model', y='t', hue='metric', ax=axes[0], order=order, use_errorbar=True)
    axes[0].set_title('group difference (t-test)')

    pointplot(d_lm_c, x='model', y='r_corr_w_control', hue='metric', ax=axes[1], order=order, use_errorbar=True)
    axes[1].set_title('correlation with mean sentence length')
    axes[1].set_xlabel('r')
    add_grey(axes[0], r=2, line_dir='h')
    add_grey(axes[1], line_dir='h')
    plt.savefig(f'{path}LM/model/t_test_corr_len.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)

In [40]:
plot_all_LM(reformed_d, plot_abs=True)
plot_all_LM(reformed_d, plot_abs=False)

## Cross-metric comparison

### R squared

In [41]:
def select_ok_metrics(row, t=1, r=0.3, rc=0.3):
    ok_corr = False
    ok_t = abs(row['saps_total', 't']) >= t
    ok_len = pd.isna(row['saps_total', 'r_corr_w_control']) or abs(row['saps_total', 'r_corr_w_control']) <= rc
    for scale in scale_cols:
        if abs(row[scale, 'r']) >= r:
            ok_corr = True 
            break
    return (ok_corr or ok_t) and ok_len

In [42]:
def select_ok_metrics_for_one_scale(row, scale, r=0.3, rc=0.3):
    ok_corr = False
    ok_len = pd.isna(row['panss_total', 'r_corr_w_control']) or abs(row['panss_total', 'r_corr_w_control']) <= rc
    ok_corr = abs(row[scale, 'r']) >= r
    return ok_corr and ok_len

In [43]:
def select_control_corr_ms_better_than_len(row, row_len, r=0.3, rc=0.3):
    if select_ok_metrics(row, r=r, rc=rc):
        return False
    ok_len = pd.isna(row['saps_total', 'r_corr_w_control']) or abs(row['saps_total', 'r_corr_w_control']) <= rc
    if ok_len:
        return False
    else:
        for scale in scales_:
            if abs(row[scale, 'r']) >= abs(row_len[scale, 'r']):
                return True
        return False

In [44]:
def select_control_corr_ms_better_than_len_for_one_scale(row, row_len, scale, r=0.3, rc=0.3):
    if select_ok_metrics(row, r=r, rc=rc):
        return False
    ok_len = pd.isna(row['saps_total', 'r_corr_w_control']) or abs(row['saps_total', 'r_corr_w_control']) <= rc
    if ok_len:
        return False
    else:
        if abs(row[scale, 'r']) >= abs(row_len[scale, 'r']) and abs(row[scale, 'r']) > r:
            return True
        return False

In [45]:
def select_bad_len_metrics(row, rc=0.3):
    ok_len = pd.isna(row['panss_total', 'r_corr_w_control']) or abs(row['panss_total', 'r_corr_w_control']) <= rc
    return not ok_len

In [46]:
median_d = reformed_d.applymap(np.nanmedian)

In [47]:
idxs_scale = {}
idxs_bad_bet_len = {}
for scale in scale_cols:
    ids = median_d[median_d.apply(lambda x: select_ok_metrics_for_one_scale(x, scale=scale), axis=1)].index
    idxs_scale[scale] = ids
    idsl = median_d[median_d.apply(lambda x: select_control_corr_ms_better_than_len_for_one_scale(x, 
                                                                                                  row_len=median_d.loc[('syntactic', 'mean_sent_len')], 
                                                                                                  scale=scale), axis=1)].index
    idxs_bad_bet_len[scale] = idsl

In [48]:
good_ms = sorted(list(set([y for x in idxs_scale.values() for y in x])))

In [49]:
bad_ms_better_than_len = sorted(list(set([y for x in idxs_bad_bet_len.values() for y in x])))

In [50]:
bad_ms = median_d[median_d.apply(select_bad_len_metrics, axis=1)].index

In [51]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6), sharex=True)
fig.suptitle('cross-type metric comparison for correlation with mean sent len')
measure = 'r_corr_w_control'
plot_abs = False


d = prep_horizontal_pointplot_errobar_data(reformed_d['panss_pos'].loc[bad_ms], measure, plot_abs=plot_abs)
pointplot_horizontal(d, x=measure, ax=ax)
ax.set_xlabel('r')
add_grey(ax)
plt.savefig(f'{PATH_FIG}/compare_corr_len.png', dpi=150, bbox_inches = 'tight')
plt.close(fig)

In [52]:
ms_to_plot = sorted(bad_ms_better_than_len+ good_ms)

In [53]:
len(ms_to_plot)

21

In [54]:
len(good_ms)

4

In [55]:
{s: len(idxs_scale[s]) + len(idxs_bad_bet_len[s]) for s in idxs_scale}

{'saps_total': 5,
 'sans_total': 10,
 'panss_pos': 1,
 'panss_neg': 12,
 'panss_o': 14,
 'panss_total': 16}

In [56]:
{s: len(idxs_scale[s]) for s in idxs_scale}

{'saps_total': 1,
 'sans_total': 3,
 'panss_pos': 0,
 'panss_neg': 3,
 'panss_o': 2,
 'panss_total': 3}

In [57]:
def sort_index(idxs):
    ms = sorted(idxs)
    for s in ('syntactic: mean_sent_len', 'mean_sent_len', ('syntactic', 'mean_sent_len')):
        if s in ms:
            ms.remove(s)
            ms.append(s)
    return ms 

In [58]:
def map_marker(m, scale, idxs_scale):
    if m in idxs_scale[scale]:
        return 'o'
    else: 
        return 'x'

In [59]:
def add_len_lines(ax, median_d, scale):
    ax.axvline(median_d.loc[('syntactic', 'mean_sent_len')][scale, 'r'], linestyle='--')
    ax.axvline(-median_d.loc[('syntactic', 'mean_sent_len')][scale, 'r'], linestyle='--')

In [60]:
def plot_one_scale(ax, scale, idxs_scale, reformed_d, ms_to_plot, measure, plot_abs):
    markers = [map_marker(m, scale, idxs_scale) for m in ms_to_plot]
    if ('syntactic', 'mean_sent_len') in idxs_scale[scale]:
        add_len_lines(ax, median_d, scale)
    d = prep_horizontal_pointplot_errobar_data(reformed_d[scale].loc[ms_to_plot], measure, plot_abs=plot_abs)
    pointplot_horizontal(d, x=measure, ax=ax, markers=markers)
    ax.set_title(f'{measure} {scale}')

In [61]:
fig, axes = plt.subplots(3, 2, figsize=(18, 14), sharex=True)
fig.suptitle('cross-scale cross-type metric comparison', y=0.915)
plt.subplots_adjust(wspace=0.4, hspace=0.1) #left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)
measure = 'r'
plot_abs = False
idxs = sort_index(ms_to_plot)
axs = axes.flatten()

for i, scale in enumerate(ORDERED_SCALES):
    ax = axs[i]
    plot_one_scale(ax, scale, idxs_scale, reformed_d, ms_to_plot, measure, plot_abs)
    
add_grey(axes)
plt.savefig(f'{PATH_FIG}/compare_r.png', dpi=150, bbox_inches = 'tight')
plt.close(fig)

### t-test

In [62]:
def select_ok_metrics_t(row):
    q25 = np.quantile(row['saps_total', 't'], 0.25)
    q75 = np.quantile(row['saps_total', 't'], 0.75)
    return q25 > 0 or q75 < 0

In [63]:
idx_comp_t = reformed_d[reformed_d.apply(select_ok_metrics_t, axis=1)].index

In [64]:
fig, axes = plt.subplots(1, 2, figsize=(18, 4))
plt.subplots_adjust(wspace=0.4)
fig.suptitle('cross-type metric comparison for group difference and correlation with mean sentence length', y=1)

dcr = reformed_d['sans_total'].loc[idx_comp_t].drop([('syntactic', 'mean_sent_len')])
d_c = prep_horizontal_pointplot_errobar_data(dcr, 'r_corr_w_control', plot_abs=False)
d_t = prep_horizontal_pointplot_errobar_data(reformed_d['sans_total'].loc[idx_comp_t], 't', plot_abs=False)


pointplot_horizontal(d_t, x='t', ax=axes[0])
axes[0].set_title('t')
axes[0].set_title('group difference (t-test)')

pointplot_horizontal(d_c, x='r_corr_w_control', ax=axes[1])
axes[1].set_title('correlation with mean sentence length')
axes[1].set_xlabel('r');
add_grey(axes[0], r=2)
add_grey(axes[1])
plt.savefig(f'{PATH_FIG}/compare_t.png', dpi=150, bbox_inches = 'tight')
plt.close(fig)

### average LM model / metric performance medians across scales

In [65]:
scales_ = ('sans_total', 'saps_total', 'panss_pos', 'panss_neg', 'panss_o', 'panss_total')
sc_ind_ = ('t', 'r_corr_w_control')
models_ = ('bert', 'glove_tf', 'glove_avg', 'w2v_tf', 'w2v_avg')
metrics_ = ('cgcoh', 'gcoh', 'lcoh', 'scoh', 'sprob', 'pppl')

In [66]:
def mean_model_metric_medians(median_df, leave_out=()):
    resp_d_model = pd.DataFrame(columns=[f'{sc} abs r' for sc in scales_] + list(sc_ind_), index=models_)
    resp_d_metric = pd.DataFrame(columns=[f'{sc} abs r' for sc in scales_] + list(sc_ind_), index=metrics_)
    for scale in scales_:
        ex_d = prep_LM_pointplot(median_df.loc['LM', scale], col='r', use_errorbar=False, plot_abs=True)
        for model in models_:
            leave_out_ = ex_d[ex_d['model'] == model]
            leave_out_ = leave_out_[~leave_out_['metric'].isin(leave_out)]
            resp_d_model.loc[model, f'{scale} abs r'] = np.nanmean(leave_out_['r'])
        for metric in metrics_:
            resp_d_metric.loc[metric, f'{scale} abs r'] = np.nanmean(ex_d[ex_d['metric'] == metric]['r'])
    for sc_ind in sc_ind_:
        for model in models_:
            leave_out_ = ex_d[ex_d['model'] == model]
            leave_out_ = leave_out_[~leave_out_['metric'].isin(leave_out)]
            resp_d_model.loc[model, sc_ind] = np.nanmean(leave_out_[sc_ind])
        for metric in metrics_:
            resp_d_metric.loc[metric, sc_ind] = np.nanmean(ex_d[ex_d['metric'] == metric][sc_ind])
    return resp_d_model, resp_d_metric


#### only including cosine-similarity based metrics

In [67]:
resp_d_model, resp_d_metric = mean_model_metric_medians(median_d, leave_out=('pppl', 'sprob'))

In [68]:
style(resp_d_model)

Unnamed: 0,sans_total abs r,saps_total abs r,panss_pos abs r,panss_neg abs r,panss_o abs r,panss_total abs r,t,r_corr_w_control
bert,0.236577,0.044427,0.070759,0.225894,0.252356,0.226347,0.512648,0.16449
glove_tf,0.165831,0.208567,0.234527,0.192332,0.163751,0.198657,0.39451,0.423193
glove_avg,0.223965,0.196259,0.201146,0.234413,0.184835,0.232335,0.39749,0.559332
w2v_tf,0.280048,0.303887,0.259121,0.312481,0.255188,0.325277,0.733133,0.588689
w2v_avg,0.306273,0.246191,0.189937,0.314826,0.233992,0.292005,0.620401,0.611941


In [69]:
resp_d_model[['sans_total abs r',
 'saps_total abs r',
 'panss_pos abs r',
 'panss_neg abs r',
 'panss_o abs r',
 'panss_total abs r']].mean(axis=1).sort_values() 

bert         0.176060
glove_tf     0.193944
glove_avg    0.212159
w2v_avg      0.263871
w2v_tf       0.289334
dtype: float64

In [70]:
resp_d_model['r_corr_w_control']

bert          0.16449
glove_tf     0.423193
glove_avg    0.559332
w2v_tf       0.588689
w2v_avg      0.611941
Name: r_corr_w_control, dtype: object

In [71]:
resp_d_metric['r_corr_w_control']

cgcoh    0.283809
gcoh     0.489772
lcoh      0.54957
scoh     0.554965
sprob     0.57879
pppl     0.536921
Name: r_corr_w_control, dtype: object

#### including feature based metrics for BERT

In [72]:
resp_d_model, resp_d_metric = mean_model_metric_medians(median_d)

In [73]:
style(resp_d_metric)

Unnamed: 0,sans_total abs r,saps_total abs r,panss_pos abs r,panss_neg abs r,panss_o abs r,panss_total abs r,t,r_corr_w_control
cgcoh,0.105299,0.190716,0.142744,0.102059,0.085015,0.098072,0.876023,0.283809
gcoh,0.176891,0.183033,0.163362,0.200699,0.159443,0.205444,0.136936,0.489772
lcoh,0.314002,0.232792,0.256959,0.328495,0.292893,0.345097,0.734223,0.54957
scoh,0.373963,0.192923,0.201327,0.392703,0.334748,0.371083,0.379362,0.554965
sprob,0.245612,0.13136,0.203768,0.318323,0.278572,0.314921,1.292476,0.57879
pppl,0.471239,0.155317,0.119219,0.454768,0.432778,0.412121,2.039498,0.536921


In [74]:
resp_d_model[['sans_total abs r',
 'saps_total abs r',
 'panss_pos abs r',
 'panss_neg abs r',
 'panss_o abs r',
 'panss_total abs r']].mean(axis=1).sort_values() 

glove_tf     0.193944
glove_avg    0.212159
bert         0.215651
w2v_avg      0.263871
w2v_tf       0.289334
dtype: float64

In [75]:
resp_d_model['r_corr_w_control']

bert         0.295612
glove_tf     0.423193
glove_avg    0.559332
w2v_tf       0.588689
w2v_avg      0.611941
Name: r_corr_w_control, dtype: object

In [76]:
resp_d_metric['r_corr_w_control']

cgcoh    0.283809
gcoh     0.489772
lcoh      0.54957
scoh     0.554965
sprob     0.57879
pppl     0.536921
Name: r_corr_w_control, dtype: object

## Plot and analyze across parts of NET

In [77]:
def plot_horizontal_tasks(df, title, scale, measure, xname=None, m_type='syntactic', 
                          plot_abs=False, r=0.3, figparams=figprms):
    absolute_value = f' (absolute r value)' if plot_abs else ''
    figsize, wspace, hspace, yt = get_fparams(m_type, 2, 2, figparams)
    fig, axes = plt.subplots(2, 2, figsize=figsize, sharex=True)
    fig.suptitle(title + absolute_value, y=yt)
    plt.subplots_adjust(wspace=wspace, hspace=hspace)
    
    axs = axes.flatten()
    
    for i, task in enumerate(TASKS):
        ax = axs[i]
        data_task = df.loc[m_type, (task, scale)]
        if m_type == 'syntactic' and measure == 'r_corr_w_control':
            data_task = df.loc[m_type, (task, scale)].drop('mean_sent_len')
        pointplot_horizontal(prep_horizontal_pointplot_errobar_data(data_task, 
                                                        col=measure, plot_abs=plot_abs), 
                       x=measure, ax=ax)
        ax.set_title(task)

    add_grey(axes, r=r)
    if xname is None:
        xname = measure
    
    for ax in axes.reshape(-1): 
        label = 'abs ' + xname if plot_abs else xname
        ax.set_xlabel(label)
    return fig

In [78]:
m_type = 'graph'
fig = plot_horizontal_tasks(reformed_tasks, 
                      f'cross-task comparison for {m_type} metrics on group difference (t-test)', 
                      scale='panss_o', measure='t', m_type=m_type, r=2, figparams=figprms)
plt.close(fig)

In [79]:
fig = plot_horizontal_tasks(reformed_tasks, 
                      f'cross-task comparison for {m_type} metrics on sans_total', 
                      scale='sans_total', measure='r', m_type=m_type, figparams=figprms)
plt.close(fig)

In [80]:
fig = plot_horizontal_tasks(reformed_tasks, 
                            f'cross-task comparison for {m_type} metrics on correlation with mean sentence length', 
                            scale='panss_o', measure='r_corr_w_control', 
                            xname='r', m_type=m_type, figparams=figprms)
plt.close(fig)

In [81]:
def plot_all_scales(reformed_d, m_type='syntactic', path=PATH_FIG, plot_abs=True, dpi=150, figparams=figprms):
    
    fig = plot_horizontal_tasks(reformed_d, 
                                f'cross-task comparison for {m_type} metrics on group difference (t-test)', 
                                scale='panss_o', measure='t', m_type=m_type, r=2, figparams=figparams)
    plt.savefig(f'{path}{m_type}/t_across_tasks.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    ab = 'abs_' if plot_abs else ''
    
    upfig = figparams.copy()
    if m_type == 'lexical':
        upfig[m_type]['yt'] += 0.05
    fig = plot_horizontal_tasks(reformed_d, 
                            f'cross-task comparison for {m_type} metrics on correlation with mean sentence length', 
                            scale='panss_o', measure='r_corr_w_control', 
                            xname='r', m_type=m_type, figparams=upfig)
    plt.savefig(f'{path}{m_type}/{ab}corr_len_across_tasks.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    for scale in ORDERED_SCALES:
        fig = plot_horizontal_tasks(reformed_d, 
                                    f'cross-task comparison for {m_type} metrics on {scale}', 
                                    scale=scale, measure='r', 
                                    plot_abs=plot_abs, m_type=m_type, figparams=figparams)
        plt.savefig(f'{path}{m_type}/{ab}r_{scale}_across_tasks.png', dpi=dpi, bbox_inches = 'tight')
        plt.close(fig)

In [82]:
for m_type in reformed_tasks.index.unique(level=0):
    plot_all_scales(reformed_tasks, m_type, plot_abs=True)
    plot_all_scales(reformed_tasks, m_type, plot_abs=False)

In [83]:
def plot_lm_tasks(df, title, scale, measure, order=order, plot_abs=False, use_errorbar=True,
                 figsize=(15, 10), yname=None, r=0.3):
    absolute_value = f' (absolute {measure} value)' if plot_abs else ''
    fig, axes = plt.subplots(2, 2, figsize=figsize, sharey=True)
    fig.suptitle(title+absolute_value, y=0.925)
    plt.subplots_adjust(wspace=0.1)
    
    axs = axes.flatten()
    for i, task in enumerate(TASKS):
        ax = axs[i]
        d = prep_LM_pointplot(df.loc['LM', (task, scale)], col=measure, plot_abs=plot_abs)
        pointplot(d, x='model', y=measure, hue='metric', ax=ax, order=order, use_errorbar=use_errorbar)
        ax.set_title(task)

    add_grey(axes, line_dir='h', r=r)
    
    if yname is None:
        yname = measure
    for ax in axes.reshape(-1): 
        label = 'abs ' + yname if plot_abs else yname
        ax.set_ylabel(label)
    return fig

In [84]:
fig = plot_lm_tasks(reformed_tasks, 
                    'cross-task comparison for LM metrics across models on group difference (t-test)',
                    scale='panss_o', measure='t', use_errorbar=True, figsize=(15, 10), r=2)
plt.close(fig)

In [85]:
fig = plot_lm_tasks(reformed_tasks, 
                    'cross-task comparison for LM metrics across models on sans_total',
                    scale='sans_total', measure='r', use_errorbar=True, figsize=(15, 10))
plt.close(fig)

In [86]:
def plot_all_LM_across_tasks(reformed_d, m_type='LM', path=PATH_FIG, plot_abs=False, figsize=(18, 12), dpi = 150):
    ab = 'abs_' if plot_abs else ''
    absolute_value = f' (absolute r value)' if plot_abs else ''
    
    fig = plot_lm_tasks(reformed_d, 
                        'cross-task comparison for LM metrics across models on group difference (t-test)',
                        scale='panss_o', measure='t', use_errorbar=True, figsize=figsize, r=2)
    plt.savefig(f'{path}{m_type}/model/t_across_tasks.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    fig = plot_lm_tasks(reformed_d, 
                        'cross-task comparison for LM metrics across models on correlation with mean sentence length',
                        scale='panss_o', measure='r_corr_w_control', yname='r',
                        plot_abs=plot_abs, use_errorbar=True, figsize=figsize)
    plt.savefig(f'{path}{m_type}/model/{ab}corr_len_across_tasks.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    for scale in ORDERED_SCALES:
        fig = plot_lm_tasks(reformed_d, 
                            f'cross-task comparison for LM metrics across models on {scale}{absolute_value}', 
                            scale=scale, measure='r', 
                            plot_abs=plot_abs, use_errorbar=True, figsize=figsize)
        plt.savefig(f'{path}{m_type}/model/{ab}r_{scale}_across_tasks.png', dpi=dpi, bbox_inches = 'tight')
        plt.close(fig)
    plt.close(fig)

In [87]:
plot_all_LM_across_tasks(reformed_tasks, plot_abs=True)
plot_all_LM_across_tasks(reformed_tasks, plot_abs=False)