# Import

In [1]:
%matplotlib widget

In [2]:
import os
import pickle5 as pickle
import copy

import pandas as pd
import seaborn as sns
import numpy
import torch
import scipy
import scipy.stats

import pyro
import pyro.infer
import pyro.infer.mcmc
import pyro.distributions as dist
import torch.distributions.constraints as constraints
from tqdm.auto import tqdm

from pprint import pprint

import matplotlib.pyplot as plot
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import warnings
warnings.filterwarnings('ignore')
sns.set(style="whitegrid")

In [3]:
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))

# Define

In [4]:
def sigmoid(x):
    return 1./(1.+torch.exp(-x))

def icc_best_deriv(alpha, beta, theta, model_names, gamma=None, col='mean', target=None):
    '''
    Method to calculate the locally estimated headroom (LEH) score, defined as
    the derivative of the item characteristic curve w.r.t. the best performing model.
    
    Args:
        alpha:       DataFrame of discrimination parameter statistics for each item.
        beta:        DataFrame of difficulty parameter statistics for each item.
        theta:       DataFrame of ability parameter statistics for each responder.
        model_names: List of responder names.
        gamma:       DataFrame of guessing parameter statistics for each item.
        col:         DataFrame column name to use for calculating LEH scores.
    
    Returns:
        scores:      LEH scores for each item.    
    '''
    if target:
        idx = model_names.index(target)
        best_value = theta[col].iloc[idx]
        print('Local Grad for Target Model')
        print(f'Target model: {target}\n{best_value}')
    else:
        best_idx, best_value = theta[col].argmax(), theta[col].max()
        print(f'Best model: {model_names[best_idx]}\n{best_value}')
    
    a, b = torch.tensor(alpha[col].values), torch.tensor(beta[col].values)
    
    logits = (a*(best_value-b))
    sigmoids = sigmoid(logits)
    scores = sigmoids*(1.-sigmoids)*a
    
    print(f'No gamma: {scores.mean()}')
    if not gamma is None:
        g = torch.tensor(gamma[col].apply(lambda x: x.item()).values)
        scores = (1.-g)*scores
        print(f'With gamma: {scores.mean()}')
    
    return scores      
    
    

In [5]:
def get_model_guide(alpha_dist, theta_dist, alpha_transform, theta_transform):
    model = lambda obs: irt_model(obs, alpha_dist, theta_dist, alpha_transform = alpha_transform, theta_transform = theta_transform)
    guide = lambda obs: vi_posterior(obs, alpha_dist, theta_dist)
    
    return model, guide

In [6]:
def get_data_accuracies(data, verbose = False, get_cols = False):
    '''
    Method to reformat `data` and calculate item and responder accuracies.
    
    Args:
        data:                DataFrame of item responses.
        verbose:             Boolean value of whether to print statements.
        get_cols:            Boolean value of whether to return original column
                             values of `data`.
        
    Returns:
        new_data:            Reformatted `data`, dropping first column.
        accuracies:          Accuracy for each responder across examples.
        example_accuracies:  Accuracy for each example across responders.
        data.columns.values: Returns only if `get_cols` is True. Original column
                             values of `data`.
    '''
    new_data = numpy.array(data)
    new_data = new_data[:,1:]
    
    model_names = dict(data['userid'])
    accuracies = new_data.mean(-1)
    example_accuracies = new_data.mean(0)
    
    if verbose:
        print('\n'.join([f'{name}: {acc}' for name, acc in zip(model_names.values(),accuracies)]))
    
    if get_cols:
        return new_data, accuracies, example_accuracies, data.columns.values
    else:
        return new_data, accuracies, example_accuracies

In [7]:
def get_stats_CI(params, p=0.95, dist='normal'):
    '''
    Method to calculate lower and upper quantiles defined by `p`, mean, and variance of `param`
    
    Args:
        params: Dictionary of distribution parameters for each item keyed according to the 
                parametric distribution defined by `dist`.
        p:      Percent of distribution covered by the lower and upper interval values for each
                parameter.
        dist:   Name of parametric distribution
    
    Returns:
        return: {
            'lower': Lower interval values of each parameter,
            'upper': Upper interval values of each parameter,
            'mean' : Mean of each parameter,
            'var'  : Variance of each parameter
        }
    '''
    stats = {}
    if dist == 'normal':
        L,U = scipy.stats.norm.interval(p,loc=params['mu'], scale=torch.exp(params['logstd']))
        M,V = scipy.stats.norm.stats(loc=params['mu'], scale=torch.exp(params['logstd']))
    elif dist == 'log-normal':
        L,U = scipy.stats.lognorm.interval(p, s=torch.exp(params['logstd']), scale=torch.exp(params['mu']))
        M,V = scipy.stats.lognorm.stats(s=torch.exp(params['logstd']), scale=torch.exp(params['mu']))
    elif dist == 'beta':
        L,U = scipy.stats.beta.interval(p,a=params['alpha'], b=params['beta'])
        M,V = scipy.stats.beta.stats(a=params['alpha'], b=params['beta'])
    else:
        raise TypeError(f'Distribution type {dist} not supported.')
    
    return {
        'lower':[L],
        'upper':[U],
        'mean':[M],
        'var':[V],
    }

