# Imports & Constants

In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import ast
import matplotlib.pyplot as plt
import numpy as np
import pickle
import pandas as pd
import random
import seaborn as sns
import statsmodels.api as sm

In [4]:
from random import choices
from scipy import stats
from tqdm.notebook import tqdm
from matplotlib.colors import ListedColormap

In [5]:
from jupyter_utils import style, mean_std, display_test, display_group_test, \
    scatter_annotate, show_corrtest_mask_corr, pointplot, pointplot_horizontal, add_grey, \
    prep_horizontal_pointplot_errobar_data, map_model, prep_LM_pointplot, draw_sample_with_replacement, t_test
from ortogonolize_utils import compute_coefficient

In [6]:
import warnings
warnings.filterwarnings(action='ignore', category=np.VisibleDeprecationWarning)
warnings.filterwarnings(action='ignore', message='All-NaN slice encountered')
warnings.filterwarnings(action='ignore', message='Precision loss occurred in moment calculation due to catastrophic cancellation. This occurs when the data are nearly identical. Results may be unreliable.')
warnings.filterwarnings(action='ignore', message='Mean of empty slice')
warnings.filterwarnings(action='ignore', category=stats.ConstantInputWarning)
warnings.filterwarnings(action='ignore', message='indexing past lexsort depth may impact performance')

In [7]:
sns.set_theme(style="whitegrid")

In [8]:
PATH = '/Users/galina.ryazanskaya/Downloads/thesis?/code?/processed_values'

In [9]:
PATH_FIG = '/Users/galina.ryazanskaya/Downloads/thesis?/figures/de/'

# Load data

In [10]:
combined_data_averaged = pd.read_csv('/Users/galina.ryazanskaya/Downloads/thesis?/code?/processed_values/de_averaged.csv', index_col=0)

In [11]:
combined_data_averaged.columns = [c if '(' not in c else ast.literal_eval(c) for c in combined_data_averaged.columns]

In [12]:
combined_data_all = pd.read_csv('/Users/galina.ryazanskaya/Downloads/thesis?/code?/processed_values/de_all.csv', index_col=0, header=[0, 1, 2])

In [13]:
TASKS = ['anger', 'fear', 'happiness', 'sadness']

In [14]:
def task_data(df, task, keep_target=True, fill_synt=True):
    subset = df[task].dropna(axis=0, how='any')
    if task == 'fear' and 'KG_018' in subset.index:
        subset = subset.drop(['KG_018'])   # incorrect task
    if fill_synt:
        subset['syntactic'] = subset['syntactic'].fillna(0.0)
    if keep_target:
        subset = pd.concat([subset, df['target'].loc[subset.index]], axis=1)
    return subset

In [15]:
def aplly_to_all_tasks(df, f, tasks=TASKS, to_df=True, *args, **kwargs):
    res = {}
    for task in tasks:
        data = task_data(df, task)
        res[task] = f(data, *args, **kwargs)
    if to_df:
        if all(isinstance(v, pd.Series) for v in res.values()):
            return pd.DataFrame(res)
        elif all(isinstance(v, pd.DataFrame) for v in res.values()):
            return pd.concat(list(res.values()), keys=list(res.keys()), names=['task'], axis=1)
        else:
            return res
    return res

In [40]:
group_means = combined_data_all.groupby([('target', 'target', 'group')]).mean().T
group_means.columns = ['control', 'NAP']
group_means.index.rename(['task', 'feature_group', 'feature'], inplace=True)
group_means.to_csv("feature_means_de.csv")

# Bootstrap

In [16]:
scale_cols = ['saps_total',
             'sans_total',
             'panss_pos',
             'panss_neg',
             'panss_o',
             'panss_total']

In [17]:
cols_LM = [col for col in combined_data_averaged.columns if col[0] == 'LM']
cols_synt = [col for col in combined_data_averaged.columns if col[0] == 'syntactic']
cols_lex = [col for col in combined_data_averaged.columns if col[0] == 'lexical']
cols_graph = [col for col in combined_data_averaged.columns if col[0] == 'graph']
cols_av = cols_LM + cols_synt + cols_lex + cols_graph

In [18]:
cols_to_correct_for = [('syntactic', 'mean_sent_len'), ('syntactic', 'n_sents'), ('lexical', 'n_words')]

In [19]:
## 1 sample with replacement
## 1.1 compute r for each scale for each iteration
## 1.2 compute t test for each iteration

def bootstrap_with_corrections(df, cols_av, scale_cols, N, columns_to_correct_for=cols_to_correct_for, group=None):
    correction_corr_names = tuple(f'r_corr_w_{x[-1]}' for x in columns_to_correct_for)
    dict_scales_sapmles = {k: {scale: {metric: [] for metric in cols_av} for scale in scale_cols} \
                           for k in ('sample_raw', 'r', 't') + correction_corr_names}
    for i in tqdm(range(N)):
        sample = draw_sample_with_replacement(df, seed=i)
        for scale in scale_cols:
            for col in cols_av:
                if group:
                    t_test_res = t_test(sample, col, group)
                    dict_scales_sapmles['t'][scale][col].append(t_test_res)
                    
                r_raw = compute_coefficient(sample, scale, col)[0]
                dict_scales_sapmles['sample_raw'][scale][col].append(r_raw)
                
                droped = sample.dropna(subset=[col, scale])
                r = stats.pearsonr(droped[col], droped[scale])[0]
                dict_scales_sapmles['r'][scale][col].append(r)
                for col_to_correct_for in columns_to_correct_for:
                    if col != col_to_correct_for:
                        droped_c = sample.dropna(subset=[col, col_to_correct_for])
                        r_c = stats.pearsonr(droped_c[col], droped_c[col_to_correct_for])[0]
                        dict_scales_sapmles[f'r_corr_w_{col_to_correct_for[-1]}'][scale][col].append(r_c)
                    
    return dict_scales_sapmles

**expensive to compute**

In [274]:
# N = 1000
# dict_scales_sapmles = bootstrap_with_corrections(combined_data_averaged, cols_av, 
#                                                  scale_cols, N, 
#                                                  columns_to_correct_for=cols_to_correct_for, 
#                                                  group='group')

In [275]:
# reform_v = {(scale, measure): dict_scales_sapmles[measure][scale] for scale in scale_cols for measure in dict_scales_sapmles}

In [276]:
# with open('processed_values/de_scales_samples_w_verbosity.pickle', 'wb') as f:
#     pickle.dump(reform_v, f)

In [20]:
with open('processed_values/de_scales_samples_w_verbosity.pickle', 'rb') as f:
    reform_v = pickle.load(f)

In [21]:
reformed_d_w_verbosity = pd.DataFrame(reform_v)

In [22]:
## 1 sample with replacement
## 1.1 compute r for each scale for each iteration
## 1.2 compute t test for each iteration

