In [27]:
from collections import defaultdict
import matplotlib.pyplot as plt
from scipy.stats import chi2_contingency, ttest_ind, ttest_rel, shapiro, wilcoxon
import os
import pandas as pd
import seaborn as sns
import  matplotlib
import numpy as np

plt.rcParams['font.family'] = "Times New Roman"
plt.rcParams['font.size'] = 30
plt.rcParams['font.sans-serif'] = ""
plt.rcParams['text.usetex'] = True

plt.set_cmap('Blues')

def from_df_dict(df):
    #Excluding 'other' cases --> TODO print how many
    df = df[df.interpretation != 'other']
    values_df = dict(df.interpretation.value_counts(normalize = True))
    return values_df

def identify_datapoints_to_exclude(model, starting_dir = '.'):
    to_exclude = {}
    for data_type in ['NP-S', 'N-V', 'NP-Z']:
        to_exclude[data_type] = []
        if data_type.startswith('N-V'):
            categories= ['_N', '_V']
        else:
            categories = ['']
        model = 'GPT2'
        to_discard_all = set()
        
        # From Noun/Verb collect cases that are not disambiguated by syntactic agreement
        if data_type == 'N-V':
            original = pd.read_csv(starting_dir + '/data/prompts_N-V_V_ambiguous.tsv',sep ='\t')
            to_discard = list(original[original.disambiguation != 'agreement'].item)
            for i in to_discard:
                to_discard_all.add(i)
        
        # From all categories collect cases where the original sentence is correctly parsed
        for category in categories:
            to_eval = starting_dir + '/generated/'  + '/'.join([data_type + category, model,'original_stimuli', ''])
            if data_type == 'N-V':
                look_for = category[-1]
            else:
                look_for = data_type[-1]
            for condition in ['ambiguous', 'unambiguous']:
                df = pd.read_csv(to_eval + condition + '_interpretations_original.tsv', sep = '\t')
                to_discard = list(df[df.interpretation != look_for].stimulus_id)
                for i in to_discard:
                    to_discard_all.add(i)
        to_exclude[data_type] = list(to_discard_all)
    return to_exclude


<Figure size 432x288 with 0 Axes>

In [36]:
def collect_data(to_eval, to_exclude, sampling_type = 'p', data_type = 'N-V', output = '../analyzed', parser = 'allennlp', 
                 category = 'N', other_category = 'V', model = 'lstm_gulordava'):
    categories = [category]
    category_main = category
    to_eval_dirs = [to_eval]
    if data_type == 'N-V':
        to_eval_dirs.append(to_eval.replace('N-V_V', 'N-V_N'))
        categories.append('N')
    records_interpretations = []
    records_labels = []
    results_pre = {}
    results_post ={}
    for i in range(len(to_eval_dirs)):
        to_eval=  to_eval_dirs[i]
        category = categories[i]
        # Get values for all conditions
        for condition in ['ambiguous', 'unambiguous']:
            data_dir = to_eval + condition + '/evaluated_' + parser + '/'
            pre = pd.read_csv(data_dir + 'interpretations_pre_probs.tsv', sep = '\t')
            post = pd.read_csv(data_dir + 'interpretations_post_probs.tsv', sep = '\t')
            #exclude cases
            pre = pre[~pre.stimulus_id.isin(to_exclude[data_type])]
            post = post[~post.stimulus_id.isin(to_exclude[data_type])]
            # normalized quantity for each category overall
            interpretations_pre = pre.interpretation.value_counts(normalize = True)
            interpretations_post = post.interpretation.value_counts(normalize = True)
            for label in list(pre.interpretation):
                if condition == 'ambiguous':
                    prompt_type = 'no cue'
                else:
                    prompt_type = 'pre-target cue'
                prompt_type = '\\textsc{' + prompt_type + '}'
                records_labels.append([label, prompt_type, category])
            for label in list(post.interpretation):
                if condition == 'ambiguous':
                    prompt_type = 'post-target cue'
                else:
                    prompt_type = 'pre\&post-target cue'
                prompt_type = '\\textsc{' + prompt_type + '}'
                records_labels.append([label, prompt_type, category])
            # normalized quantity for each category per item
            for stimulus_id in list(set(pre.stimulus_id)):
                smaller_df = pre[pre.stimulus_id == stimulus_id]
                values = from_df_dict(smaller_df)
                value_category = values[category_main] if category_main in values else 0
                if condition == 'ambiguous':
                    prompt_type = 'no cue'
                else:
                    prompt_type = 'pre-target cue'
                prompt_type = '\\textsc{' + prompt_type + '}'
                records_interpretations.append([value_category, prompt_type, category, stimulus_id])
            for stimulus_id in list(set(post.stimulus_id)):
                smaller_df = post[post.stimulus_id == stimulus_id]
                values = from_df_dict(smaller_df)
                value_category = values[category_main] if category_main in values else 0
                if condition == 'ambiguous':
                    prompt_type = 'post-target cue'
                else:
                    prompt_type = 'pre\&post-target cue'
                prompt_type = '\\textsc{' + prompt_type + '}'
                records_interpretations.append([value_category, prompt_type, category, stimulus_id])
            if i == 0:
                results_pre[condition] = interpretations_pre
                results_post[condition] = interpretations_post
            else:
                results_pre[condition] = results_pre[condition], interpretations_pre
                results_post[condition] = results_post[condition], interpretations_post
    dist_values = pd.DataFrame.from_records(records_interpretations, 
                                                columns = ['P(' + category_main + ')', 'prompt_type', 'category', 'stimulus_id'])
    dist_labels = pd.DataFrame.from_records(records_labels, 
                                                columns = ['interpretation', 'prompt_type', 'category'])
    return dist_values, dist_labels

