In [1]:
import pandas as pd
import numpy as np

import json

from scripts.utils import SimulateData
from stopsignalmetrics.ssrtmodel import SSRTmodel

from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
SSRTscales = [85, 25, 5, 0]

SSRT_method_map = {'standard': 'Weighted', 'fixed': 'Fixed', 'tracking': 'Tracking'}
gen_map = {'gen-graded_both': 'Graded-Both',
           'gen-graded_go': 'Graded-Go',
           'gen-guesses': 'Guesses',
           'gen-standard': 'Independent'}

gen_map_simple = {'graded_both': 'Graded-Both',
           'graded_go': 'Graded-Go',
           'guesses': 'Guesses',
           'standard': 'Independent'}

# looking at correlations across generating models x SSRT methods, SSRT scales

In [3]:
full_ssrt_df = pd.DataFrame()

for SSRTscale in SSRTscales:
    ssrt_df = pd.read_csv('ssrt_metrics/expected_ssrts_SSRTscale-%d.csv' % SSRTscale, index_col = 0)
    # reformatting - was multiindex
    ssrt_df.columns = [ f'gen-{gen}_SSRT-{ssrt}' for gen, ssrt in zip(ssrt_df.loc['underlying distribution', :].values, ssrt_df.columns)]
    ssrt_df = ssrt_df.drop(['underlying distribution', 'NARGUID'])

    for col in ssrt_df.columns:
        ssrt_df[col] = ssrt_df[col].astype(float)
        
#     ssrt_means = ssrt_df.filter(regex='SSRT-standard|tracking|fixed').mean().to_frame(name='mean SSRT')
#     ssrt_means['SSRTscale'] = SSRTscale
#     ssrt_means['Generating Model'] = np.nan
#     ssrt_means['SSRT Method'] = np.nan
#     # ssrt_means[['Generating Model', 'SSRT Method']] = np.nan, np.nan
#     ssrt_means[['Generating Model', 'SSRT Method']] = ssrt_means.reset_index()['index'].str.split('_SSRT-', expand=True).values
#     ssrt_means['Generating Model'] = ssrt_means['Generating Model'].map(gen_map)
#     ssrt_means['SSRT Method'] = ssrt_means['SSRT Method'].apply(lambda x: SSRT_method_map[x.split('.')[0]])
    
#     ssrt_means = ssrt_means.reset_index(drop=True)
#     full_ssrt_df = pd.concat([full_ssrt_df, ssrt_means], 0)
#     full_ssrt_df['mean SSRT'] = full_ssrt_df['mean SSRT'].round(2)

# column_order = ['Generating Model', 'SSRT Method', 'SSRTscale', 'mean SSRT']
# full_ssrt_df = full_ssrt_df.sort_values(by=column_order[:-1], ascending=[False, True, False])
# full_ssrt_df = full_ssrt_df[column_order]

In [4]:
def read_in_multiidx_df(ssrtscale):
    tst_df = pd.read_csv('ssrt_metrics/expected_ssrts_SSRTscale-%d.csv' % ssrtscale, index_col = 0, header=[0,1])
    tst_df = tst_df.drop('guesses', level=0, axis=1)
    tst_df = tst_df.drop('graded_go', level=0, axis=1)
    tst_df = tst_df.drop('graded_both', level=0, axis=1)
    tst_df.columns = tst_df.columns.set_names('SSRT Method', level=0)
    tst_df.columns = tst_df.columns.set_names('Generating Model', level=1)
    tst_df = tst_df.rename(columns=SSRT_method_map, level=0)
    tst_df = tst_df.rename(columns=gen_map_simple, level=1)
    return tst_df

pd_df = {'SSRTscale=%d' % scale: read_in_multiidx_df(scale) for scale in  SSRTscales}

In [5]:
pd_df['SSRTscale=85'].corr(method='spearman')