In [8]:
def get_plot_stats(exp_dir, alpha_dist, theta_dist, transforms, p = 0.95):
    '''
    Method to return plotting statistics for 3 parameter IRT model parameters.
    
    Args:
        exp_dir:          Path to 3 parameter IRT parameters and responses.
        alpha_dist:       Name of the item discrimination [a] distribution.
        theta_dist:       Name of the responder ability [t] distribution.
        transforms:       Dictionary of transformations to apply to each parameter type
                          where keys are parameter names and values are functions.
        p:                Percent of distribution covered by the lower and upper interval 
                          values for each parameter.
    
    Returns:
        param_plot_stats: Dictionary of parameter plot statistics where keys are parameter
                          names and values are plot statistics dictionaries as defined by
                          get_stats_CI().
    '''
    param_dists = {
        'a':alpha_dist,
        'b':'normal',
        'g':'normal',
        't':theta_dist,
    }

    dist_params = {
        'normal':['mu', 'logstd'],
        'log-normal':['mu', 'logstd'],
        'beta':['alpha', 'beta'],
    }

    pyro.clear_param_store()
    pyro.get_param_store().load(os.path.join(exp_dir, 'params.p'))

    with torch.no_grad():
        pyro_param_dict = dict(pyro.get_param_store().named_parameters())
    
    # get stats for plotting
    param_plot_stats = {}

    for param, param_dist in param_dists.items():
        temp_params = dist_params[param_dist]

        for idx, (p1_orig, p2_orig) in enumerate(zip(pyro_param_dict[f'{param} {temp_params[0]}'], pyro_param_dict[f'{param} {temp_params[1]}'])):
            p1, p2 = p1_orig.detach(), p2_orig.detach()
            
            temp_stats_df = pd.DataFrame.from_dict(
                get_stats_CI(
                    params = {
                        temp_params[0]:p1,
                        temp_params[1]:p2,
                    },
                    p=p,
                    dist = param_dist,
                )
            )
            
            temp_stats_df = temp_stats_df.applymap(transforms[param])
        
            if idx == 0:
                param_plot_stats[param] = temp_stats_df
            else:
                param_plot_stats[param] = param_plot_stats[param].append(temp_stats_df, ignore_index = True)
    
    return param_plot_stats

In [9]:
def sign_mult(df1, df2):
    newdf = copy.deepcopy(df2)
    
    for idx, row in df1.iterrows():
        if numpy.sign(row['mean']) < 0:
            newdf.loc[idx,'mean'] = -1*newdf.loc[idx,'mean']
            newdf.loc[idx,'lower'] = -1*newdf.loc[idx,'upper']
            newdf.loc[idx,'upper'] = -1*newdf.loc[idx,'lower']
    
    return newdf

In [10]:
def get_diff_by_set(diffs, item_ids):
    diff_by_set = {}
    id_split = '_'

    max_diff = -1e6
    min_diff = 1e6
    
    for idx, diff in enumerate(diffs):
        set_name = item_ids[idx].split(id_split)[0]

        if set_name in diff_by_set.keys():
            diff_by_set[set_name].append(diff)
        else:
            diff_by_set[set_name] = [diff]
            
        if diff < min_diff:
            min_diff = diff
            
        if diff > max_diff:
            max_diff = diff
    
    return diff_by_set, min_diff, max_diff

# Get Tasks

In [11]:
from variational_irt import *

In [12]:
datasets="boolq,cb,commonsenseqa,copa,cosmosqa,hellaswag,adversarial-nli,rte,snli,wic,qamr,arct,mcscript,mctaco,mutual,mutual-plus,quoref,socialiqa,squad-v2,wsc,mnli,mrqa-nq,newsqa,abductive-nli,arc-easy,arc-challenge,piqa,quail,winogrande,anli"
data_names, responses, n_items = get_files(
    os.path.join(repo, 'data'),
    "csv",
    set(datasets.split(','))
)

In [13]:
task_metadata = pd.read_csv('task_metadata.csv')
task_metadata.set_index("jiant_name", inplace=True)
task_list = [x for x in task_metadata.index if x in data_names]

In [14]:
total = 0
task_name = []
task_format = []

for tname, size in zip(data_names, n_items):
    name = task_metadata.loc[tname]['taskname']
    total += size
    task_name += [name for _ in range(size)]
    task_format += [task_metadata.loc[tname]['format'] for _ in range(size)]
    
task_name = pd.DataFrame(task_name, columns=['task_name'])
task_format = pd.DataFrame(task_format, columns=['format'])
task_name_format = pd.concat([task_name, task_format], axis=1)

In [15]:
len(data_names)

28

# Get Params and Order

In [16]:
exp_dir = os.path.join(repo, 'params', f'alpha-lognormal-identity_theta-normal-identity_nosubsample_1.00_0.30')
p = 0.95

with open(os.path.join(exp_dir, 'responses.p'), 'rb') as f:
    combined_responses = pickle.load(f).reset_index()

In [17]:
# Check accuracy of roberta-large models

extractmodel = 'roberta-large_best'
tie_break = 0

acc_by_dataset = {}

roberta_rp = combined_responses.loc[combined_responses['userid']==extractmodel, :]
if roberta_rp.shape[0] > 1:
    roberta_rp = roberta_rp.iloc[tie_break, :]