def plot_distributions(dist_values, sampling_type = 'p', data_type = 'N-V', output = '../analyzed', 
                       parser = 'allennlp', category = 'N', other_category = 'V', model = 'lstm_gulordava'):
    if data_type == 'N-V':
        categories = ['N', 'V']
        fig, axes = plt.subplots(1, 2, figsize = (24, 12), sharey = True)
        for i in range(len(categories)):
            sns.boxplot(x  = 'prompt_type', y = 'P(' + category + ')', data = dist_values[dist_values.category == categories[i]], 
                        color = 'White', fliersize = 0, showmeans = True, 
                        meanprops={"marker":"o", "markerfacecolor":"None",  "markeredgecolor":"black", "markersize":"10"}, 
                        ax = axes[i])
            sns.swarmplot(x  = 'prompt_type', y = 'P(' + category + ')', data = dist_values[dist_values.category == categories[i]],
                        alpha=0.6, dodge = True, ax = axes[i], palette = sns.color_palette("colorblind"))
            axes[i].set_xticklabels(axes[i].get_xticklabels(),rotation=15,  ha='right')
            axes[i].set_xlabel('prompt type')
    else:
        fig, axes = plt.subplots(1, 1, figsize = (12, 12), sharey = True)
        axes = sns.boxplot(x  = 'prompt_type', y = 'P(' + category + ')',data = dist_values, 
                           color = 'White', fliersize = 0,showmeans = True, 
                           meanprops={"marker":"o", "markerfacecolor":"None", "markeredgecolor":"black", "markersize":"10"})
        axes = sns.swarmplot(x  = 'prompt_type', y = 'P(' + category + ')', data = dist_values,
                             alpha=0.6, dodge = True, palette = sns.color_palette("colorblind"))
        axes.set_xticklabels(axes.get_xticklabels(),rotation=15, ha='right')
        axes.set_xlabel('prompt type')
    
    plt.tight_layout()
    plt.savefig('../results/p1/' + data_type + '_distribution_perdatapoint_' + model +'.png')
    plt.show()
    plt.clf()

import numpy as np

def plot_labels(dist_labels, sampling_type = 'p', data_type = 'N-V', output = '../analyzed', 
                       parser = 'allennlp', category = 'N', other_category = 'V', model = 'lstm_gulordava'):
    if data_type != 'N-V':
        fig, axes = plt.subplots(1, 1, figsize = (9, 5), sharey = True)
        categories = data_type.split('-') + ['other']
        prompt_types = ['no cue', 'post-target cue', 'pre-target cue', 'pre\&post-target cue']
        prompt_types = ['\\textsc{' + p + '}'for p in prompt_types]
        bottom = np.array([0.0,0.0,0.0,0.0])
        colors = ['teal', 'darkslategrey', 'azure']
        colors = colors[:len(categories)]
        for cat, color in zip(categories, colors):
            height_bar = []
            for p in prompt_types:
                dist_labels_p = dist_labels[dist_labels.prompt_type == p].interpretation.value_counts(normalize = True).mul(100)
                if cat in dist_labels_p:
                    height_bar.append(dist_labels_p[cat])
                else:
                    height_bar.append(0)    
            height_bar = np.array(height_bar)
            axes.bar(prompt_types, height_bar, bottom = bottom, width = 0.3, color = color, label = cat)
            bottom += height_bar
        axes.set_xticklabels(prompt_types,rotation=15)
        axes.set_xlabel('prompt type')
        plt.savefig('../results/p1/' + data_type + '_distribution_' + model +'.png')
        axes.legend()
        plt.show()
        plt.clf()


def report_labels(dist_labels_p1, dist_labels_beam, sampling_type = 'p', data_type = 'N-V', output = '../analyzed', 
                       parser = 'allennlp', category = 'N', other_category = 'V', model = 'lstm_gulordava'):
    prompt_types = ['no cue', 'post-target cue', 'pre-target cue', 'pre\&post-target cue']
    output_file = '../results/p1/' + data_type + '_distributionlabels_' + model +'.tsv'
    labels = data_type.split('-') + ['other']
    with open(output_file, 'w') as output_file:
        for dist_labels in [dist_labels_p1, dist_labels_beam]:
            for prompt_type in prompt_types:
                if data_type != 'N-V':
                    values = dict(dist_labels[dist_labels.prompt_type == '\\textsc{'+ prompt_type + '}'].interpretation.value_counts(normalize = True))
                    for label in labels:
                        if not label in values:
                            values[label] = 0
                        else:
                            values[label] = round(values[label] * 100, 1)
                    output_file.write(prompt_type + '\t'+ '\t'.join([str(values[label]) for label in labels]) + '\n')
                else:
                    for cat in ['N', 'V']:
                        dist_labels_tmp = dist_labels[dist_labels.category == cat]
                        values = dict(dist_labels_tmp[dist_labels_tmp.prompt_type == '\\textsc{'+ prompt_type + '}'].interpretation.value_counts(normalize = True))
                        for label in labels:
                            if not label in values:
                                values[label] = 0
                            else:
                                values[label] = round(values[label] * 100, 1)
                        output_file.write(prompt_type + '-' + cat + '\t' + '\t'.join([str(values[label]) for label in labels]) + '\n')
            output_file.write('\n\n')
            