def bootstrap(df, cols_av, scale_cols, N, col_to_correct_for=('syntactic', 'mean_sent_len'), group=None):
    dict_scales_sapmles = {k: {scale: {metric: [] for metric in cols_av} for scale in scale_cols} \
                           for k in ('sample_raw', 'r', 't', 'r_corr_w_control')}
    for i in tqdm(range(N)):
        sample = draw_sample_with_replacement(df, seed=i)
        for scale in scale_cols:
            for col in cols_av:
                if group:
                    t_test_res = t_test(sample, col, group)
                    dict_scales_sapmles['t'][scale][col].append(t_test_res)
                    
                r_raw = compute_coefficient(sample, scale, col)[0]
                dict_scales_sapmles['sample_raw'][scale][col].append(r_raw)
                
                droped = sample.dropna(subset=[col, scale])
                r = stats.pearsonr(droped[col], droped[scale])[0]
                dict_scales_sapmles['r'][scale][col].append(r)
                
                if col != col_to_correct_for:
                    
                    droped_c = sample.dropna(subset=[col, col_to_correct_for])
                    r_c = stats.pearsonr(droped_c[col], droped_c[col_to_correct_for])[0]
                    dict_scales_sapmles['r_corr_w_control'][scale][col].append(r_c)
                    
    return dict_scales_sapmles

**expensive to compute**

In [280]:
# N = 1000
# dict_scales_sapmles = bootstrap(combined_data_averaged, cols_av, scale_cols, N, col_to_correct_for=('syntactic', 'mean_sent_len'), group='group')

In [281]:
# reform = {(scale, measure): dict_scales_sapmles[measure][scale] for scale in scale_cols for measure in dict_scales_sapmles}

In [282]:
# with open('processed_values/de_scales_samples_wo_o.pickle', 'wb') as f:
#     pickle.dump(reform, f)

In [23]:
with open('processed_values/de_scales_samples_wo_o.pickle', 'rb') as f:
    reform = pickle.load(f)

In [24]:
reformed_d = pd.DataFrame(reform)

### Bootstrap for each task

In [25]:
def bootstrap_tasks_with_corrections(df, cols_av, scale_cols, N, columns_to_correct_for=cols_to_correct_for, group=None):
    correction_corr_names = tuple(f'r_corr_w_{x[-1]}' for x in columns_to_correct_for)
    res_c = ('sample_raw', 'r', 't') + correction_corr_names
    dict_scales_sapmles = {k: {scale: {metric: [] for metric in cols_av} for scale in scale_cols} for k in res_c}
    for i in tqdm(range(N)):
        sample = draw_sample_with_replacement(df, seed=i)
        scale_independent = {k: {metric: [] for metric in cols_av} for k in ('t', ) + correction_corr_names}
        for col in cols_av:
            if group:
                t_test_res = t_test(sample, col, group)
                scale_independent['t'][col].append(t_test_res)
            for col_to_correct_for in cols_to_correct_for:
                if col != col_to_correct_for:
                    dropped_c = sample.dropna(subset=[col, col_to_correct_for])
                    r_c = stats.pearsonr(dropped_c[col], dropped_c[col_to_correct_for])[0]
                    scale_independent[f'r_corr_w_{col_to_correct_for[-1]}'][col] = r_c
            for scale in scale_cols:
                scale_ = ('target', scale)
                r_raw = compute_coefficient(sample, scale_, col)[0]
                dict_scales_sapmles['sample_raw'][scale][col].append(r_raw)
                
                dropped = sample.dropna(subset=[col, scale_])
                r = stats.pearsonr(dropped[col], dropped[scale_])[0]
                dict_scales_sapmles['r'][scale][col].append(r)
                               
                for k in scale_independent:
                    dict_scales_sapmles[k][scale][col].append(scale_independent[k][col])
                
    return dict_scales_sapmles

**expensive to compute**

In [286]:
# N = 1000
# dict_scales_sapmles_tasks = aplly_to_all_tasks(combined_data_all, bootstrap_tasks_with_corrections, cols_av=cols_av, scale_cols=scale_cols, N=N, columns_to_correct_for=cols_to_correct_for, group=('target', 'group'))

In [287]:
# reform_tasks_v = {(task, scale, measure): dict_scales_sapmles_tasks[task][measure][scale] for scale in scale_cols 
#                 for task in dict_scales_sapmles_tasks
#                 for measure in dict_scales_sapmles_tasks[task]}

In [288]:
# with open('processed_values/de_scales_samples_wo_o_tasks_w_verbosity.pickle', 'wb') as f:
#     pickle.dump(reform_tasks_v, f)

In [26]:
with open('processed_values/de_scales_samples_wo_o_tasks_w_verbosity.pickle', 'rb') as f:
    reform_tasks_v = pickle.load(f)

In [27]:
reformed_tasks_v = pd.DataFrame(reform_tasks_v)
reformed_tasks_v.columns.names = ['TASK', 'scale', 'measure']

In [28]:
def bootstrap_tasks(df, cols_av, scale_cols, N, col_to_correct_for=('syntactic', 'mean_sent_len'), group=None):
    res_c = ('sample_raw', 'r', 't', 'r_corr_w_control')
    dict_scales_sapmles = {k: {scale: {metric: [] for metric in cols_av} for scale in scale_cols} for k in res_c}
    for i in tqdm(range(N)):
        sample = draw_sample_with_replacement(df, seed=i)
        scale_independent = {k: {metric: [] for metric in cols_av} for k in ('t', 'r_corr_w_control')}
        for col in cols_av:
            if group:
                t_test_res = t_test(sample, col, group)
                scale_independent['t'][col].append(t_test_res)
            if col != col_to_correct_for:
                dropped_c = sample.dropna(subset=[col, col_to_correct_for])
                r_c = stats.pearsonr(dropped_c[col], dropped_c[col_to_correct_for])[0]
                scale_independent['r_corr_w_control'][col] = r_c
            for scale in scale_cols:
                scale_ = ('target', scale)
                r_raw = compute_coefficient(sample, scale_, col)[0]
                dict_scales_sapmles['sample_raw'][scale][col].append(r_raw)
                
                dropped = sample.dropna(subset=[col, scale_])
                r = stats.pearsonr(dropped[col], dropped[scale_])[0]
                dict_scales_sapmles['r'][scale][col].append(r)
                               
                for k in scale_independent:
                    dict_scales_sapmles[k][scale][col].append(scale_independent[k][col])
                
    return dict_scales_sapmles

**expensive to compute**

In [292]:
# N = 1000
# dict_scales_sapmles_tasks = aplly_to_all_tasks(combined_data_all, bootstrap_tasks, cols_av=cols_av, scale_cols=scale_cols, N=N, col_to_correct_for=('syntactic', 'mean_sent_len'), group=('target', 'group'))

In [293]:
# reform_tasks = {(task, scale, measure): dict_scales_sapmles_tasks[task][measure][scale] for scale in scale_cols 
#                 for task in dict_scales_sapmles_tasks
#                 for measure in dict_scales_sapmles_tasks[task]}

In [294]:
# with open('processed_values/de_scales_samples_wo_o_tasks.pickle', 'wb') as f:
#     pickle.dump(reform_tasks, f)

In [29]:
with open('processed_values/de_scales_samples_wo_o_tasks.pickle', 'rb') as f:
    reform_tasks = pickle.load(f)

In [30]:
reformed_tasks = pd.DataFrame(reform_tasks)
reformed_tasks.columns.names = ['TASK', 'scale', 'measure']

# Plot & Analyze