Unnamed: 0_level_0,SSRT Method,Weighted,Weighted,Weighted,Weighted,Fixed,Fixed,Fixed,Fixed,Tracking,Tracking,Tracking,Tracking
Unnamed: 0_level_1,Generating Model,Graded-Both,Graded-Go,Guesses,Independent,Graded-Both,Graded-Go,Guesses,Independent,Graded-Both,Graded-Go,Guesses,Independent
SSRT Method,Generating Model,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
Weighted,Graded-Both,1.0,0.758175,0.954717,0.84807,0.832774,0.830987,0.802775,0.823524,0.797558,0.831956,0.828053,0.794088
Weighted,Graded-Go,0.758175,1.0,0.800954,0.932132,0.955442,0.960172,0.893299,0.967107,0.942721,0.963803,0.961764,0.940713
Weighted,Guesses,0.954717,0.800954,1.0,0.883052,0.916206,0.915362,0.912179,0.883527,0.840504,0.909605,0.903124,0.831826
Weighted,Independent,0.84807,0.932132,0.883052,1.0,0.945407,0.948271,0.859374,0.98929,0.98704,0.961287,0.964597,0.98811
Fixed,Graded-Both,0.832774,0.955442,0.916206,0.945407,1.0,0.995241,0.972835,0.973492,0.934561,0.992431,0.990233,0.927491
Fixed,Graded-Go,0.830987,0.960172,0.915362,0.948271,0.995241,1.0,0.969973,0.975533,0.937234,0.992267,0.989591,0.930623
Fixed,Guesses,0.802775,0.893299,0.912179,0.859374,0.972835,0.969973,1.0,0.90689,0.843852,0.957535,0.945583,0.832058
Fixed,Independent,0.823524,0.967107,0.883527,0.98929,0.973492,0.975533,0.90689,1.0,0.9847,0.983529,0.985174,0.98327
Tracking,Graded-Both,0.797558,0.942721,0.840504,0.98704,0.934561,0.937234,0.843852,0.9847,1.0,0.953058,0.957436,0.99481
Tracking,Graded-Go,0.831956,0.963803,0.909605,0.961287,0.992431,0.992267,0.957535,0.983529,0.953058,1.0,0.988424,0.948916


In [6]:
pd_df['SSRTscale=25'].corr(method='spearman')

Unnamed: 0_level_0,SSRT Method,Weighted,Weighted,Weighted,Weighted,Fixed,Fixed,Fixed,Fixed,Tracking,Tracking,Tracking,Tracking
Unnamed: 0_level_1,Generating Model,Graded-Both,Graded-Go,Guesses,Independent,Graded-Both,Graded-Go,Guesses,Independent,Graded-Both,Graded-Go,Guesses,Independent
SSRT Method,Generating Model,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
Weighted,Graded-Both,1.0,0.364069,0.928638,0.730268,0.64639,0.633387,0.594147,0.616733,0.453262,0.637547,0.634508,0.477163
Weighted,Graded-Go,0.364069,1.0,0.427898,0.661625,0.838623,0.854347,0.760262,0.849623,0.67099,0.847368,0.85451,0.698325
Weighted,Guesses,0.928638,0.427898,1.0,0.713065,0.760578,0.754876,0.755591,0.675903,0.420635,0.748634,0.726886,0.449477
Weighted,Independent,0.730268,0.661625,0.713065,1.0,0.743179,0.751169,0.587324,0.915837,0.876471,0.779265,0.836898,0.89192
Fixed,Graded-Both,0.64639,0.838623,0.760578,0.743179,1.0,0.9774,0.952675,0.890551,0.587794,0.971362,0.94686,0.623727
Fixed,Graded-Go,0.633387,0.854347,0.754876,0.751169,0.9774,1.0,0.943722,0.895893,0.599108,0.968865,0.949542,0.634749
Fixed,Guesses,0.594147,0.760262,0.755591,0.587324,0.952675,0.943722,1.0,0.765161,0.388202,0.926997,0.859342,0.429281
Fixed,Independent,0.616733,0.849623,0.675903,0.915837,0.890551,0.895893,0.765161,1.0,0.838642,0.911624,0.951832,0.860441
Tracking,Graded-Both,0.453262,0.67099,0.420635,0.876471,0.587794,0.599108,0.388202,0.838642,1.0,0.645137,0.719878,0.965121
Tracking,Graded-Go,0.637547,0.847368,0.748634,0.779265,0.971362,0.968865,0.926997,0.911624,0.645137,1.0,0.943904,0.67735


In [7]:
pd_df['SSRTscale=5'].corr(method='spearman')