def significance_tests(dist_values, sampling_type = 'p', data_type = 'N-V', output = '../analyzed', 
                       parser = 'allennlp', category = 'N', other_category = 'V', model = 'lstm_gulordava'):
    distr_analysis_file = '../results/p1/' + data_type + '_distribution_' + model +'.txt'
    categories = [category] if data_type != 'N-V' else ['N', 'V'] 
    category_main = category[0]
    if data_type == 'N-V':
        no_cue_cat, pre_cue_cat, post_cue_cat, pre_post_cue_cat = {}, {}, {}, {}
    with open(distr_analysis_file, 'w') as output_file:
        for category in categories:
            output_file.write('\nWilcoxon signed-rank test across prompt types -' + category + '\n')
            dist_values_cat = dist_values[dist_values.category == category]
            no_cue = dist_values_cat[dist_values_cat.prompt_type == '\\textsc{no cue}']['P(' + category_main + ')']
            pre_cue = dist_values_cat[dist_values_cat.prompt_type == '\\textsc{pre-target cue}']['P(' + category_main + ')']
            post_cue = dist_values_cat[dist_values_cat.prompt_type == '\\textsc{post-target cue}']['P(' + category_main + ')']
            pre_post_cue = dist_values_cat[dist_values_cat.prompt_type == '\\textsc{pre\&post-target cue}']['P(' + category_main + ')']
            output_file.write('No cue vs. post-target cue\t' + str(wilcoxon(no_cue, post_cue)) + '\n')
            output_file.write('No cue vs. pre&post-target cue\t'+ str(wilcoxon(pre_cue, pre_post_cue)) + '\n')
            output_file.write('No cue vs. pre-target cue\t'+ str( wilcoxon(no_cue, pre_cue)) + '\n')
            output_file.write('Post-cue vs. pre&post-target cue\t'+ str(wilcoxon(post_cue, pre_post_cue)) + '\n')
            if data_type == 'N-V':
                for x, y in zip([no_cue_cat, pre_cue_cat, post_cue_cat, pre_post_cue_cat], [no_cue, pre_cue, post_cue, pre_post_cue]):
                    x[category] = y
        if data_type == 'N-V':
            output_file.write('Wilcoxon signed-rank test across categories\n')
            output_file.write('No cue \t' + str(wilcoxon(no_cue_cat['N'], no_cue_cat['V'])) + '\n')
            output_file.write('Post cue\t'+ str(wilcoxon(post_cue_cat['N'], post_cue_cat['V'])) + '\n')
            output_file.write('Pre-target cue\t'+ str( wilcoxon(pre_cue_cat['N'], pre_cue_cat['V'])) + '\n')
            output_file.write('Pre&post-target cue\t'+ str(wilcoxon(pre_post_cue_cat['N'], pre_post_cue_cat['V'])) + '\n')

In [37]:
eval_dir = '../generated/'
data_types = ['N-V','NP-S', 'NP-Z']
models = ['lstm_gulordava', 'GPT2']

dist_data_models = {model:{} for model in models}