cols = combined_responses.columns.values

for item in cols[1:]:
    data_name = '_'.join(item.split('_')[:-1])
    resp = roberta_rp[item].item()
    
    if data_name in acc_by_dataset:
        acc_by_dataset[data_name]['correct'] += resp
        acc_by_dataset[data_name]['total'] += 1
    else:
        acc_by_dataset[data_name] = {'correct': resp, 'total': 1}

print(extractmodel)
print('='*90)
print(f'Overall acc: {roberta_rp.iloc[0, 1:].sum()/(roberta_rp.shape[1]-1):.4f}')        

for data_name, acc_dict in acc_by_dataset.items():
    print(f'{data_name} acc: {acc_dict["correct"]/acc_dict["total"]:.4f}')

roberta-large_best
Overall acc: 0.7403
abductive_nli acc: 0.8564
adversarial_nli acc: 0.4938
arc_challenge acc: 0.3319
arc_easy acc: 0.6296
arct acc: 0.8604
boolq acc: 0.8367
cb acc: 0.8571
commonsenseqa acc: 0.6759
copa acc: 0.8400
cosmosqa acc: 0.7984
hellaswag acc: 0.8417
mcscript acc: 0.9183
mctaco acc: 0.5360
mnli acc: 0.8991
mrqa_natural_questions acc: 0.6941
mutual_plus acc: 0.7314
mutual acc: 0.8668
newsqa acc: 0.5542
piqa acc: 0.7617
qamr acc: 0.7303
quail acc: 0.6691
quoref acc: 0.8023
rte acc: 0.8345
snli acc: 0.9203
socialiqa acc: 0.7738
squad_v2 acc: 0.4326
wic acc: 0.7085
winogrande acc: 0.7697
wsc acc: 0.6154


In [18]:
# set to False if run for the first time
# note that this will take sometimes to run if the datasets are big
load_from_cache = True

In [19]:
# distribution and transformation
alpha_dist = 'log-normal'
alpha_transf = 'standard'
theta_dist = 'normal'
theta_transf = 'standard'

exp_dir = os.path.join(repo, 'params', f'alpha-lognormal-identity_theta-normal-identity_nosubsample_1.00_0.30')
p = 0.95

with open(os.path.join(exp_dir, 'responses.p'), 'rb') as f:
    combined_responses = pickle.load(f).reset_index()
data, accuracies, example_accuracies = get_data_accuracies(combined_responses)
column_names = combined_responses.columns[1:]
select_ts = {
    'standard':lambda x:x,
    'positive':lambda x:torch.log(1+torch.exp(torch.tensor(x))),
    'sigmoid':lambda x:sigmoid(torch.tensor(x)),
}

transforms = {
    'a':select_ts[alpha_transf],
    'b':select_ts['standard'],
    'g':select_ts['sigmoid'],
    't':select_ts[theta_transf],
}

if load_from_cache:
    param_plot_stats = {}

    for key in transforms.keys():
        with open(os.path.join('plot_stats_pickles', f'{key}.p'), 'rb') as f:
            param_plot_stats[key] = pickle.load(f)
else:
    param_plot_stats = get_plot_stats(
        exp_dir,
        alpha_dist,
        theta_dist,
        transforms,
        p = 0.95
    )
    
    os.makedirs('plot_stats_pickles', exist_ok=True)
    for key, value in param_plot_stats.items():
        with open(os.path.join('plot_stats_pickles', f'{key}.p'), 'wb') as f:
            pickle.dump(value, f)

In [20]:
combined_responses

Unnamed: 0,userid,abductive_nli_0,abductive_nli_1,abductive_nli_2,abductive_nli_3,abductive_nli_4,abductive_nli_5,abductive_nli_6,abductive_nli_7,abductive_nli_8,...,wsc_42,wsc_43,wsc_44,wsc_45,wsc_46,wsc_47,wsc_48,wsc_49,wsc_50,wsc_51
0,roberta-base-10M-1_best,1,0,1,1,1,1,1,0,1,...,1,0,1,1,0,1,0,1,0,1
1,roberta-base-10M-1_1,0,0,0,1,1,0,1,0,1,...,1,0,1,0,0,1,0,0,0,1
2,roberta-base-10M-1_25,1,1,1,1,1,1,1,1,1,...,0,0,0,1,0,1,1,1,1,1
3,roberta-base-10M-1_50,1,0,1,1,1,1,1,0,1,...,1,0,0,1,0,1,1,1,0,1
4,roberta-base-10M-1_10,1,0,1,1,1,1,1,0,1,...,1,0,0,1,0,1,1,1,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85,xlm-roberta-large_best,1,1,0,1,1,1,1,1,1,...,1,1,0,1,0,1,1,1,0,0
86,xlm-roberta-large_1,1,1,0,1,1,0,1,0,1,...,0,0,1,0,1,0,0,0,1,1
87,xlm-roberta-large_25,1,1,0,1,1,1,0,1,1,...,1,1,0,1,0,1,1,1,0,0
88,xlm-roberta-large_50,1,1,1,1,1,1,1,1,1,...,1,1,0,1,0,1,1,1,0,0