In [31]:
figprms = {'syntactic': 
               {'subplot_size': (9, 4.5),
                'wspace': 0.25,
                'hspace': 0.125,
                'yt': 0.925
                }, 
           'LM': 
               {'subplot_size': (9, 5.5),
                'wspace': 0.275,
                'hspace': 0.125,
                'yt': 0.92
               }, 
           'lexical': 
               {'subplot_size': (9, 2),
                'wspace': 0.2,
                'hspace': 0.25,
                'yt': 0.925
               }, 
           'graph': 
               {'subplot_size': (9, 3.5), 
                'wspace': 0.3,
                'hspace': 0.125,
                'yt': 0.925
               }}

In [32]:
def get_fparams(m_type, n_sublots_height, n_sublots_width, figparams):
    subplot_size = figprms[m_type]['subplot_size']
    width = subplot_size[0] * n_sublots_width
    height = subplot_size[1] * n_sublots_height
    figsize = (width, height)
    wspace = figprms[m_type]['wspace']
    hspace = figprms[m_type]['hspace']
    yt = figprms[m_type]['yt']
    return figsize, wspace, hspace, yt

In [33]:
ORDERED_SCALES = ['panss_pos', 'panss_neg', 'panss_o', 'panss_total', 'saps_total', 'sans_total']

### Plot horizontal bar plots

In [300]:
def plot_horizontal_tasks(df, title, measure, m_type='syntactic', plot_abs=False, figparams=figprms):
    figsize, wspace, hspace, yt = get_fparams(m_type, 3, 2, figparams)
    fig, axes = plt.subplots(3, 2, figsize=figsize, sharex=True)
    fig.suptitle(title, y=yt+0.25)
    plt.subplots_adjust(wspace=wspace) #left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)
    
    ab = 'abs ' if plot_abs else ''
    
    axs = axes.flatten()

    for i, scale in enumerate(ORDERED_SCALES):
        ax = axs[i]
        d = prep_horizontal_pointplot_errobar_data(df[scale].loc[m_type], measure, plot_abs=plot_abs)
        pointplot_horizontal(d, x=measure, ax=ax)
        ax.set_title(f'{ab}{measure} {scale}')
    
    add_grey(axes)

    if plot_abs:
        for ax in axes.reshape(-1): 
            ax.set_xlabel('abs ' + measure);
    return fig

In [34]:
verbosity_control_cols = ['r_corr_w_mean_sent_len', 'r_corr_w_n_sents', 'r_corr_w_n_words']

In [35]:
control_col_names = ['mean sentence length', 'sentence count', 'word count']