Unnamed: 0_level_0,SSRT Method,Weighted,Weighted,Weighted,Weighted,Fixed,Fixed,Fixed,Fixed,Tracking,Tracking,Tracking,Tracking
Unnamed: 0_level_1,Generating Model,Graded-Both,Graded-Go,Guesses,Independent,Graded-Both,Graded-Go,Guesses,Independent,Graded-Both,Graded-Go,Guesses,Independent
SSRT Method,Generating Model,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
Weighted,Graded-Both,1.0,0.102276,0.918045,0.819518,0.576365,0.558566,0.561902,0.540865,0.047667,0.560797,0.548669,0.177228
Weighted,Graded-Go,0.102276,1.0,0.214396,0.179188,0.723337,0.746706,0.725822,0.681968,0.072301,0.71553,0.70553,0.248062
Weighted,Guesses,0.918045,0.214396,1.0,0.812287,0.716607,0.71332,0.727794,0.674501,0.020389,0.714339,0.695235,0.194518
Weighted,Independent,0.819518,0.179188,0.812287,1.0,0.564718,0.563712,0.542127,0.687041,0.252986,0.568062,0.594996,0.381794
Fixed,Graded-Both,0.576365,0.723337,0.716607,0.564718,1.0,0.961454,0.975825,0.892868,0.025841,0.959179,0.934345,0.255687
Fixed,Graded-Go,0.558566,0.746706,0.71332,0.563712,0.961454,1.0,0.968315,0.887588,0.030622,0.952217,0.928753,0.256504
Fixed,Guesses,0.561902,0.725822,0.727794,0.542127,0.975825,0.968315,1.0,0.879787,-0.028331,0.965392,0.927807,0.204531
Fixed,Independent,0.540865,0.681968,0.674501,0.687041,0.892868,0.887588,0.879787,1.0,0.176855,0.886664,0.902551,0.395418
Tracking,Graded-Both,0.047667,0.072301,0.020389,0.252986,0.025841,0.030622,-0.028331,0.176855,1.0,0.0424,0.100646,0.554048
Tracking,Graded-Go,0.560797,0.71553,0.714339,0.568062,0.959179,0.952217,0.965392,0.886664,0.0424,1.0,0.925121,0.270117


In [8]:
pd_df['SSRTscale=0'].corr(method='spearman')

Unnamed: 0_level_0,SSRT Method,Weighted,Weighted,Weighted,Weighted,Fixed,Fixed,Fixed,Fixed,Tracking,Tracking,Tracking,Tracking
Unnamed: 0_level_1,Generating Model,Graded-Both,Graded-Go,Guesses,Independent,Graded-Both,Graded-Go,Guesses,Independent,Graded-Both,Graded-Go,Guesses,Independent
SSRT Method,Generating Model,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
Weighted,Graded-Both,1.0,0.082797,0.915809,0.845988,0.561281,0.542712,0.550469,0.520519,-0.107311,0.540991,0.528603,0.129091
Weighted,Graded-Go,0.082797,1.0,0.186984,0.093039,0.726581,0.748673,0.73533,0.677372,-0.153037,0.71727,0.701549,0.146954
Weighted,Guesses,0.915809,0.186984,1.0,0.847432,0.696445,0.693348,0.710136,0.662723,-0.142006,0.691736,0.675901,0.15718
Weighted,Independent,0.845988,0.093039,0.847432,1.0,0.535537,0.529837,0.537537,0.612766,-0.09622,0.530211,0.513717,0.120431
Fixed,Graded-Both,0.561281,0.726581,0.696445,0.535537,1.0,0.959347,0.978521,0.903127,-0.213625,0.957482,0.937401,0.18381
Fixed,Graded-Go,0.542712,0.748673,0.693348,0.529837,0.959347,1.0,0.970373,0.895301,-0.20753,0.948691,0.928647,0.181769
Fixed,Guesses,0.550469,0.73533,0.710136,0.537537,0.978521,0.970373,1.0,0.91356,-0.212949,0.968291,0.94783,0.185221
Fixed,Independent,0.520519,0.677372,0.662723,0.612766,0.903127,0.895301,0.91356,1.0,-0.201775,0.896265,0.879421,0.175608
Tracking,Graded-Both,-0.107311,-0.153037,-0.142006,-0.09622,-0.213625,-0.20753,-0.212949,-0.201775,1.0,-0.209105,-0.242532,0.021964
Tracking,Graded-Go,0.540991,0.71727,0.691736,0.530211,0.957482,0.948691,0.968291,0.896265,-0.209105,1.0,0.927836,0.182984