In [21]:
model_names = []
model_levels = []
for m in combined_responses['userid']:
    mname = m.split('_')[0]
    mlevel = m.split('_')[-1]
    if mname.endswith('-1') or mname.endswith('-2') or mname.endswith('-3'):
        mname = mname[:-2]
    model_names.append(mname)
    
    mlevel_append = '' if mlevel == 'best' else r'%'
    model_levels.append(mlevel+mlevel_append)

In [22]:
#  we will only use log mean for discriminative parameter
for param_key, param_stat in param_plot_stats.items():
    param_stat['log_mean'] = numpy.log(param_stat['mean'])
    print(param_key, param_stat['log_mean'].isnull().sum())

a 0
b 38590
g 0
t 47


In [23]:
param_a = pd.concat([param_plot_stats['a'], task_name_format], axis=1)
param_b = pd.concat([param_plot_stats['b'], task_name_format], axis=1)

task_order = [task_metadata.loc[x]['taskname'] for x in task_list]

In [24]:
leh_scores = icc_best_deriv(
    param_plot_stats['a'],
    param_plot_stats['b'],
    param_plot_stats['t'],
    model_names,
    gamma = param_plot_stats['g'],
)

leh_scores_plot = pd.DataFrame(pd.Series(leh_scores), columns = ['mean'])
print(leh_scores_plot)

Best model: albert-xxlarge-v2
1.6489769220352173
No gamma: 0.14222308635772687
With gamma: 0.11293618947661707
           mean
0      0.170596
1      0.230350
2      0.063763
3      0.014811
4      0.090399
...         ...
82751  0.058886
82752  0.140021
82753  0.090254
82754  0.130726
82755  0.097754

[82756 rows x 1 columns]


In [25]:
leh_scores_plot = pd.concat([leh_scores_plot, task_name_format], axis=1)

In [26]:
leh_75 = leh_scores_plot.groupby(by='task_name').quantile(q=0.75).reset_index()

with open(os.path.join('plot_stats_pickles', f'LEH_75qtile.p'), 'wb') as f:
    pickle.dump(leh_75, f)

In [27]:
task_metadata.set_index("taskname", inplace=True)

# Follow-ups

## % all incorrect/correct

In [28]:
pcent_correct = pd.DataFrame(
    combined_responses.iloc[:,1:].sum(axis=0)/combined_responses.shape[0], columns=['pcent_correct']
)

In [29]:
pcent_correct['all_correct'] = pcent_correct['pcent_correct'].eq(1)
pcent_correct['all_wrong'] = pcent_correct['pcent_correct'].eq(0)
pcent_correct['all_either'] = pcent_correct['all_correct'] + pcent_correct['all_wrong']
pcent_correct['dataset'] = list(pd.Series(pcent_correct.index.values).apply(lambda x: '_'.join(x.split('_')[:-1])))

In [30]:
pcent_correct

Unnamed: 0,pcent_correct,all_correct,all_wrong,all_either,dataset
abductive_nli_0,0.466667,False,False,False,abductive_nli
abductive_nli_1,0.388889,False,False,False,abductive_nli
abductive_nli_2,0.744444,False,False,False,abductive_nli
abductive_nli_3,0.844444,False,False,False,abductive_nli
abductive_nli_4,0.644444,False,False,False,abductive_nli
...,...,...,...,...,...
wsc_47,0.777778,False,False,False,wsc
wsc_48,0.544444,False,False,False,wsc
wsc_49,0.633333,False,False,False,wsc
wsc_50,0.311111,False,False,False,wsc


In [31]:
print('exclude', param_plot_stats['a'].loc[~np.array(pcent_correct['all_either']),:].mean())
print('')
print('only', param_plot_stats['a'].loc[np.array(pcent_correct['all_either']),:].mean())
print('')
print('only correct', param_plot_stats['a'].loc[np.array(pcent_correct['all_correct']),:].mean())
print('')
print('only wrong', param_plot_stats['a'].loc[np.array(pcent_correct['all_wrong']),:].mean())

exclude lower       0.857838
upper       2.089375
mean        1.368802
var         0.108964
log_mean    0.285911
dtype: float64

only lower       1.013472
upper       2.422414
mean        1.601511
var         0.136384
log_mean    0.457826
dtype: float64

only correct lower       0.610682
upper       1.927486
mean        1.128396
var         0.121926
log_mean    0.113834
dtype: float64

only wrong lower       1.077908
upper       2.501590
mean        1.677196
var         0.138696
log_mean    0.512855
dtype: float64


In [32]:
print('exclude', param_plot_stats['g'].loc[~np.array(pcent_correct['all_either']),:].mean())
print('')
print('only', param_plot_stats['g'].loc[np.array(pcent_correct['all_either']),:].mean())
print('')
print('only correct', param_plot_stats['g'].loc[np.array(pcent_correct['all_correct']),:].mean())
print('')
print('only wrong', param_plot_stats['g'].loc[np.array(pcent_correct['all_wrong']),:].mean())

exclude lower       0.175066
upper       0.422752
mean        0.280274
var         0.550505
log_mean   -1.574811
dtype: float64

only lower       0.121289
upper       0.213446
mean        0.160410
var         0.563151
log_mean   -2.880551
dtype: float64

only correct lower       0.765192
upper       0.968812
mean        0.910426
var         0.583914
log_mean   -0.093919
dtype: float64