# for all combinations of data and model
for data_type in data_types:
    category = data_type.split('-')[-1]
    for model in models:
        #collect cases to exclude
        to_exclude = identify_datapoints_to_exclude(model, starting_dir = '..')
        data_type_tmp = data_type
        if data_type == 'N-V': data_type_tmp = 'N-V_V'
        # folder with data: standard generation by sampling
        to_eval = eval_dir + '/'.join([data_type_tmp, model, 'sampling-p_p1_temperature1_repetition1/'])
        dist_values, dist_labels = collect_data(to_eval, to_exclude, data_type =data_type, category = category,  model = model)
        dist_data_models[model][data_type] = dist_values
        plot_distributions(dist_values, data_type =data_type, category = category, model = model)
        plot_labels(dist_labels, data_type =data_type, category = category, model = model)
        #significance_tests(dist_values, data_type =data_type, category = category, model = model)
        to_eval_beam = eval_dir + '/'.join([data_type_tmp, model, 'beam-search_beams16_temperature1_repetition1/'])
        dist_values_beam, dist_labels_beam = collect_data(to_eval_beam, to_exclude, data_type =data_type, category = category,  model = model)
        report_labels(dist_labels, dist_labels_beam, data_type =data_type, category = category, model = model)

  -9.263269   -11.575602    -7.5252504   -1.2877626 ]
 2.1613958e-01 1.1074665e+00 1.4973709e+01 4.7867794e+00 5.0597434e+00
 -7.2901325e+00 -2.7656593e+00 -4.0863371e+00 -4.2561712e+00
  1.5615695   3.5013907   8.009781    9.746901    4.754291    1.9315523 ]
 -0.11623001 -3.8186455  -5.977291   -6.900938   -3.3176556  -1.1912279
 14.349806   2.2600281]
  7.832646    5.1931095   7.540637   14.438068    0.8915719 ]
  3.109976   7.3599534  8.286135   7.8089204  7.1055913  0.6820721
  1.0490116 10.220833   1.5081336 10.731841   3.0728016  2.195261 ]
  6.866201    3.5275526  10.36005     3.4965076   2.5250685   2.0272667 ]
 5.98179340e+00 8.81511211e+00]
 7.2843966e+00 6.2904336e-02 7.1342602e+00 3.9792850e+00 4.5858140e+00
 36.842014 ]
  -3.6102104   -4.593336   -13.108221    -1.0966702   -1.8237209
  5.9297404 12.775041  36.20532  ]
 -4.703723  -4.146202  -0.5169811]
 1.9663149]
  -2.216239    -7.7660775   -0.22794151]
  7.636613    6.895844    5.1901436   1.828624  ]
 -1.0685921e-01 -1.

  1.9644078   0.59822166]
 -3.7653532  -3.1178112  -4.1598034  -8.866962   -2.1337967  -1.230856
  3.0078871  5.0495977  7.7501616  9.422992  11.962185   5.396159 ]
  -0.97122955  -5.982012    -2.7957497   -2.943657   -10.473287
  -5.025017    -4.309044    -0.34487152  -6.5903635   -0.58148193
 -5.535513  -2.2916336 -5.895545  -0.5468273]
 -1.8079891  -1.3820419 ]
  -6.368997    -5.8102264  -11.967024    -0.47872353  -1.4385128
 -5.1793900e+00 -9.0295029e-01 -4.5288086e-01 -5.4342003e+00
 14.74455    5.7239857  3.1293304  2.8481703  5.615604   6.5887666
  -2.920576    -5.891079    -1.8388815 ]
  -2.2605124   -6.0137587   -2.7923918   -1.3722801   -5.8537474
  -8.682546    -0.67648697]
  -1.1190624 ]
  -8.62946     -4.448063   -10.005384    -8.948774    -3.184124
  -1.8697567   -1.6247301   -0.33823204]
  -3.2248068   -7.7762375   -7.636511    -5.1546354   -9.966798
  8.300992   7.3248925  5.3376756  3.653224   3.3916464  9.686644
 -5.3647938e+00 -3.4361172e+00 -4.0248470e+00 -6.1051693

  0.34209687  9.18979     0.41536143  6.9706717   0.4979572   0.7583719
 -5.067546   -2.161953  ]
  -1.526062  ]
 -7.554803   -0.35938072 -3.1164322  -3.9965572  -7.2267504  -6.7784376
  -4.3056116   -2.616537    -6.632613    -1.9011478   -9.599786
  5.585134    6.976615    2.101998    6.5956144   8.499891   12.820485
 -8.441429   -8.295147   -0.42557335 -0.8337002  -8.097021   -2.363677
 -5.5720310e+00 -2.4655342e-01 -1.7416782e+00 -9.7663879e-02
 14.594253  14.583988  12.719173   9.601067   9.927395   0.9237587
  -0.22102547  -9.469843    -2.2782192   -0.05559158  -7.2556744
 1.5931947e+01 8.4579077e+00 3.3174875e+00 8.9774981e+00 3.9603422e+00
  7.7971816 14.894174   1.174925   2.6549718  1.3828133  3.4985125]
  1.4226886  4.61113    8.805619   3.6781764 10.215725   3.261325
  4.018372    2.620523   14.680265    8.828979    3.2174046   3.4444327
  4.8027725   3.4256606   5.92695     6.028344    8.930493    3.851396
  -5.199581   -1.9638958  -3.2962399  -8.019475   -0.6387558 -25.523

  3.9770424   9.774461    3.9860544   8.66246     0.53219134]
 11.670613    0.37131462  0.90776014 11.672101    2.11802     5.7023253
 1.56510127e+00 1.76016402e+00]
  -1.5762806  -11.502078    -0.5090008   -0.8720999   -6.3884506
  7.491608    0.85548025]
  2.1643755  12.856518    1.0836971   4.3565845   4.5535984  13.38988
  -1.5785675  -1.8491364 -10.197338   -0.7292824]
  -4.384058    -0.762805    -4.5548487   -3.2913647   -2.4764996
  1.423206   6.73112    7.9887786  4.3972607 11.454495   1.294878
 -5.049054  -8.589908  -1.0213432]
  5.85192    3.4099028]
 -1.6422434 -3.2805328 -3.1115456 -1.6521702 -3.9148922 -5.520318 ]
  0.16269006 10.563512    5.1682796  16.295       0.33558083  1.7740079
 -4.2256727  -7.1151924  -2.377015   -2.014248   -0.41457748 -7.1968994
  -0.9135399   -3.764144    -9.132521    -2.1711435   -8.671581
  -3.5825405   -1.2740421   -7.5837746   -3.1887188   -9.652497
  -7.2125044   -0.903656  ]
  1.1292216   6.2068596   0.84590703  6.470348    6.5077343   6.9

  -2.6782198   -4.067712    -9.576471    -1.1837673   -5.981738
  2.5423632   8.065738    4.729757    3.4916275   1.0409821 ]
 10.373585    1.2224032 ]
  -2.908228    -2.2891045  -11.287757    -2.9986277   -6.7568884
 10.0809765  7.9290085  5.6977525  5.2084327  6.6267834 18.911165
  3.3653316]
  1.7682748]
  5.161029   13.774066    8.169169    4.2161283   4.7136383   4.576274
 -5.5182896e+00 -4.2614555e-01]
  3.1444674  12.689187    9.80109     0.29268143  0.56918687 11.984049
  -9.560716    -5.686138    -1.987298    -8.415571    -5.4287825
 -10.516752    -1.6069889   -2.7198448   -2.5016499   -6.9183397
 18.194418    2.316733    9.316391    1.1599281   3.282736    7.7702475
  -1.6583824 ]
  8.545938   3.504668  15.955398   1.526686   2.171834   0.4115889
  -2.3113823  -3.8970118  -7.6062717  -1.0499821]
 -6.30019    -5.952238   -1.835907   -1.1789246 ]
  -3.735242   -6.205735   -2.0469913 -11.702377   -1.2762871]
 -8.26749    -2.0545673  -3.1381283  -0.09843063 -2.0332527 ]
 8.912619

  5.8059278  19.643665    7.1565557   2.532014   14.770652    1.8508195 ]
 -10.12243    -1.8222084  -0.3168602  -5.3229523  -7.327841   -8.720335
  4.436438   8.806751   2.5593412  7.871112   1.1964654]
  2.6704006  9.028089  11.284172   3.1899176  7.5696926 10.332992
 -0.7357998]
  -0.6121807   -7.225564    -6.234603    -6.8986645   -0.11532593
  2.447924  12.175617  11.505464   0.5764831]
 7.11021137e+00 1.98105586e+00 1.55746663e+00 1.31502056e+00
  5.715041    1.5601331   7.17935     2.2404084   0.2891427   6.386727  ]
 -6.914651   -6.91675    -0.29340172 -1.4161644 ]
 -1.532259   -3.459938   -5.4387665  -1.8439865  -5.9470005  -1.1276875 ]
 15.447746   9.618988   0.9642558]
 10.655163   1.5942324  3.3360796 12.635153   3.3011823  8.693184
  0.42516318  8.1318035   2.8898835   9.812741    6.5800853   9.792108
 0.2595947]
 3.4746604  7.921169   0.39289096 0.59682655]
 16.915688    3.798212    5.5054593   4.706883   10.361952    3.19899
 16.884026    5.4894414   4.968467    3.6657553

 -3.675933  -2.2269    -0.9485674 -4.51898   -3.2621403 -0.5176735
 12.228781    1.7495973   8.196995    2.9207718   2.1616404   1.9785173
 16.2382      2.3576593   1.1748341  11.004658    2.4519305   5.520339  ]
  7.5546803  8.789676   7.6841435  2.5422297  2.6500711 15.473546
 -0.89982414]
  9.354111   13.085361    3.1460705   3.5633883   4.6696577   3.6708527
  -4.0777617   -0.82642746]
 -0.7998371]
  3.2634244   5.502856    5.178396    3.2303212   6.4265804   2.3109515
  4.8158555  10.08842     3.1465642  12.430267   15.525694    2.0043972
  -8.398869    -6.03006     -8.488105    -8.332048    -6.5875444
  -7.252819    -7.613165    -2.017579    -2.4503212   -3.9181776
  1.5951337   7.1218867   4.5001917  11.554096    2.1561313  10.849973
 -5.881551  -2.5913582]
  -3.0010977   -8.289825    -7.4636917   -6.4994135  -11.789612
 -9.554848   -3.935728   -3.692484   -5.7136316  -1.2893467  -5.527135
  -7.630928   -1.7480221  -1.4222946  -6.0152283  -1.3992252  -0.9175854]
  1.4601547 ]
 1

  -2.3198395   -6.380419   -14.521317    -0.4226036   -7.231348
 9.0664167e+00 6.9398261e-03 9.8663831e-01 1.0216884e+01 1.1391418e+01]
  1.5947661 ]
 -6.1882696  -0.72192764]
  2.4771557   6.162928    2.6403546   4.6645136   0.17920868  6.705203
  -2.7785587  -1.3812437  -6.030963   -2.5445786  -5.8844967  -8.550406
  6.301019    4.1752048  14.076864   12.155598    9.108012   10.118639  ]
 -7.107892  -0.7413273 -4.894497  -2.3609886 -2.9753237 -1.5230484
  2.939921   3.1410885  5.974559   4.6411486  3.3077905  6.9456143
  3.3813357   6.1989384   2.554915   11.564626    2.666664   16.915745
  3.935249   12.514594    4.972229    1.8335332   6.318227    0.87209505]
 -1.4300013  -4.781349   -3.4262676  -3.270421   -7.210943   -3.1102448 ]
 -3.1103077  -4.554449   -8.283484   -2.4001255  -7.9911346  -0.51901054
  -1.289279   -1.6693115  -3.4067478  -0.9849653  -4.3901215  -9.299112
  -5.06026    -4.0250216  -3.3092718  -1.5774031  -5.395361   -5.856385
  -6.502429    -0.3316841   -0.626722

  3.2597082 ]
  -9.038796    -0.4990406   -1.509366    -3.6485243   -0.45142174]
  -0.30386925  -1.5855141  -14.674344    -5.5753565   -3.7898903
 -2.4889393e+00 -1.2379460e+01 -1.2442158e+01 -1.3059965e+01
 -1.3776302 ]
 -1.179142  -3.386119  -6.071703  -0.7718487]
  3.5887923   1.9776051   9.434097    3.503969    3.1144516   4.913107
 5.53676784e-01 9.34889436e-01 4.65935135e+00 2.73084927e+00
  -7.2108226  -1.5836096]
  9.091468    3.0823529   0.46653256  5.9261284  13.825502    4.9157743
  3.3293834 17.520906   6.8727803 13.007615   0.4332835 36.86934  ]
  5.6805377   2.1779113   0.47585815  0.86199355]
  7.4558167  4.147352   8.246588   7.8348374  9.45943   10.992595
  1.3262147  10.804582    0.06936543  9.368019    5.588272    3.0026658
  -1.8559122  -5.588257   -2.5394192  -0.6005392  -1.9123116  -1.5651073]
  -1.8957157  -11.280846    -3.018631    -6.64729     -0.95547867
  1.4628459  10.349076    0.48885724]
  -4.3571386   -3.7068195   -1.2081394   -3.3423576   -0.9300356
 1.2

  -7.5258408   -1.0303822 ]
 -9.643099   -2.1982708  -3.3633776  -2.882639   -4.097494   -7.744129
  -0.82668304]
 -4.8639317  -3.776474   -3.2636738  -2.1575222  -6.2673473  -3.2187223 ]
 10.605537    1.2529362 ]
  1.6210259   7.883846    4.0707593   7.1629357   3.4019709  15.77079
  -1.5530872  -2.2379265  -2.1600018  -5.4914694  -5.666194  -10.170382
  2.4517984  5.7496967  3.856634  10.195459   9.718692   1.1139386]
 -10.240548    -3.2356892   -2.1585608   -6.7156496   -4.715823
  -3.378687    -1.758606    -5.9764223  -10.739283    -9.337885
  -2.3313074   -0.12238693  -0.30625153  -2.4616756   -1.8882685
  -1.5058403  -8.2600765  -4.220624   -1.0864143 -25.496603 ]
 -7.338373   -2.1657295  -6.8643684  -2.4755964  -4.937008   -3.0234394
  -3.022852    -1.5808249   -1.358674    -6.651102   -11.360223
  3.2947364   0.16491346  4.915871    0.6689465 ]
  -5.009121    -7.3739996   -2.665557    -2.6723413   -4.564213
  8.483324    8.104992    1.713503  ]
 -10.742912    -4.2166004   -0.68

  -5.5202627   -1.1168747   -9.457413    -2.9806347   -0.9948349
  0.9106688   2.0285149   1.8492594   7.24627     6.255788    2.9155848
 -10.326979    -3.721034    -3.7174673   -5.6384583   -6.72013
  2.673202    3.431318    3.6106422 ]
  -6.072176    -4.428423    -1.5824728  -11.8542185   -2.7750692 ]
  -2.136404   -2.921237   -9.81543    -0.533865  -25.580328 ]
 6.806544  1.3960066]
  9.486283    6.8086076   0.3836644   1.059116    3.7909586   7.251332
  -3.6261826   -3.252882    -2.949728    -4.6556063   -1.7121458
  -8.062035   -3.5169878  -4.3498     -5.865796   -0.5482521]
  -2.537857    -7.306598    -7.143366    -0.3657112   -2.3897896
  7.4042025  6.451962   7.111528   6.735451  11.484743   7.0682216
  5.4810305   0.85000163]
 -0.26555824]
 1.7205589e+01 4.3249979e+00 3.1903386e+00 1.0087342e+01 3.5488167e+00
  4.1263824   3.9824193   1.0075569   2.0498984   1.1895283 ]
  -1.9597263   -1.3561077   -4.8431215   -2.533722    -7.018839
  0.11844238  0.18046348  2.6484172   4.9658

  -6.1056128  -1.3038349]
 4.6405268e+00 1.6727484e+01 3.1227262e+00 1.2146295e+00 4.8067722e+00]
  -0.5461483 ]
 -0.9477329]
 -7.4695396  -1.2084589  -3.885459   -2.625023   -0.18899727]
  -5.6642017   -1.5515842   -5.8316793   -6.6948643   -0.48670387]
  5.8552194  7.2696915  1.8578887]
 10.126894  13.306892   2.7025216  6.761234   8.875846   8.187259
 -1.6611118e+00 -7.4789715e+00 -6.3357658e+00 -3.3220863e-01]
 1.2096401e+01 1.4219388e+00 1.1774877e+01 9.8130150e+00 2.0328543e+01
 -5.477868   -0.95635605 -6.0158405  -2.5374956  -6.914036   -6.7160683
 -0.08209801 -0.12508774 -1.835743   -3.4420834  -4.171712   -0.03342438
 -1.3814373]
  2.062518    1.3364015  14.65159     9.08707     0.94324374]
  5.388815   8.952984   2.9531841 16.882961   1.8412931]
  -8.888272    -6.5680466   -3.701415    -1.455533    -9.822699
  -3.8913994   -7.0122824  -10.028587    -0.48544312]
 -10.746498   -7.1400146  -1.0441685  -5.600193   -8.396435   -2.0887604
 15.559281    0.08743872  3.1425385   5.273

  0.5986785]
 10.477107   1.1412686]
  3.5178845  5.329686   2.0021877  5.9804735  7.586887   9.4178705
  3.6283166  11.262617    2.5141249   1.058684    4.110889    3.080746
  -1.8128166   -3.4253798   -0.82492447  -9.663272    -0.55937576]
 -5.075115   -3.2936678  -4.9664946  -3.320736   -6.5918465  -5.564374
  0.8840624]
  -5.377719   -12.615002    -4.5667315   -0.11057472  -0.4320984
 -10.843271   -0.4967289]
  0.37470475  6.384845    1.422855    9.290574    2.0296566   6.7984824
 -4.419962   -0.71611404]
 -10.988888    -3.0008755   -5.897064    -7.3964186   -7.1940312
 14.505463   1.6377563]
 10.0689945  7.1805     2.6305585  5.5401034  4.307911   7.3244934]
  9.703091    4.890276    9.326998    5.5556946   6.3112874   4.257718
  -1.891756   -5.334381   -0.5097084 -25.51617  ]
  7.9082575  17.541319    1.8406767   1.8745022   2.919345    8.415634