In [303]:
def plot_all(df, m_type='syntactic', measure='r', path=PATH_FIG, dpi=150, plot_abs=False, figparams=figprms,
                         control_cols=['r_corr_w_control'], control_col_names=['mean sentence length']):
    ab = 'abs_' if plot_abs else ''
    absolute_value = f' (absolute {measure} value)' if plot_abs else ''
    if len(control_cols) != len(control_col_names):
        raise ValueError('The names of the columns must match the columns in length.')
        
    fig = plot_horizontal_tasks(df, title=f'cross-scale comparison for {m_type} metrics{absolute_value}', 
                                measure=measure, m_type=m_type, plot_abs=plot_abs, figparams=figprms)
    plt.savefig(f'{path}{m_type}/{ab}scale_r.png', dpi=dpi)
    plt.close(fig)
    
    figsize, wspace, hspace, yt = get_fparams(m_type, 1, 1, figparams)
    d_t = prep_horizontal_pointplot_errobar_data(df['panss_o'].loc[m_type], 't')
    d_t['t'] = d_t['t'] * -1
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    fig.suptitle(f'group difference (t-test) for {m_type} metrics')
    pointplot_horizontal(d_t, 't', ax=ax)
    add_grey(ax, r=2)
    plt.savefig(f'{path}{m_type}/t.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    figsize, wspace, hspace, yt = get_fparams(m_type, len(control_cols), 1, figparams)
    fig, axes = plt.subplots(len(control_cols), figsize=figsize)
    fig.suptitle(f'correlation with verbosity for {m_type} metrics', y=yt-0.02)
    plt.subplots_adjust(wspace=wspace, hspace=hspace+0.08)
    for i, control_col in enumerate(control_cols):
        ax = axes[i] if len(control_cols) > 1 else axes
        name = control_col_names[i]
        d_c = prep_horizontal_pointplot_errobar_data(df['panss_o'].loc[m_type], control_col)

        pointplot_horizontal(d_c, control_col, ax=ax)
        ax.set_title(f'correlation with {name}');
        ax.set_xlabel('r');
        add_grey(ax)
    plt.savefig(f'{path}{m_type}/corr_verbosity.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)

    figsize, wspace, hspace, yt = get_fparams(m_type, len(control_cols), 2, figparams)
    fig, axes = plt.subplots(len(control_cols), 2, figsize=figsize)
    fig.suptitle(f'group difference and correlation with verbosity for {m_type} metrics', y=yt+0.02)
    plt.subplots_adjust(wspace=wspace, hspace=hspace+0.2)
    for i, control_col in enumerate(control_cols):
        ax = axes[i, 0] if len(control_cols) > 1 else axes[0]
        name = control_col_names[i]
        d_c = prep_horizontal_pointplot_errobar_data(df['panss_o'].loc[m_type], control_col)

        pointplot_horizontal(d_c, control_col, ax=ax)
        ax.set_title(f'correlation with {name}');
        ax.set_xlabel('r');
        add_grey(ax)

        ax_2 = axes[i, 1] if len(control_cols) > 1 else axes[1]
        pointplot_horizontal(d_t, x='t', ax=ax_2)
        ax_2.set_title('group difference (t-test)')
        add_grey(ax_2, r=2)
    plt.savefig(f'{path}{m_type}/t_test_corr_verbosity.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)

In [304]:
for m_type in reformed_d.index.unique(level=0):
    plot_all(reformed_d_w_verbosity, m_type, plot_abs=True,
                                 control_cols=verbosity_control_cols, control_col_names=control_col_names)
    plot_all(reformed_d_w_verbosity, m_type, plot_abs=False,
                                 control_cols=verbosity_control_cols, control_col_names=control_col_names)

In [305]:
for m_type in reformed_d.index.unique(level=0):
    plot_all(reformed_d, m_type, plot_abs=True)
    plot_all(reformed_d, m_type, plot_abs=False)

### Plot vertical bar plots for LMs

In [36]:
order = ['bert', 'glove_tf', 'glove_avg', 'w2v_tf', 'w2v_avg']

In [307]:
def plot_LM_scales(df, title, measure='r', plot_abs=False, figsize=(18, 18), order=order):
    ab = 'abs ' if plot_abs else ''
    absolute_value = f' (absolute {measure} value)' if plot_abs else ''
    
    fig, axes = plt.subplots(3, 2, figsize=figsize, sharey=True)
    fig.suptitle(title + absolute_value, y=0.91)
    plt.subplots_adjust(wspace=0.075)
    
    axs = axes.flatten()
    for i, scale in enumerate(ORDERED_SCALES):
        ax = axs[i]
        d = prep_LM_pointplot(df.loc['LM', scale], measure, plot_abs=plot_abs)
        pointplot(d, x='model', y=measure, hue='metric', ax=ax, order=order, use_errorbar=True)
        ax.set_title(f'{ab}{measure} {scale}')
    
    add_grey(axes, line_dir='h')
    if plot_abs:
        for ax in axes.reshape(-1): 
            ax.set_ylabel('abs ' + measure);
    return fig

In [313]:
def plot_all_LM(df, path=PATH_FIG, dpi=150, plot_abs=False, figsize=(9, 9), measure='r',
                         control_cols=['r_corr_w_control'], control_col_names=['mean sentence length'],
               figparams=figprms):
    
    ab = 'abs_' if plot_abs else ''
    if len(control_cols) != len(control_col_names):
        raise ValueError('The names of the columns must match the columns in length.')
    
    fig = plot_LM_scales(df, 'cross-scale comparison of LM metrcis across models', plot_abs=plot_abs)
    plt.savefig(f'{path}LM/model/{ab}scale_r.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    d_lm_t = prep_LM_pointplot(df.loc['LM', 'panss_o'], 't')
    d_lm_t['t'] = d_lm_t['t'] * -1
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    fig.suptitle('group difference (t-test) for LM metrcis across models')
    pointplot(d_lm_t, x='model', y='t', hue='metric', ax=ax, order=order, use_errorbar=True)
    add_grey(ax, r=2, line_dir='h')
    plt.savefig(f'{path}LM/model/t.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)

    absolute_value = f' (absolute {measure} value)' if plot_abs else ''
    figsize, wspace, hspace, yt = get_fparams(m_type, len(control_cols), 1, figparams)
    fig, axes = plt.subplots(len(control_cols), 1, figsize=(9, 15))
    fig.suptitle('correlation with verbosity for LM metrcis across models' + absolute_value, y=yt-0.01)
    plt.subplots_adjust(hspace=hspace+0.1)
    for i, control_col in enumerate(control_cols):
        name = control_col_names[i]
        ax = axes[i] if len(control_cols) > 1 else axes
        d_lm_c = prep_LM_pointplot(df.loc['LM', 'panss_o'], control_col, plot_abs=plot_abs)
        pointplot(d_lm_c, x='model', y=control_col, hue='metric', ax=ax, order=order, use_errorbar=True)
        ax.set_title(f'correlation with {name}')
        add_grey(ax, line_dir='h');
    plt.savefig(f'{path}LM/model/{ab}corr_verbosity.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    

    figsize, wspace, hspace, yt = get_fparams(m_type, len(control_cols), 2, figparams)
    fig, axes = plt.subplots(len(control_cols), 2, figsize=(18, 15))
    fig.suptitle('correlation with verbosity for LM metrcis across models' + absolute_value, y=yt-0.01)
    plt.subplots_adjust(wspace=wspace-0.18, hspace=hspace+0.08)
    for i, control_col in enumerate(control_cols):
        name = control_col_names[i]
        ax = axes[i, 0] if len(control_cols) > 1 else axes[0]
        d_lm_c = prep_LM_pointplot(df.loc['LM', 'panss_o'], control_col, plot_abs=plot_abs)
        pointplot(d_lm_c, x='model', y=control_col, hue='metric', ax=ax, order=order, use_errorbar=True)
        ax.set_title(f'correlation with {name}')
        add_grey(ax, line_dir='h');

        ax_2 = axes[i, 1] if len(control_cols) > 1 else axes[1]
        pointplot(d_lm_t, x='model', y='t', hue='metric', ax=ax_2, order=order, use_errorbar=True)
        ax_2.set_title('group difference (t-test)')
        add_grey(ax_2, r=2, line_dir='h')
    plt.savefig(f'{path}LM/model/t_test_corr_verbosity.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)

In [314]:
plot_all_LM(reformed_d_w_verbosity, plot_abs=True, 
            control_cols=verbosity_control_cols, control_col_names=control_col_names)
plot_all_LM(reformed_d_w_verbosity, plot_abs=False,
            control_cols=verbosity_control_cols, control_col_names=control_col_names)

In [315]:
plot_all_LM(reformed_d, plot_abs=True)
plot_all_LM(reformed_d, plot_abs=False)

## Cross-metric comparison

### R squared

In [317]:
def select_ok_metrics(row, t=1, r=0.3, rc=0.3):
    ok_corr = False
    ok_t = abs(row['saps_total', 't']) >= t
    ok_len = pd.isna(row['saps_total', 'r_corr_w_control']) or abs(row['saps_total', 'r_corr_w_control']) <= rc
    for scale in scale_cols:
        if abs(row[scale, 'r']) >= r:
            ok_corr = True 
            break
    return (ok_corr or ok_t) and ok_len

In [318]:
def select_ok_metrics_for_one_scale(row, scale, r=0.3, rc=0.3):
    ok_corr = False
    ok_len = pd.isna(row['panss_total', 'r_corr_w_control']) or abs(row['panss_total', 'r_corr_w_control']) <= rc
    ok_corr = abs(row[scale, 'r']) >= r
    return ok_corr and ok_len

In [319]:
def select_control_corr_ms_better_than_len(row, row_len, r=0.3, rc=0.3):
    if select_ok_metrics(row, r=r, rc=rc):
        return False
    ok_len = pd.isna(row['saps_total', 'r_corr_w_control']) or abs(row['saps_total', 'r_corr_w_control']) <= rc
    if ok_len:
        return False
    else:
        for scale in scales_:
            if abs(row[scale, 'r']) >= abs(row_len[scale, 'r']):
                return True
        return False

In [320]:
def select_control_corr_ms_better_than_len_for_one_scale(row, row_len, scale, r=0.3, rc=0.3):
    if select_ok_metrics(row, r=r, rc=rc):
        return False
    ok_len = pd.isna(row['saps_total', 'r_corr_w_control']) or abs(row['saps_total', 'r_corr_w_control']) <= rc
    if ok_len:
        return False
    else:
        if abs(row[scale, 'r']) >= abs(row_len[scale, 'r']) and abs(row[scale, 'r']) > r:
            return True
        return False

In [321]:
def select_bad_len_metrics(row, rc=0.3):
    ok_len = pd.isna(row['panss_total', 'r_corr_w_control']) or abs(row['panss_total', 'r_corr_w_control']) <= rc
    return not ok_len

In [37]:
median_d = reformed_d.applymap(np.nanmedian)

In [None]:
idxs_scale = {}
idxs_bad_bet_len = {}
for scale in scale_cols:
    ids = median_d[median_d.apply(lambda x: select_ok_metrics_for_one_scale(x, scale=scale), axis=1)].index
    idxs_scale[scale] = ids
    idsl = median_d[median_d.apply(lambda x: select_control_corr_ms_better_than_len_for_one_scale(x, 
                                                                                                  row_len=median_d.loc[('syntactic', 'mean_sent_len')], 
                                                                                                  scale=scale), axis=1)].index
    idxs_bad_bet_len[scale] = idsl

In [None]:
good_ms = sorted(list(set([y for x in idxs_scale.values() for y in x])))

In [None]:
bad_ms_better_than_len = sorted(list(set([y for x in idxs_bad_bet_len.values() for y in x])))

In [None]:
bad_ms = median_d[median_d.apply(select_bad_len_metrics, axis=1)].index

In [327]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6), sharex=True)
fig.suptitle('cross-type metric comparison for correlation with mean sent len')
measure = 'r_corr_w_control'
plot_abs = False


d = prep_horizontal_pointplot_errobar_data(reformed_d['panss_pos'].loc[bad_ms], measure, plot_abs=plot_abs)
pointplot_horizontal(d, x=measure, ax=ax)
ax.set_xlabel('r')
add_grey(ax)
plt.savefig(f'{PATH_FIG}/compare_corr_len.png', dpi=150, bbox_inches = 'tight')
plt.close(fig)

In [328]:
ms_to_plot = sorted(bad_ms_better_than_len + good_ms)

In [329]:
len(ms_to_plot)

21

In [330]:
len(good_ms)

5

In [331]:
{s: len(idxs_scale[s]) + len(idxs_bad_bet_len[s]) for s in idxs_scale}

{'saps_total': 4,
 'sans_total': 10,
 'panss_pos': 1,
 'panss_neg': 14,
 'panss_o': 14,
 'panss_total': 17}

In [332]:
{s: len(idxs_scale[s]) for s in idxs_scale}

{'saps_total': 1,
 'sans_total': 3,
 'panss_pos': 0,
 'panss_neg': 4,
 'panss_o': 2,
 'panss_total': 3}

In [39]:
def sort_index(idxs):
    ms = sorted(idxs)
    for s in ('syntactic: mean_sent_len', 'mean_sent_len', ('syntactic', 'mean_sent_len')):
        if s in ms:
            ms.remove(s)
            ms.append(s)
    return ms 

In [40]:
def map_marker(m, scale, idxs_scale):
    if m in idxs_scale[scale]:
        return 'o'
    else: 
        return 'x'

In [335]:
def add_len_lines(ax, median_d, scale, measure='r'):
    ax.axvline(median_d.loc[('syntactic', 'mean_sent_len')][scale, measure], linestyle='--')
    ax.axvline(-median_d.loc[('syntactic', 'mean_sent_len')][scale, measure], linestyle='--')

In [336]:
def plot_one_scale(ax, scale, idxs_scale, reformed_d, ms_to_plot, measure, plot_abs, use_markers=True):
    markers = [map_marker(m, scale, idxs_scale) for m in ms_to_plot]
    if ('syntactic', 'mean_sent_len') in idxs_scale[scale]:
        add_len_lines(ax, median_d, scale)
    d = prep_horizontal_pointplot_errobar_data(reformed_d[scale].loc[ms_to_plot], measure, plot_abs=plot_abs)
    pointplot_horizontal(d, x=measure, ax=ax, markers=markers if use_markers else 'o')
    ax.set_title(f'{measure} {scale}')

In [337]:
fig, axes = plt.subplots(3, 2, figsize=(18, 14), sharex=True)
fig.suptitle('cross-scale cross-type metric comparison', y=0.915)
plt.subplots_adjust(wspace=0.4, hspace=0.1) #left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)
measure = 'r'
plot_abs = False
idxs = sort_index(ms_to_plot)
axs = axes.flatten()

for i, scale in enumerate(ORDERED_SCALES):
    ax = axs[i]
    plot_one_scale(ax, scale, idxs_scale, reformed_d, ms_to_plot, measure, plot_abs)
    
add_grey(axes)
plt.savefig(f'{PATH_FIG}/compare_r.png', dpi=150, bbox_inches = 'tight')
plt.close(fig)

#### Images for the Presentation

In [338]:
idx_to_plot = [m for m in ms_to_plot if m not in [('syntactic', 'AUX'),  ('LM', 'w2v_tf_cgcoh')]]

In [339]:
fig, axes = plt.subplots(1, 2, figsize=(18, 4))
plt.subplots_adjust(wspace=0.4)
fig.suptitle('cross-type metric comparison for correlation with SANS and inverse mean sentence length', y=1)

dcr = reformed_d['sans_total'].loc[idx_to_plot].drop([('syntactic', 'mean_sent_len')])
d_c = prep_horizontal_pointplot_errobar_data(dcr, 'r_corr_w_control', plot_abs=False)
d_c['r_corr_w_control'] = -1 * d_c['r_corr_w_control']

plot_one_scale(axes[0], 'sans_total', idxs_scale, reformed_d, idx_to_plot, measure, plot_abs, use_markers=False)

pointplot_horizontal(d_c, x='r_corr_w_control', ax=axes[1], palette='tab20')
axes[1].set_title('-1 * r mean sentence length')
axes[1].set_xlabel('-r');
add_grey(axes)
plt.savefig(f'{PATH_FIG}/compare_sans_to_minus_corr_len_more_metrics.png', dpi=150, bbox_inches = 'tight')
plt.close(fig)

In [340]:
idxs_sans = sorted([m for m in idxs_scale['sans_total']] + [m for m in idxs_bad_bet_len['sans_total']])

In [341]:
fig, axes = plt.subplots(1, 2, figsize=(18, 4))
plt.subplots_adjust(wspace=0.4)
fig.suptitle('cross-type metric comparison for correlation with SANS and inverse mean sentence length', y=1)

dcr = reformed_d['sans_total'].loc[idxs_sans].drop([('syntactic', 'mean_sent_len')])
d_c = prep_horizontal_pointplot_errobar_data(dcr, 'r_corr_w_control', plot_abs=False)
d_c['r_corr_w_control'] = -1 * d_c['r_corr_w_control']

plot_one_scale(axes[0], 'sans_total', idxs_scale, reformed_d, idxs_sans, measure, plot_abs, use_markers=False)

pointplot_horizontal(d_c, x='r_corr_w_control', ax=axes[1], palette='tab20')
axes[1].set_title('-1 * r mean sentence length')
axes[1].set_xlabel('-r');
add_grey(axes)
plt.savefig(f'{PATH_FIG}/compare_sans_to_minus_corr_len.png', dpi=150, bbox_inches = 'tight')
plt.close(fig)

### t-test

In [342]:
def select_ok_metrics_t(row):
    q25 = np.quantile(row['saps_total', 't'], 0.25)
    q75 = np.quantile(row['saps_total', 't'], 0.75)
    return q25 > 0 or q75 < 0

In [343]:
idx_comp_t = reformed_d[reformed_d.apply(select_ok_metrics_t, axis=1)].index

In [344]:
fig, axes = plt.subplots(1, 2, figsize=(18, 4))
plt.subplots_adjust(wspace=0.4)
fig.suptitle('cross-type metric comparison for group difference and correlation with mean sentence length', y=1)

dcr = reformed_d['sans_total'].loc[idx_comp_t].drop([('syntactic', 'mean_sent_len')])
d_c = prep_horizontal_pointplot_errobar_data(dcr, 'r_corr_w_control', plot_abs=False)
d_t = prep_horizontal_pointplot_errobar_data(reformed_d['sans_total'].loc[idx_comp_t], 't', plot_abs=False)
d_t['t'] = d_t['t'] * -1

pointplot_horizontal(d_t, x='t', ax=axes[0])
axes[0].set_title('t')
axes[0].set_title('group difference (t-test)')
add_len_lines(axes[0], median_d, scale, 't')

pointplot_horizontal(d_c, x='r_corr_w_control', ax=axes[1])
axes[1].set_title('correlation with mean sentence length')
axes[1].set_xlabel('r');
add_grey(axes[0], r=2)
add_grey(axes[1])
plt.savefig(f'{PATH_FIG}/compare_t.png', dpi=150, bbox_inches = 'tight')
plt.close(fig)

#### Images for the Presentation

In [345]:
fig, axes = plt.subplots(1, 2, figsize=(18, 4))
plt.subplots_adjust(wspace=0.4)
fig.suptitle('cross-type metric comparison for group difference and correlation with inverse mean sentence length', y=1)

dcr = reformed_d['sans_total'].loc[idx_comp_t].drop([('syntactic', 'mean_sent_len')])
d_c = prep_horizontal_pointplot_errobar_data(dcr, 'r_corr_w_control', plot_abs=False)
d_t = prep_horizontal_pointplot_errobar_data(reformed_d['sans_total'].loc[idx_comp_t], 't', plot_abs=False)
d_t['t'] = d_t['t'] * -1
d_c['r_corr_w_control'] = -1 * d_c['r_corr_w_control']

pointplot_horizontal(d_t, x='t', ax=axes[0])
axes[0].set_title('t')
axes[0].set_title('t-test')
add_len_lines(axes[0], median_d, scale, 't')

pointplot_horizontal(d_c, x='r_corr_w_control', ax=axes[1])
axes[1].set_title('-1 * r mean sentence length')
axes[1].set_xlabel('-r');
add_grey(axes[0], r=2)
add_grey(axes[1])
plt.savefig(f'{PATH_FIG}/compare_t_minus_corr_len.png', dpi=150, bbox_inches = 'tight')
plt.close(fig)

### average LM model / metric performance medians across scales

In [346]:
scales_ = ('sans_total', 'saps_total', 'panss_pos', 'panss_neg', 'panss_o', 'panss_total')
sc_ind_ = ('t', 'r_corr_w_control')
models_ = ('bert', 'glove_tf', 'glove_avg', 'w2v_tf', 'w2v_avg')
metrics_ = ('cgcoh', 'gcoh', 'lcoh', 'scoh', 'sprob', 'pppl')

In [347]:
def mean_model_metric_medians(median_df, leave_out=()):
    resp_d_model = pd.DataFrame(columns=[f'{sc} abs r' for sc in scales_] + list(sc_ind_), index=models_)
    resp_d_metric = pd.DataFrame(columns=[f'{sc} abs r' for sc in scales_] + list(sc_ind_), index=metrics_)
    for scale in scales_:
        ex_d = prep_LM_pointplot(median_df.loc['LM', scale], col='r', use_errorbar=False, plot_abs=True)
        for model in models_:
            leave_out_ = ex_d[ex_d['model'] == model]
            leave_out_ = leave_out_[~leave_out_['metric'].isin(leave_out)]
            resp_d_model.loc[model, f'{scale} abs r'] = np.nanmean(leave_out_['r'])
        for metric in metrics_:
            resp_d_metric.loc[metric, f'{scale} abs r'] = np.nanmean(ex_d[ex_d['metric'] == metric]['r'])
    for sc_ind in sc_ind_:
        for model in models_:
            leave_out_ = ex_d[ex_d['model'] == model]
            leave_out_ = leave_out_[~leave_out_['metric'].isin(leave_out)]
            resp_d_model.loc[model, sc_ind] = np.nanmean(leave_out_[sc_ind])
        for metric in metrics_:
            resp_d_metric.loc[metric, sc_ind] = np.nanmean(ex_d[ex_d['metric'] == metric][sc_ind])
    return resp_d_model, resp_d_metric


#### only including cosine-similarity based metrics

In [348]:
resp_d_model, resp_d_metric = mean_model_metric_medians(median_d, leave_out=('pppl', 'sprob'))

In [349]:
style(resp_d_model)

Unnamed: 0,sans_total abs r,saps_total abs r,panss_pos abs r,panss_neg abs r,panss_o abs r,panss_total abs r,t,r_corr_w_control
bert,0.236577,0.04418,0.07079,0.225726,0.252314,0.226234,0.511409,0.157991
glove_tf,0.165904,0.208092,0.234331,0.192519,0.164074,0.198569,0.39171,0.429151
glove_avg,0.224203,0.195832,0.200804,0.234329,0.184796,0.23216,0.394616,0.567291
w2v_tf,0.280315,0.303707,0.258881,0.312226,0.254946,0.324928,0.729381,0.592533
w2v_avg,0.306171,0.246063,0.189779,0.314721,0.233608,0.291453,0.613459,0.618035


In [350]:
resp_d_model[['sans_total abs r',
 'saps_total abs r',
 'panss_pos abs r',
 'panss_neg abs r',
 'panss_o abs r',
 'panss_total abs r']].mean(axis=1).sort_values() 

bert         0.175970
glove_tf     0.193915
glove_avg    0.212021
w2v_avg      0.263633
w2v_tf       0.289167
dtype: float64

In [351]:
resp_d_model['r_corr_w_control']

bert         0.157991
glove_tf     0.429151
glove_avg    0.567291
w2v_tf       0.592533
w2v_avg      0.618035
Name: r_corr_w_control, dtype: object

In [352]:
resp_d_metric['r_corr_w_control']

cgcoh    0.292816
gcoh     0.493656
lcoh     0.549157
scoh     0.556372
sprob    0.576114
pppl     0.529623
Name: r_corr_w_control, dtype: object

#### including feature based metrics for BERT

In [353]:
resp_d_model, resp_d_metric = mean_model_metric_medians(median_d)

In [354]:
style(resp_d_metric)

Unnamed: 0,sans_total abs r,saps_total abs r,panss_pos abs r,panss_neg abs r,panss_o abs r,panss_total abs r,t,r_corr_w_control
cgcoh,0.105246,0.190676,0.142618,0.101842,0.085081,0.098206,0.872924,0.292816
gcoh,0.176989,0.182976,0.163138,0.200634,0.159281,0.204929,0.134382,0.493656
lcoh,0.314356,0.232184,0.256773,0.328579,0.292895,0.344763,0.729779,0.549157
scoh,0.373945,0.192464,0.201138,0.392562,0.334534,0.370777,0.375375,0.556372
sprob,0.245612,0.131085,0.202997,0.318265,0.278489,0.314706,1.28745,0.576114
pppl,0.471239,0.155275,0.119219,0.454768,0.432778,0.412121,2.038195,0.529623


In [355]:
resp_d_model[['sans_total abs r',
 'saps_total abs r',
 'panss_pos abs r',
 'panss_neg abs r',
 'panss_o abs r',
 'panss_total abs r']].mean(axis=1).sort_values() 

glove_tf     0.193915
glove_avg    0.212021
bert         0.215551
w2v_avg      0.263633
w2v_tf       0.289167
dtype: float64

In [356]:
resp_d_model['r_corr_w_control']

bert         0.289617
glove_tf     0.429151
glove_avg    0.567291
w2v_tf       0.592533
w2v_avg      0.618035
Name: r_corr_w_control, dtype: object

In [357]:
resp_d_metric['r_corr_w_control']

cgcoh    0.292816
gcoh     0.493656
lcoh     0.549157
scoh     0.556372
sprob    0.576114
pppl     0.529623
Name: r_corr_w_control, dtype: object

## Plot and analyze across parts of NET

In [41]:
verbosity_control_cols

['r_corr_w_mean_sent_len', 'r_corr_w_n_sents', 'r_corr_w_n_words']

In [42]:
def plot_horizontal_tasks(df, title, scale, measure, xname=None, m_type='syntactic', 
                          plot_abs=False, r=0.3, figparams=figprms, 
                          control_cols = ['r_corr_w_control'], control_col_names=['mean sentence length']):
    absolute_value = f' (absolute r value)' if plot_abs else ''
    if len(control_cols) != len(control_col_names):
        raise ValueError('The names of the columns must match the columns in length.')
    
    figsize, wspace, hspace, yt = get_fparams(m_type, 2, 2, figparams)
    fig, axes = plt.subplots(2, 2, figsize=figsize, sharex=True)
    fig.suptitle(title + absolute_value, y=yt)
    plt.subplots_adjust(wspace=wspace, hspace=hspace)
    
    axs = axes.flatten()
    
    for i, task in enumerate(TASKS):
        ax = axs[i]
        data_task = df.loc[m_type, (task, scale)]
        if m_type == 'syntactic':
            if measure == 'r_corr_w_mean_sent_len' or measure == 'r_corr_w_control':
                data_task = df.loc[m_type, (task, scale)].drop('mean_sent_len')
            elif measure == 'r_corr_w_mean_sent_len':
                data_task = df.loc[m_type, (task, scale)].drop('n_sents')
        elif m_type == 'lexical' and measure == 'r_corr_w_n_words':
            data_task = df.loc[m_type, (task, scale)].drop('n_words')
            
        data = prep_horizontal_pointplot_errobar_data(data_task, col=measure, plot_abs=plot_abs)
        if measure == 't':
            data['t'] = data['t'] * -1
        pointplot_horizontal(data, x=measure, ax=ax)
        ax.set_title(task)

    add_grey(axes, r=r)
    if xname is None:
        xname = measure
    
    for ax in axes.reshape(-1): 
        label = 'abs ' + xname if plot_abs else xname
        ax.set_xlabel(label)
    return fig

In [375]:
m_type = 'syntactic'
fig = plot_horizontal_tasks(reformed_tasks, 
                      f'cross-task comparison for {m_type} metrics on group difference (t-test)', 
                      scale='panss_o', measure='t', m_type=m_type, r=2, figparams=figprms)
plt.close(fig)

In [376]:
fig = plot_horizontal_tasks(reformed_tasks, 
                      f'cross-task comparison for {m_type} metrics on sans_total', 
                      scale='sans_total', measure='r', m_type=m_type, figparams=figprms)
plt.close(fig)

In [377]:
fig = plot_horizontal_tasks(reformed_tasks, 
                            f'cross-task comparison for {m_type} metrics on correlation with mean sentence length', 
                            scale='panss_o', measure='r_corr_w_control', 
                            xname='r', m_type=m_type, figparams=figprms)
plt.close(fig)

In [378]:
def plot_all_scales(reformed_d, m_type='syntactic', path=PATH_FIG, plot_abs=True, dpi=150, figparams=figprms,
                    control_cols=['r_corr_w_control'], control_col_names=['mean sentence length']):
    
    fig = plot_horizontal_tasks(reformed_d, 
                                f'cross-task comparison for {m_type} metrics on group difference (t-test)', 
                                scale='panss_o', measure='t', m_type=m_type, r=2, figparams=figparams)
    plt.savefig(f'{path}{m_type}/t_across_tasks.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    ab = 'abs_' if plot_abs else ''
    
    upfig = figparams.copy()
    if m_type == 'lexical':
        upfig[m_type]['yt'] += 0.05
    for i, control_col in enumerate(control_cols):
        name = control_col_names[i]
        fig = plot_horizontal_tasks(reformed_d, 
                                f'cross-task comparison for {m_type} metrics on correlation with {name}', 
                                scale='panss_o', measure=control_col, plot_abs=plot_abs, 
                                xname='r', m_type=m_type, figparams=upfig)
        plt.savefig(f'{path}{m_type}/{ab}corr_{"_".join(name.split())}_across_tasks.png', 
                    dpi=dpi, bbox_inches = 'tight')
        plt.close(fig)
    
    for scale in ORDERED_SCALES:
        fig = plot_horizontal_tasks(reformed_d, 
                                    f'cross-task comparison for {m_type} metrics on {scale}', 
                                    scale=scale, measure='r', 
                                    plot_abs=plot_abs, m_type=m_type, figparams=figparams)
        plt.savefig(f'{path}{m_type}/{ab}r_{scale}_across_tasks.png', dpi=dpi, bbox_inches = 'tight')
        plt.close(fig)

In [379]:
for m_type in reformed_tasks.index.unique(level=0):
    plot_all_scales(reformed_tasks, m_type, plot_abs=True)
    plot_all_scales(reformed_tasks, m_type, plot_abs=False)

In [380]:
def plot_lm_tasks(df, title, scale, measure, order=order, plot_abs=False, use_errorbar=True,
                 figsize=(15, 10), yname=None, r=0.3):
    absolute_value = f' (absolute {measure} value)' if plot_abs else ''
    fig, axes = plt.subplots(2, 2, figsize=figsize, sharey=True)
    fig.suptitle(title+absolute_value, y=0.925)
    plt.subplots_adjust(wspace=0.1)
    
    axs = axes.flatten()
    for i, task in enumerate(TASKS):
        ax = axs[i]
        d = prep_LM_pointplot(df.loc['LM', (task, scale)], col=measure, plot_abs=plot_abs)
        if measure == 't':
            d['t'] = d['t'] * -1
        pointplot(d, x='model', y=measure, hue='metric', ax=ax, order=order, use_errorbar=use_errorbar)
        ax.set_title(task)

    add_grey(axes, line_dir='h', r=r)
    
    if yname is None:
        yname = measure
    for ax in axes.reshape(-1): 
        label = 'abs ' + yname if plot_abs else yname
        ax.set_ylabel(label)
    return fig

In [381]:
fig = plot_lm_tasks(reformed_tasks, 
                    'cross-task comparison for LM metrics across models on group difference (t-test)',
                    scale='panss_o', measure='t', use_errorbar=True, figsize=(15, 10), r=2)
plt.close(fig)

In [382]:
fig = plot_lm_tasks(reformed_tasks, 
                    'cross-task comparison for LM metrics across models on sans_total',
                    scale='sans_total', measure='r', use_errorbar=True, figsize=(15, 10))
plt.close(fig)

In [383]:
def plot_all_LM_across_tasks(reformed_d, m_type='LM', path=PATH_FIG, plot_abs=False, figsize=(18, 12), dpi = 150,
                    control_cols=['r_corr_w_control'], control_col_names=['mean sentence length']):
    ab = 'abs_' if plot_abs else ''
    absolute_value = f' (absolute r value)' if plot_abs else ''
    
    fig = plot_lm_tasks(reformed_d, 
                        'cross-task comparison for LM metrics across models on group difference (t-test)',
                        scale='panss_o', measure='t', use_errorbar=True, figsize=figsize, r=2)
    plt.savefig(f'{path}{m_type}/model/t_across_tasks.png', dpi=dpi, bbox_inches = 'tight')
    plt.close(fig)
    
    for i, control_col in enumerate(control_cols):
        name = control_col_names[i]
        fig = plot_lm_tasks(reformed_d, 
                            f'cross-task comparison for LM metrics across models on correlation with {name}',
                            scale='panss_o', measure=control_col, yname='r',
                            plot_abs=plot_abs, use_errorbar=True, figsize=figsize)
        plt.savefig(f'{path}{m_type}/model/{ab}corr_{"_".join(name.split())}_across_tasks.png', 
                    dpi=dpi, bbox_inches = 'tight')
        plt.close(fig)
    
    for scale in ORDERED_SCALES:
        fig = plot_lm_tasks(reformed_d, 
                            f'cross-task comparison for LM metrics across models on {scale}{absolute_value}', 
                            scale=scale, measure='r', 
                            plot_abs=plot_abs, use_errorbar=True, figsize=figsize)
        plt.savefig(f'{path}{m_type}/model/{ab}r_{scale}_across_tasks.png', dpi=dpi, bbox_inches = 'tight')
        plt.close(fig)
    plt.close(fig)

In [384]:
plot_all_LM_across_tasks(reformed_tasks, plot_abs=True)
plot_all_LM_across_tasks(reformed_tasks, plot_abs=False)

# Reform the dataset to long format

index: unique

langugae: de / ru

task: (de tasks) / (ru tasks) - 4 for each

scale: (de: panss sans saps t test) / (ru: panss dep td t test)

metric: name

matric_group: 4

values: median CI_low CI_high corr_mean_sent_len corr_n_sent corr_n_word

In [43]:
control_corr_names = ['r_corr_w_mean_sent_len', 'r_corr_w_n_sents', 'r_corr_w_n_words']

In [44]:
low = 0.25
high = 0.75
lang = 'de'

In [45]:
TASKS

['anger', 'fear', 'happiness', 'sadness']

In [60]:
long_data = []
for task in TASKS:
    for scale_ in ORDERED_SCALES + ['group_diff', 'sample_raw']:
        performance_metric_d = {'sample_raw': 'sample_raw', 'group_diff': 't'}
        performance_metric = 'r' if scale_ not in performance_metric_d else performance_metric_d[scale_] 
        scale_key = 'panss_total' if scale_ not in ORDERED_SCALES else scale_
        for metric in cols_av:
            metric_group, metric_name = metric
            data = reform_tasks_v[(task, scale_key, performance_metric)][(metric_group, metric_name)]
            median = np.nanmedian(data)
            mean = np.nanmean(data)
            CI_low = np.nanquantile(np.array(data), low)
            CI_high = np.nanquantile(np.array(data), high)
#             if np.isnan(median):
#                 print('nan median in: ', task, scale_, performance_metric, metric_group, metric_name)
#             line = (lang, task, scale_, metric_name)
            control_cols_medians, control_cols_means, control_cols_CI_lows, control_cols_CI_highs = {}, {}, {}, {}
            for control_col in control_corr_names:
                control_data = reform_tasks_v[(task, scale_key, control_col)][metric]
                if control_data[0]:
                    c_median = np.nanmedian(control_data)
                    c_mean = np.nanmean(control_data)
                    c_CI_low = np.nanquantile(np.array(control_data), low)
                    c_CI_high = np.nanquantile(np.array(control_data), high)
                else:
                    c_median, c_mean, c_CI_high, c_CI_low = np.nan, np.nan, np.nan, np.nan
                control_cols_medians[control_col] = c_median
                control_cols_means[control_col] = c_mean
                control_cols_CI_lows[control_col] = c_CI_low
                control_cols_CI_highs[control_col] = c_CI_high
            long_line = (lang, task, scale_, performance_metric, metric_group, metric_name, 
                         median, mean, CI_low, CI_high)
            for control_col in control_corr_names:
                long_line += (control_cols_medians[control_col], control_cols_means[control_col], 
                              control_cols_CI_lows[control_col], control_cols_CI_highs[control_col])
            if long_line not in long_data:
                long_data.append(long_line)
            else:
                print(line)

In [61]:
long_df = pd.DataFrame(long_data, columns=('lang', 'task', 'scale', 'performance_metric',
                                           'metric_group', 'metric_name', 'median', 'mean', 'CI_low', 'CI_high',
                                           'corr_mean_sent_len_median', 'corr_mean_sent_len_mean', 
                                           'corr_mean_sent_len_CI_low', 'corr_mean_sent_len_CI_high',
                                           'corr_n_sents_median', 'corr_n_sents_mean',
                                           'corr_n_sents_CI_low', 'corr_n_sents_CI_high',
                                           'corr_n_words_median', 'corr_n_words_mean',
                                           'corr_n_words_CI_low', 'corr_n_words_CI_high',))
long_df.tail()

Unnamed: 0,lang,task,scale,performance_metric,metric_group,metric_name,median,mean,CI_low,CI_high,...,corr_mean_sent_len_CI_low,corr_mean_sent_len_CI_high,corr_n_sents_median,corr_n_sents_mean,corr_n_sents_CI_low,corr_n_sents_CI_high,corr_n_words_median,corr_n_words_mean,corr_n_words_CI_low,corr_n_words_CI_high
1627,de,sadness,sample_raw,sample_raw,graph,PE,0.072718,0.092412,0.028025,0.144375,...,0.161599,0.30102,0.482556,0.483487,0.430986,0.540829,0.413537,0.414719,0.368688,0.460928
1628,de,sadness,sample_raw,sample_raw,graph,degree_average,0.053097,0.080566,0.014317,0.124021,...,0.231586,0.390299,0.434464,0.432834,0.386642,0.482012,0.411861,0.410253,0.36708,0.455254
1629,de,sadness,sample_raw,sample_raw,graph,degree_std,0.10082,0.11755,0.044424,0.174172,...,0.309745,0.450923,0.463002,0.460349,0.41412,0.509225,0.466058,0.464205,0.423372,0.503094
1630,de,sadness,sample_raw,sample_raw,graph,number_of_edges,0.209177,0.212705,0.144147,0.277617,...,0.402361,0.510313,0.530675,0.529984,0.499753,0.560742,0.560653,0.561677,0.528005,0.592891
1631,de,sadness,sample_raw,sample_raw,graph,number_of_nodes,0.207853,0.209475,0.143152,0.270386,...,0.339156,0.452717,0.412672,0.404037,0.358598,0.452844,0.453008,0.452813,0.410813,0.496868


In [62]:
long_df[long_df[['lang', 'task', 'scale', 'metric_group', 'metric_name']].duplicated(keep=False)]

Unnamed: 0,lang,task,scale,performance_metric,metric_group,metric_name,median,mean,CI_low,CI_high,...,corr_mean_sent_len_CI_low,corr_mean_sent_len_CI_high,corr_n_sents_median,corr_n_sents_mean,corr_n_sents_CI_low,corr_n_sents_CI_high,corr_n_words_median,corr_n_words_mean,corr_n_words_CI_low,corr_n_words_CI_high


In [63]:
long_df.scale.unique()

array(['panss_pos', 'panss_neg', 'panss_o', 'panss_total', 'saps_total',
       'sans_total', 'group_diff', 'sample_raw'], dtype=object)

In [64]:
long_df.performance_metric.unique()

array(['r', 't', 'sample_raw'], dtype=object)

In [65]:
long_df.to_csv(PATH + '/long_de.csv')