only wrong lower       0.018283
upper       0.092608
mean        0.040428
var         0.559830
log_mean   -3.326336
dtype: float64


In [33]:
temp = pd.concat([pcent_correct.reset_index(), task_name_format], axis=1)

dataset2format = {'all':'all'}
format2dataset = {'all':'all'}

for dset in temp['dataset'].unique():
    fmat = temp.loc[temp['dataset'] == dset, 'format'].unique()[0]
    dataset2format[dset] = fmat
    format2dataset[fmat] = dset

In [34]:
item_excludes = {}

for dataset in pcent_correct['dataset'].unique():
    temp = pcent_correct.loc[pcent_correct['dataset'] == dataset, :]
    
    try:
        assert np.not_equal(temp.index.values, np.array(sorted(temp.index.values, key=lambda x: int(x.split('_')[-1])))).sum() == 0, dataset
    except:
        print(dataset)
        pass
    
    item_excludes[dataset] = temp['all_either']

adversarial_nli


In [35]:
keeps = ['all_correct', 'all_wrong', 'all_either', 'dataset']
grouped_pcent_correct = pcent_correct[keeps].groupby(by='dataset')

In [36]:
combined_all_pcent = {}
combined_all_count = {}

for col in ['all_correct', 'all_wrong', 'all_either']:
    combined_all_count[col] = pcent_correct[col].sum()
    combined_all_pcent[col] = combined_all_count[col]/combined_responses.shape[1]
    
combined_all_count['total_count'] = combined_responses.shape[1]

In [37]:
grouped_pcent_either = grouped_pcent_correct.sum()/grouped_pcent_correct.count()
summary_pcent = grouped_pcent_either.append(pd.DataFrame.from_dict({'all':combined_all_pcent}, orient='index'))

grouped_count_either = grouped_pcent_correct.sum()
grouped_count_either['total_count'] = grouped_pcent_correct.count()['all_either']
summary_count = grouped_count_either.append(pd.DataFrame.from_dict({'all':combined_all_count}, orient='index'))


summary_pcent['format'] = list(pd.Series(summary_pcent.index.values).apply(lambda x: dataset2format[x]))
summary_count['format'] = list(pd.Series(summary_count.index.values).apply(lambda x: dataset2format[x]))

In [38]:
summary_pcent.loc[summary_pcent['all_either'] > 0, :]

Unnamed: 0,all_correct,all_wrong,all_either,format
adversarial_nli,0.000313,0.011875,0.012188,classification
arc_easy,0.0,0.000421,0.000421,MC-par
boolq,0.095413,0.002446,0.097859,MC-par
cosmosqa,0.0,0.002009,0.002009,MC-par
hellaswag,0.0,0.000398,0.000398,MC-par
mctaco,0.0,0.108108,0.108108,MC-sent
mnli,0.000102,0.000407,0.000509,classification
mrqa_natural_questions,0.0,0.073076,0.073076,span selection
newsqa,0.0,0.161426,0.161426,span selection
qamr,0.0,0.080714,0.080714,span selection


In [39]:
summary_count.loc[summary_count['all_either'] > 0, :]

Unnamed: 0,all_correct,all_wrong,all_either,total_count,format
adversarial_nli,1,38,39,3200,classification
arc_easy,0,1,1,2376,MC-par
boolq,156,4,160,1635,MC-par
cosmosqa,0,3,3,1493,MC-par
hellaswag,0,2,2,5021,MC-par
mctaco,0,144,144,1332,MC-sent
mnli,1,4,5,9824,classification
mrqa_natural_questions,0,469,469,6418,span selection
newsqa,0,693,693,4293,span selection
qamr,0,1515,1515,18770,span selection


In [40]:
summary_by_format = summary_count.groupby(by='format')
summary_by_format = summary_by_format.sum()

In [41]:
summary_by_format.loc[summary_by_format['all_either'] > 0,:]

Unnamed: 0_level_0,all_correct,all_wrong,all_either,total_count
format,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
MC-par,156,10,166,17193
MC-sent,0,161,161,10853
all,470,2938,3408,82757
classification,314,58,372,23015
span selection,0,2709,2709,30690


In [42]:
print('Break Down of All Either')
summary_by_format['all_either']/summary_by_format.loc['all', 'all_either']

Break Down of All Either


format
MC-par            0.048709
MC-sent           0.047242
all               1.000000
classification    0.109155
span selection    0.794894
Name: all_either, dtype: float64

In [43]:
print('Breakdown of All Either by Task Format')
summary_by_format['all_either']/summary_by_format['total_count']

Breakdown of All Either by Task Format


format
MC-par            0.009655
MC-sent           0.014835
all               0.041181
classification    0.016163
span selection    0.088270
dtype: float64

## Trim Responses

In [44]:
data_dir = os.path.join('..','data')
files = [f for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))]

### Trim Model Responses

In [45]:
trimmed_model_dir = os.path.join('..','data_trimmed_model')
os.makedirs(trimmed_model_dir, exist_ok=True)

In [46]:
exclude_models = ['albert-xxlarge-v2', 'xlm-roberta-large', 'roberta-large', 'roberta-base']