N-V_2_14
 -4.058529   -5.038966   -1.2877903 ]
  1.7985202  13.040227    0.71996343  2.1775548   5.263708    0.6512639 ]
  -2.1510324  -

  -0.24796486]
  -3.2785244   -2.6077824   -3.4096136   -4.0621223   -2.112955  ]
 14.111378   2.808646   4.5290504  9.5612335  0.8864702  7.041066
 -0.12944603 -3.990203   -4.008977  ]
 -1.4840183  -0.85218143 -5.789776   -6.9027042  -0.9889908  -6.006139
  1.610147   10.443093    4.583921    9.422927    4.7031145   7.869918
  2.7055745   8.02574     0.34776816 11.755793   12.289705    1.26011
  -8.683262   -3.1039696  -2.0809507  -3.847865   -3.2857695  -6.6622896]
  -1.2609892 ]
  -3.3753414   -4.428585    -4.020789    -1.7241869  -10.232788  ]
  4.047265    5.1928744   0.18227962 11.196241    3.1732862   3.5068846
  -1.778511    -1.9480209   -4.0500593   -1.7785997   -3.1743526
  4.7259483  2.385087 ]
  -5.4420443   -2.0860538   -1.3366804   -9.651083    -1.2489328 ]
  1.441344 ]
  2.0564241 ]
  -1.9392471  -3.5393744  -4.2315407  -2.0219755  -7.671626   -1.603858
  -6.2918587  -9.20444    -2.8210354  -0.3780403]
 -12.949769   -5.625412   -0.6550636]
 -4.075201  -3.5657673 -3.15256

  1.313726  15.379795   0.286669   8.19757    8.986043   1.2378651]
 -6.6929455e+00 -3.0751352e+00 -3.4935274e+00 -3.7968569e+00
  4.540514    9.876343    0.52624214  2.8508682 ]
 -3.757248   -0.17983246 -5.6145477  -0.7874794  -5.961521   -0.47288895]
  -2.5708466   -0.95705414]
  1.9876626   1.9913445  12.215829    0.38563466 12.295175    5.005771
  -1.8655062   -4.8162384  -11.379772    -3.5552464   -3.1029367
  0.76137954  8.644371    2.849959    7.2038      3.8144307  13.270776
  3.0922096 ]
 16.345276    3.4106472   2.2502966   7.814681    1.8822786   6.672554
  6.3874946   6.1378655   8.716327    0.29667142 13.044467    5.774268
  -4.872426    -0.42185783  -4.654545    -0.89696693]
 4.1940293  3.297389   3.696225   4.229579   0.92224807]
 -6.425996   -0.7819271 ]
  -2.745407   -7.0939856  -6.7987165  -5.7915754 -11.666769   -1.5926018
 -1.1570187 ]
  7.703711    3.8025062   1.77082    11.450957    6.3248587   1.100495
  7.7475696   2.1039038   4.81447     2.8895767   4.5352182  

  9.224462  19.908031   6.2789297  5.0238514  1.9286987  3.431592
 -4.7608633  -0.2774582  -2.8040829  -2.027382   -1.9027491  -4.163969
  4.5105424  10.909425   11.322601    3.7869892   1.4762963   1.4867501
  9.257363   1.4487818  3.3871996  1.1063465  3.5546033  4.681907
  -9.4048     -6.5483656  -1.9234467  -1.2171211 -11.8469715  -6.8714333
  -2.3264942   -5.560787    -9.592948    -7.162876    -1.1444769
  -3.846653    -6.961894    -1.8333664   -3.2553911   -4.3513966
 -10.438853  -11.124243   -3.0034523  -0.8110237]
 -1.4458599e+00 -6.1203270e+00 -7.8590813e+00 -8.2206726e-04
 13.443188  11.451726  10.811817  11.374263   3.5902603  9.089289
  5.783105   14.152169  ]
  -5.6352425   -5.7306185   -4.367612    -2.3416271   -1.086978
 -4.2586384  -0.6540184  -0.6177292  -3.5976982  -4.97832    -6.385874
  -4.895916    -6.486474    -8.933832    -0.1616211   -4.1477575
  1.3672897]
  -2.895153    -3.3782854  -11.351393    -3.260601    -0.39187813]
  -7.0300493  -5.795149   -0.9471779]
 

  5.1694283   5.601145    0.03640528  6.965154    4.995184    5.0530252
  1.8446996   2.8020334   7.6857657   1.3111022   6.6768975   8.3839
 -5.6902943  -0.44519997]
 -2.0520144  -2.1468248  -2.212905   -3.6581173  -3.1218166  -1.4607391
  -8.423253    -4.750204    -9.398247    -0.9144411 ]
  2.0151138 15.421626   1.0512846]
  -6.773428    -1.1775227 ]
 -3.9837055  -5.7285385  -7.3106384  -0.86376953 -5.9831142  -0.22649574
 -4.2596035 -4.056362  -1.0881195 -6.3847036 -4.802618  -5.2247677
  7.909474    3.9863763   9.805386    5.662175    6.991473    0.98227966
  8.46984     6.3573756   5.457828    4.194464   10.202597    7.485399
  8.670827    1.7717818  10.661749    8.06856     2.1326191  18.89769
  -0.8969097 ]
  -3.1668043   -2.4931622   -0.4337864 ]
  -4.32325     -0.40512466  -5.240923    -2.8945112   -2.0367804
  7.5058703 10.563037   2.285355   1.3212079 10.8841    10.913246
 12.158373   7.089707  11.934051   2.0902112]
 -12.153924    -7.5033083   -6.8803263   -3.350565    -8.

  -6.635703    -2.40621     -3.632412    -4.0420933   -6.5815153
  7.343094    4.006062    6.5992594  14.583319    0.44246045  1.4081306
  7.037379    2.6009858   3.3339758   8.318626    0.78474164  5.860621
  -5.1860485  -4.8330173  -9.233139  -11.781347   -1.4083805]
 -11.721571   -10.495987    -0.04399109  -3.916891    -3.489809  ]
  -5.0078077   -4.5106964   -2.6881409   -2.0788784   -5.6990213
  6.502079   2.1823387  4.3393836  1.3703412  2.158107   1.1492623]
  -6.7390184   -4.135889    -4.528819    -2.073886    -4.162445
  1.7519474   2.3228362   0.6822428   0.9910493   0.45202267  2.9492342
  6.289152    1.0107434   3.1419442   4.5160456   0.50010085]
  -1.6777916   -8.859524    -2.3347073   -5.239255    -0.04455948
  1.4129351   7.758034    1.1926515 ]
 -12.354277   -5.89649    -8.418645   -1.804698   -7.15275   -13.17663
 19.645803   2.7900732 11.742582   5.411715   0.9531333]
  -6.627964    -4.6738415   -6.989272    -0.55420685]
  -1.2895174   -3.408019    -3.6819897   -7.32

 -1.2350483e+00 -5.8241062e+00 -5.6359615e+00 -9.6983356e+00
  -1.1785259   -0.38741112  -7.077114    -3.5082731   -0.83675575
 13.579415  16.703283   2.8701208]
  3.934198   10.120896    2.8780196   4.8689084   3.992785    1.0762756
 -9.14361    -0.33175278]
  4.710506    1.9557179   0.02723656  6.3507733   2.6500008   8.188424
  1.4557478]
  -0.94145393 -11.721463    -0.11276817  -7.322068    -3.5823784
 -2.2694693e+00 -6.1300392e+00 -5.1707420e+00 -2.5242424e+00
 4.9790573e+00 8.9890594e+00 8.5536118e+00 6.2997384e+00 2.0038359e+00
 -4.3833456e+00 -2.6617832e+00]
 -0.58836555]
  -1.4448824  -4.179264   -2.9019642  -8.850694   -1.9694653  -1.251894
  7.068399    0.9334806   8.89996     1.868608    6.784186    2.1184342
  5.3469834   3.0474706   2.8978114   4.531501    8.642236    9.139768
 -9.252993   -4.3681526  -0.63970757]
  -5.6835213   -0.5169716   -6.877205    -1.41078     -7.333257
  -1.3077669   -7.320634    -0.827343    -3.3065472   -1.006958
  4.0666027  13.418089    5.0462

  9.863513  10.969658   4.428977   6.394187   8.096461   1.6596422
 -1.3052979]
  -4.285656   -0.6414089]
 -0.16682625 -0.53016853 -8.376915   -1.5176468  -3.9705143  -0.43150902]
  -3.9042358   -1.9172173   -7.7901516   -5.5539093   -8.379138
  0.94540095]
  -6.063383    -3.0917568   -0.6710987   -4.6584272   -8.122915
  -9.456317   -12.605513  ]
 -13.112419    -7.324706    -3.9657345   -2.705433    -4.6029615 ]
 -1.5197487  -4.995516   -0.39300537]
  2.9614203   5.498164    1.9656888   5.493682   12.232842    8.979894
 -13.49245     -0.18083763  -9.485636   -11.551602   -11.395832
  -4.359004   -8.023694   -5.2929115  -6.768276   -7.981058   -1.0077801]
 10.277022   13.652919    6.301216    3.362503   11.343732    0.47340906
 -0.9254799  -4.3767557  -5.235834   -2.0990543  -6.172674   -3.9704895 ]
 1.442805   3.991702   0.87848735 5.665686   0.9646604 ]
 3.5660672e-01 3.0042920e+00 1.0771485e+01 3.1648279e-04 1.2766090e+01
  -2.6250477 -10.1708975  -3.329875   -4.0784693 -12.7783575 

 -5.356188   -7.5212317  -1.8396282  -0.49815273 -9.220961   -0.18402481
  -2.0281525   -0.870924  ]
  -5.9295073  -2.6183968  -5.298954   -1.9408188]
  6.6138544  2.467951   4.0723305 14.108743   5.0526867  4.8842
  -0.87763596  -0.47697735  -4.6271334   -1.0812683 ]
 -1.9565392  -0.30112267 -6.368411   -5.935198   -1.4367905  -0.47582245]
  5.97792   12.857164   0.0606039  6.9861827  6.630838   7.98743
  2.091027    6.3602347  12.334642    0.18569174  7.7706027   3.56137
  -6.4733114  -9.879092   -3.5582705  -3.5465765  -4.6157446  -9.389186
  -2.0469189   -2.026947    -6.7454453   -2.8554525   -1.2103386
  -0.6963463  -6.45212    -6.2541943  -1.3838854  -5.77582    -1.9058399]
  9.769356  11.4824     9.029861  10.4409485 13.5064745  5.072686
 -13.845286    -3.62879     -1.7605143   -0.33019066  -1.3616238
  -4.3494167   -2.0048141   -3.9063702   -3.5470772   -1.8057995
 16.217789    0.9029831   8.675355    8.04501    11.77477     1.0809948 ]
  -0.9531374   -3.3331146   -6.8745117   