In [9]:
SSRT_types = ['Weighted', 'Fixed', 'Tracking']
for scale_key in pd_df.keys():
    corr_array = []
    for ssrt_type in SSRT_types:
        curr_corr_df = pd_df[scale_key].filter(regex=ssrt_type).corr(method='spearman').copy()
        np.fill_diagonal(curr_corr_df.values, np.nan)
        corr_array.append(curr_corr_df[(ssrt_type, 'Independent')].values)
    corr_array = [r for rlist in corr_array for r in rlist]
    print(scale_key)
    print('min: ', np.nanmin(corr_array))
    print('mean:', np.nanmean(corr_array))

SSRTscale=85
min:  0.8480698857634738
mean: 0.9348751745773315
SSRTscale=25
min:  0.6616250893656151
mean: 0.7837094289967573
SSRTscale=5
min:  0.17918847359029216
mean: 0.6264209665301154
SSRTscale=0
min:  0.021963680877580986
mean: 0.5421801827423268


# Looking at SSRTs from different scales, generating models, and SSRT methods

In [None]:
full_ssrt_df = pd.DataFrame()

for SSRTscale in SSRTscales:
    ssrt_df = pd.read_csv('ssrt_metrics/expected_ssrts_SSRTscale-%d.csv' % SSRTscale, index_col = 0)
    # reformatting - was multiindex
    ssrt_df.columns = [ f'gen-{gen}_SSRT-{ssrt}' for gen, ssrt in zip(ssrt_df.loc['underlying distribution', :].values, ssrt_df.columns)]
    ssrt_df = ssrt_df.drop(['underlying distribution', 'NARGUID'])

    for col in ssrt_df.columns:
        ssrt_df[col] = ssrt_df[col].astype(float)
        
    ssrt_means = ssrt_df.filter(regex='SSRT-standard|tracking|fixed').mean().to_frame(name='mean SSRT')
    ssrt_means['SSRTscale'] = SSRTscale
    ssrt_means['Generating Model'] = np.nan
    ssrt_means['SSRT Method'] = np.nan
    # ssrt_means[['Generating Model', 'SSRT Method']] = np.nan, np.nan
    ssrt_means[['Generating Model', 'SSRT Method']] = ssrt_means.reset_index()['index'].str.split('_SSRT-', expand=True).values
    ssrt_means['Generating Model'] = ssrt_means['Generating Model'].map(gen_map)
    ssrt_means['SSRT Method'] = ssrt_means['SSRT Method'].apply(lambda x: SSRT_method_map[x.split('.')[0]])
    
    ssrt_means = ssrt_means.reset_index(drop=True)
    full_ssrt_df = pd.concat([full_ssrt_df, ssrt_means], 0)
    full_ssrt_df['mean SSRT'] = full_ssrt_df['mean SSRT'].round(2)

column_order = ['Generating Model', 'SSRT Method', 'SSRTscale', 'mean SSRT']
full_ssrt_df = full_ssrt_df.sort_values(by=column_order[:-1], ascending=[False, True, False])
full_ssrt_df = full_ssrt_df[column_order]

In [None]:
ax = plt.subplot(111, frame_on=False) # no visible frame
ax.xaxis.set_visible(False)  # hide the x axis
ax.yaxis.set_visible(False)  # hide the y axis

table = ax.table(full_ssrt_df.values, colLabels=full_ssrt_df.columns,
                 loc='center', cellLoc='left',
                 colWidths = [0.35, 0.32, 0.22, 0.30], fontsize=18)


for (row, col), cell in table.get_celld().items():
    if (row == 0) or (col == -1):
        cell.set_text_props(fontproperties=FontProperties(weight='bold'))

table.auto_set_font_size(False)
table.set_fontsize(20)
table.scale(2, 3)

for key, cell in table.get_celld().items():
    cell.set_linewidth(3)

# plt.savefig('Figures/Full_sim_SSRT_table.png', bbox_inches='tight', transparent=True)