In [47]:
for file in files:
    temp = pd.read_csv(os.path.join(data_dir, file))
    temp_models = temp['userid'].apply(lambda x:''.join(x.split('_')[:-1]))
    
    exclude_indexes = temp_models.eq(exclude_models[0])
    for i in range(1,len(exclude_models)):
        exclude_indexes = exclude_indexes + temp_models.eq(exclude_models[i])
    
    trimmed = temp.loc[~exclude_indexes, :]
    
    check = set(temp.loc[~exclude_indexes, :]['userid'].apply(lambda x:''.join(x.split('_')[:-1])).unique())
    for model in exclude_models:
        assert not model in check, f'{file}, {model}'
    
    trimmed.to_csv(os.path.join(trimmed_model_dir, file), index=False)

### Trim Item Responses

In [48]:
trimmed_item_dir = os.path.join('..','data_trimmed_item')
os.makedirs(trimmed_item_dir, exist_ok=True)

In [49]:
ending = '_irt_all_coded.csv'

for file in files:
    dataset_name = '_'.join(file[:-len(ending)].split('-'))
    if dataset_name == 'mrqa_nq':
        dataset_name = 'mrqa_natural_questions'
    
    if not dataset_name in item_excludes.keys():
        print(dataset_name)  
    
    temp = pd.read_csv(os.path.join(data_dir, file))
    trimmed = temp.loc[:,~np.array([False] + list(item_excludes[dataset_name]))]
    
    trimmed.to_csv(os.path.join(trimmed_item_dir, file), index=False)

# LEH at BERT-large

### From Total

In [50]:
target = 'bert-large-cased_best'

leh_scores_bert = icc_best_deriv(
    param_plot_stats['a'],
    param_plot_stats['b'],
    param_plot_stats['t'],
    list(combined_responses['userid']),
    gamma = param_plot_stats['g'],
    target = target
)

leh_scores_plot_bert = pd.DataFrame(pd.Series(leh_scores_bert), columns = ['mean'])
print(leh_scores_plot_bert)

Local Grad for Target Model
Target model: bert-large-cased_best
0.8104567527770996
No gamma: 0.15963196480051325
With gamma: 0.12148421911600042
           mean
0      0.176043
1      0.208972
2      0.060574
3      0.036004
4      0.092786
...         ...
82751  0.094904
82752  0.182747
82753  0.141831
82754  0.043311
82755  0.055791

[82756 rows x 1 columns]


### From Trimmed BERT

In [51]:
# set to False if run for the first time
# note that this will take sometimes to run if the datasets are big
load_from_cache_trimmed = False

In [52]:
# distribution and transformation
alpha_dist = 'log-normal'
alpha_transf = 'standard'
theta_dist = 'normal'
theta_transf = 'standard'

exp_dir_trimmed = os.path.join(repo, 'params_trimmed_model', f'alpha-lognormal-identity_theta-normal-identity_nosubsample_1.00_0.30')
p = 0.95

with open(os.path.join(exp_dir_trimmed, 'responses.p'), 'rb') as f:
    combined_responses_trimmed = pickle.load(f).reset_index()
data_trimmed, accuracies_trimmed, example_accuracies_trimmed = get_data_accuracies(combined_responses_trimmed)
column_names_trimmed = combined_responses_trimmed.columns[1:]
select_ts = {
    'standard':lambda x:x,
    'positive':lambda x:torch.log(1+torch.exp(torch.tensor(x))),
    'sigmoid':lambda x:sigmoid(torch.tensor(x)),
}

transforms = {
    'a':select_ts[alpha_transf],
    'b':select_ts['standard'],
    'g':select_ts['sigmoid'],
    't':select_ts[theta_transf],
}

if load_from_cache_trimmed:
    param_plot_stats_trimmed = {}

    for key in transforms.keys():
        with open(os.path.join('plot_stats_pickles_trimmed_model', f'{key}.p'), 'rb') as f:
            param_plot_stats_trimmed[key] = pickle.load(f)
else:
    param_plot_stats_trimmed = get_plot_stats(
        exp_dir_trimmed,
        alpha_dist,
        theta_dist,
        transforms,
        p = p
    )
    
    os.makedirs('plot_stats_pickles_trimmed_model', exist_ok=True)
    for key, value in param_plot_stats_trimmed.items():
        with open(os.path.join('plot_stats_pickles_trimmed_model', f'{key}.p'), 'wb') as f:
            pickle.dump(value, f)

In [53]:
target = 'bert-large-cased_best'

leh_scores_bert_trimmed = icc_best_deriv(
    param_plot_stats_trimmed['a'],
    param_plot_stats_trimmed['b'],
    param_plot_stats_trimmed['t'],
    list(combined_responses_trimmed['userid']),
    gamma = param_plot_stats_trimmed['g'],
    target = target
)

leh_scores_plot_bert_trimmed = pd.DataFrame(pd.Series(leh_scores_bert_trimmed), columns = ['mean'])
print(leh_scores_plot_bert_trimmed)

Local Grad for Target Model
Target model: bert-large-cased_best
0.9805275201797485
No gamma: 0.1567958442231057
With gamma: 0.12096604941079794
           mean
0      0.158494
1      0.150325
2      0.039995
3      0.052717
4      0.152572
...         ...
82751  0.084873
82752  0.150056
82753  0.117921
82754  0.098012
82755  0.059681

[82756 rows x 1 columns]


## Differences