KeyboardInterrupt: 

In [6]:
import statsmodels.api as sm
from statsmodels.formula.api import ols, glm, logit
import itertools
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
import statsmodels.api as sm

from scipy import stats
models = ['lstm_gulordava', 'GPT2']
for model in models:
    dist_data = dist_data_models[model]
    for data_type in data_types:
        if data_type != 'N-V':
            categories = [data_type.split('-')[1]]
        else:
            categories = data_type.split('-')
        for cat in categories:
            dist_values = dist_data[data_type]
            if data_type == 'N-V':
                dist_values = dist_values[dist_values.category == cat]
            pre = ['yes' if 'pre' in i else 'no' for i in dist_values.prompt_type]
            post = ['yes' if 'post' in i else 'no' for i in dist_values.prompt_type]
            dist_values['pre'] = pre
            dist_values['post'] = post
            cue = ['yes' if 'post' in i or 'pre' in i else 'no' for i in dist_values.prompt_type]
            n_cues = []
            for i in dist_values.prompt_type:
                n = 0
                if 'post' in i:
                    n+= 1
                if 'pre' in i:
                    n+= 1
                n_cues.append(n)
            dist_values['cue'] = cue 
            dist_values['n_cues'] = n_cues
            dist_values['P'] = dist_values['P(' + data_type.split('-')[1] + ')']
            RSS_list, R_squared_list, feature_list = [],[], []
            numb_features = []
            df = dist_values[['pre', 'post', 'stimulus_id', 'cue', 'n_cues','P']]

            if data_type != 'N-V':
                df.to_csv('../results/regression_data/' + data_type + '-' + model + '.csv', index = False)
            else:
                df.to_csv('../results/regression_data/' + data_type + '-'+ cat + '-' + model + '.csv', index = False)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dist_values['pre'] = pre
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dist_values['post'] = post
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dist_values['cue'] = cue
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See 