In [54]:
leh_bert_combined = pd.concat([leh_scores_plot_bert, leh_scores_plot_bert_trimmed], axis = 1)

In [55]:
diff = (leh_scores_plot_bert['mean'] - leh_scores_plot_bert_trimmed['mean'])

print('mean', diff.mean())
print('std', diff.std())
print('median', diff.median())

IQR = diff.quantile(0.75) - diff.quantile(0.25)
upper_whisker = diff.quantile(0.75) + 1.5*IQR
lower_whisker = diff.quantile(0.25) - 1.5*IQR

print('IQR', IQR)
print('upper_whisker', upper_whisker)
print('lower_whisker', lower_whisker)

print('pcent above lower', (diff > upper_whisker).sum()/diff.shape[0])
print('pcent below lower', (diff < lower_whisker).sum()/diff.shape[0])

mean 0.0005181697052024826
std 0.059380400067841184
median 0.0020297530195709183
IQR 0.049340527744277426
upper_whisker 0.09977841616189373
lower_whisker -0.09758369481521596
pcent above lower 0.04359804727149693
pcent below lower 0.03825704480641887


In [56]:
fig, ax = plot.subplots()
sns.boxplot(diff)
ax.set_xlabel(r'$\Delta$ LEH')
ax.set_title('LEH at BERT: Difference')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'LEH at BERT: Difference')

In [57]:
fig, ax = plot.subplots()
sns.boxplot(diff, showfliers=False)
ax.set_xlabel(r'$\Delta$ LEH')
ax.set_title('LEH at BERT: Difference')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'LEH at BERT: Difference')

In [58]:
rel_diff = (leh_scores_plot_bert['mean'] - leh_scores_plot_bert_trimmed['mean'])/leh_scores_plot_bert_trimmed['mean']

print('mean', rel_diff.mean())
print('std', rel_diff.std())
print('median', rel_diff.median())

IQR = rel_diff.quantile(0.75) - rel_diff.quantile(0.25)
upper_whisker = rel_diff.quantile(0.75) + 1.5*IQR
lower_whisker = rel_diff.quantile(0.25) - 1.5*IQR

print('IQR', IQR)
print('upper_whisker', upper_whisker)
print('lower_whisker', lower_whisker)

print('pcent above lower', (rel_diff > upper_whisker).sum()/rel_diff.shape[0])
print('pcent below lower', (rel_diff < lower_whisker).sum()/rel_diff.shape[0])

mean 0.07837129239456639
std 0.5712089759066041
median 0.025084227491001372
IQR 0.5199853736242595
upper_whisker 1.0685743083792318
lower_whisker -1.0113671861178057
pcent above lower 0.049615737831698
pcent below lower 0.0


In [59]:
fig, ax = plot.subplots()
sns.boxplot(rel_diff)
ax.set_xlabel(r'$\Delta$ LEH / LEH$_{BERT}$')
ax.set_title('LEH at BERT: Relative Difference')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'LEH at BERT: Relative Difference')

In [60]:
fig, ax = plot.subplots()
sns.boxplot(rel_diff, showfliers=False)
ax.set_xlabel(r'$\Delta$ LEH / LEH$_{BERT}$')
ax.set_title('LEH at BERT: Relative Difference')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'LEH at BERT: Relative Difference')

## Boxplots

### Methods

In [61]:
font_label = 12
font_legend = 10
font_legendtitle = font_legend + 4
font_xtick = 12
font_title = 14
marker_scale = 1.5

plot.rc('axes', labelsize=font_label)
plot.rc('axes', titlesize=font_title)
plot.rc('legend', fontsize=font_legend)

In [62]:
def plot_box(df, order, param_type, task_metadata, xsize=12, ysize=3, width=0.6, rotation=90, ylim=None, ystep=None):
    
    param2label = {
        'discriminative':r'$\log$ Discrimination ($\log$ $\alpha$)',
        'difficulty': r'Difficulty ($\beta$)',
        "disc-diff": "IRT Score", # "discrimination - difficulty product\n", # + r"($ n(\alpha) x n(\beta$) )"
        "disc-diff_pos": "IRT Score", # "discrimination - difficulty product\n", # + r"($ n(\alpha) x n(\beta$) )"
        "disc-diff_minmax": "IRT Score", # "discrimination - difficulty product\n", # + r"($ n(\alpha) x n(\beta$) )"
        "irt-score": "LEH Score",
    }
    
    param2yname = {
        'discriminative': "log_mean",
        'difficulty': "mean",
        "disc-diff": 0,
        "disc-diff_pos": 0,
        "disc-diff_minmax": 0,
        "irt-score": "mean",
    }
    
    sns.set_style("whitegrid")
    sns.set_context("paper")
    f, ax = plot.subplots(figsize=(xsize, ysize))
    
    my_pal = {"MC-par": "r",
              "MC-sent": "b",
              "classification":"g",
              "span selection": "grey"}    
    
    ax = sns.boxplot(x="task_name", y=param2yname[param_type], data=df, order=order, width=width)
    
    for i, task in enumerate(order):
        # Select which box you want to change    
        mybox = ax.artists[i]
        
        # Change the appearance of that box
        skill = task_metadata.loc[task]['format']
        mybox.set_facecolor(my_pal[skill])
    
    # Add transparency to colors
    for patch in ax.artists:
         r, g, b, a = patch.get_facecolor()
         patch.set_facecolor((r, g, b, .6))

    sns.despine()
    plot.xticks(range(len(order)), order, rotation=rotation, fontsize=font_xtick)
    
    if not ylim is None and not ystep is None:
        plot.ylim(ylim)
        plot.yticks(numpy.arange(ylim[0], ylim[1]+ystep, ystep))
    
    plot.xlabel(None)
    plot.ylabel(param2label[param_type], fontsize=font_label)
    
    return f
#     plot.savefig('../plots/' + param_type + "_box.png",
#                 format='png', dpi=300,
#                 bbox_inches = 'tight',
#                 pad_inches = .1)

### Plots

In [63]:
leh_bert_combined.corr()

Unnamed: 0,mean,mean.1
mean,1.0,0.739129
mean,0.739129,1.0


In [64]:
leh_scores_plot_bert = pd.concat([leh_scores_plot_bert, task_name_format], axis=1)
leh_scores_plot_bert_trimmed = pd.concat([leh_scores_plot_bert_trimmed, task_name_format], axis=1)

In [65]:
leh_scores_plot_bert

Unnamed: 0,mean,task_name,format
0,0.176043,AbductNLI,MC-sent
1,0.208972,AbductNLI,MC-sent
2,0.060574,AbductNLI,MC-sent
3,0.036004,AbductNLI,MC-sent
4,0.092786,AbductNLI,MC-sent
...,...,...,...
82751,0.094904,,
82752,0.182747,,
82753,0.141831,,
82754,0.043311,,


In [66]:
f = plot_box(
    leh_scores_plot_bert,
    task_order,
    "irt-score",
    task_metadata
)

f.suptitle('All Models')
f.savefig(os.path.join('..', 'plots_LEH_BERT', 'all_models.png'),
         format='png', dpi=300,
                bbox_inches = 'tight',
                pad_inches = .1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [67]:
ftrimmed = plot_box(
    leh_scores_plot_bert_trimmed,
    task_order,
    "irt-score",
    task_metadata
)

ftrimmed.suptitle('MiniBERTas + BERT')
ftrimmed.savefig(os.path.join('..', 'plots_LEH_BERT', 'trimmed_models.png'),
         format='png', dpi=300,
                bbox_inches = 'tight',
                pad_inches = .1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [68]:
leh_tasks = pd.concat([
    leh_scores_plot_bert.rename(columns={'mean':'Full'}),
    pd.DataFrame(leh_scores_plot_bert_trimmed['mean']).rename(columns={'mean':'BERT'})
], axis=1)

In [69]:
leh_tasks_qtile = leh_tasks.groupby(by='task_name').quantile(q=0.75)
leh_tasks_qtile['diff'] = leh_tasks_qtile['Full'] - leh_tasks_qtile['BERT']
leh_tasks_qtile['rel_diff'] = (leh_tasks_qtile['Full'] - leh_tasks_qtile['BERT'])/leh_tasks_qtile['BERT']

print('diff', leh_tasks_qtile['diff'].median())
print('rel diff', leh_tasks_qtile['rel_diff'].median())

diff 0.004662516241732889
rel diff 0.027159661919406792


In [70]:
print('median diff (magnitude)', leh_tasks_qtile['diff'].abs().median())
print('median rel diff (magnitude)', leh_tasks_qtile['rel_diff'].abs().median())

median diff (magnitude) 0.010413355061485552
median rel diff (magnitude) 0.06581982948185636


In [71]:
leh_tasks_qtile.rename(columns={'BERT':'Up to BERT', 'diff':'Difference', 'rel_diff':'Relative Difference'}).sort_values(by='Relative Difference')

Unnamed: 0_level_0,Full,Up to BERT,Difference,Relative Difference
task_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
WiC,0.007902,0.231662,-0.22376,-0.965891
Winogrande,0.007903,0.221559,-0.213656,-0.964329
WSC,0.008322,0.223147,-0.214825,-0.962706
ARC-C,0.165677,0.183855,-0.018178,-0.098869
CB,0.098872,0.105602,-0.006729,-0.063725
SNLI,0.080718,0.086122,-0.005405,-0.062757
ANLI,0.156625,0.162685,-0.006059,-0.037246
PiQA,0.144227,0.146788,-0.002562,-0.017451
COPA,0.140743,0.141083,-0.00034,-0.002408
MC-TACO,0.177377,0.177229,0.000148,0.000838


In [72]:
leh_tasks_qtile.loc[leh_tasks_qtile['diff'] < 0, :]

Unnamed: 0_level_0,Full,BERT,diff,rel_diff
task_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ANLI,0.156625,0.162685,-0.006059,-0.037246
ARC-C,0.165677,0.183855,-0.018178,-0.098869
CB,0.098872,0.105602,-0.006729,-0.063725
COPA,0.140743,0.141083,-0.00034,-0.002408
PiQA,0.144227,0.146788,-0.002562,-0.017451
SNLI,0.080718,0.086122,-0.005405,-0.062757
WSC,0.008322,0.223147,-0.214825,-0.962706
WiC,0.007902,0.231662,-0.22376,-0.965891
Winogrande,0.007903,0.221559,-0.213656,-0.964329
