# Import

In [1]:
%matplotlib widget

In [2]:
import os
import pickle5 as pickle
import copy

import pandas as pd
import seaborn as sns
import numpy
import torch
import scipy
import scipy.stats

import pyro
import pyro.infer
import pyro.infer.mcmc
import pyro.distributions as dist
import torch.distributions.constraints as constraints
from tqdm.auto import tqdm

import matplotlib.pyplot as plot
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import warnings
warnings.filterwarnings('ignore')
sns.set(style="whitegrid")

In [3]:
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))

# Define

In [4]:
def sigmoid(x):
    return 1./(1.+torch.exp(-x))

def icc_best_deriv(alpha, beta, theta, model_names, gamma=None, col='mean'):
    '''
    Method to calculate the locally estimated headroom (LEH) score, defined as
    the derivative of the item characteristic curve w.r.t. the best performing model.
    
    Args:
        alpha:       DataFrame of discrimination parameter statistics for each item.
        beta:        DataFrame of difficulty parameter statistics for each item.
        theta:       DataFrame of ability parameter statistics for each responder.
        model_names: List of responder names.
        gamma:       DataFrame of guessing parameter statistics for each item.
        col:         DataFrame column name to use for calculating LEH scores.
    
    Returns:
        scores:      LEH scores for each item.    
    '''
    best_idx, best_value = theta[col].argmax(), theta[col].max()
    print(f'Best model: {model_names[best_idx]}\n{best_value}')
    
    a, b = torch.tensor(alpha[col].values), torch.tensor(beta[col].values)
    
    logits = (a*(best_value-b))
    sigmoids = sigmoid(logits)
    scores = sigmoids*(1.-sigmoids)*a
    
    print(f'No gamma: {scores.mean()}')
    if not gamma is None:
        g = torch.tensor(gamma[col].apply(lambda x: x.item()).values)
        scores = (1.-g)*scores
        print(f'With gamma: {scores.mean()}')
    
    return scores      
    
    

In [5]:
def get_model_guide(alpha_dist, theta_dist, alpha_transform, theta_transform):
    model = lambda obs: irt_model(obs, alpha_dist, theta_dist, alpha_transform = alpha_transform, theta_transform = theta_transform)
    guide = lambda obs: vi_posterior(obs, alpha_dist, theta_dist)
    
    return model, guide

In [6]:
def get_data_accuracies(data, verbose = False, get_cols = False):
    '''
    Method to reformat `data` and calculate item and responder accuracies.
    
    Args:
        data:                DataFrame of item responses.
        verbose:             Boolean value of whether to print statements.
        get_cols:            Boolean value of whether to return original column
                             values of `data`.
        
    Returns:
        new_data:            Reformatted `data`, dropping first column.
        accuracies:          Accuracy for each responder across examples.
        example_accuracies:  Accuracy for each example across responders.
        data.columns.values: Returns only if `get_cols` is True. Original column
                             values of `data`.
    '''
    new_data = numpy.array(data)
    new_data = new_data[:,1:]
    
    model_names = dict(data['userid'])
    accuracies = new_data.mean(-1)
    example_accuracies = new_data.mean(0)
    
    if verbose:
        print('\n'.join([f'{name}: {acc}' for name, acc in zip(model_names.values(),accuracies)]))
    
    if get_cols:
        return new_data, accuracies, example_accuracies, data.columns.values
    else:
        return new_data, accuracies, example_accuracies

In [7]:
def get_stats_CI(params, p=0.95, dist='normal'):
    '''
    Method to calculate lower and upper quantiles defined by `p`, mean, and variance of `param`
    
    Args:
        params: Dictionary of distribution parameters for each item keyed according to the 
                parametric distribution defined by `dist`.
        p:      Percent of distribution covered by the lower and upper interval values for each
                parameter.
        dist:   Name of parametric distribution
    
    Returns:
        return: {
            'lower': Lower interval values of each parameter,
            'upper': Upper interval values of each parameter,
            'mean' : Mean of each parameter,
            'var'  : Variance of each parameter
        }
    '''
    stats = {}
    if dist == 'normal':
        L,U = scipy.stats.norm.interval(p,loc=params['mu'], scale=torch.exp(params['logstd']))
        M,V = scipy.stats.norm.stats(loc=params['mu'], scale=torch.exp(params['logstd']))
    elif dist == 'log-normal':
        L,U = scipy.stats.lognorm.interval(p, s=torch.exp(params['logstd']), scale=torch.exp(params['mu']))
        M,V = scipy.stats.lognorm.stats(s=torch.exp(params['logstd']), scale=torch.exp(params['mu']))
    elif dist == 'beta':
        L,U = scipy.stats.beta.interval(p,a=params['alpha'], b=params['beta'])
        M,V = scipy.stats.beta.stats(a=params['alpha'], b=params['beta'])
    else:
        raise TypeError(f'Distribution type {dist} not supported.')
    
    return {
        'lower':[L],
        'upper':[U],
        'mean':[M],
        'var':[V],
    }

In [8]:
def get_plot_stats(exp_dir, alpha_dist, theta_dist, transforms, p = 0.95):
    '''
    Method to return plotting statistics for 3 parameter IRT model parameters.
    
    Args:
        exp_dir:          Path to 3 parameter IRT parameters and responses.
        alpha_dist:       Name of the item discrimination [a] distribution.
        theta_dist:       Name of the responder ability [t] distribution.
        transforms:       Dictionary of transformations to apply to each parameter type
                          where keys are parameter names and values are functions.
        p:                Percent of distribution covered by the lower and upper interval 
                          values for each parameter.
    
    Returns:
        param_plot_stats: Dictionary of parameter plot statistics where keys are parameter
                          names and values are plot statistics dictionaries as defined by
                          get_stats_CI().
    '''
    param_dists = {
        'a':alpha_dist,
        'b':'normal',
        'g':'normal',
        't':theta_dist,
    }

    dist_params = {
        'normal':['mu', 'logstd'],
        'log-normal':['mu', 'logstd'],
        'beta':['alpha', 'beta'],
    }

    pyro.clear_param_store()
    pyro.get_param_store().load(os.path.join(exp_dir, 'params.p'))

    with torch.no_grad():
        pyro_param_dict = dict(pyro.get_param_store().named_parameters())
    
    # get stats for plotting
    param_plot_stats = {}

    for param, param_dist in param_dists.items():
        temp_params = dist_params[param_dist]

        for idx, (p1_orig, p2_orig) in enumerate(zip(pyro_param_dict[f'{param} {temp_params[0]}'], pyro_param_dict[f'{param} {temp_params[1]}'])):
            p1, p2 = p1_orig.detach(), p2_orig.detach()
            
            temp_stats_df = pd.DataFrame.from_dict(
                get_stats_CI(
                    params = {
                        temp_params[0]:p1,
                        temp_params[1]:p2,
                    },
                    p=p,
                    dist = param_dist,
                )
            )
            
            temp_stats_df = temp_stats_df.applymap(transforms[param])
        
            if idx == 0:
                param_plot_stats[param] = temp_stats_df
            else:
                param_plot_stats[param] = param_plot_stats[param].append(temp_stats_df, ignore_index = True)
    
    return param_plot_stats

In [9]:
def sign_mult(df1, df2):
    newdf = copy.deepcopy(df2)
    
    for idx, row in df1.iterrows():
        if numpy.sign(row['mean']) < 0:
            newdf.loc[idx,'mean'] = -1*newdf.loc[idx,'mean']
            newdf.loc[idx,'lower'] = -1*newdf.loc[idx,'upper']
            newdf.loc[idx,'upper'] = -1*newdf.loc[idx,'lower']
    
    return newdf

In [10]:
def get_diff_by_set(diffs, item_ids):
    diff_by_set = {}
    id_split = '_'

    max_diff = -1e6
    min_diff = 1e6
    
    for idx, diff in enumerate(diffs):
        set_name = item_ids[idx].split(id_split)[0]

        if set_name in diff_by_set.keys():
            diff_by_set[set_name].append(diff)
        else:
            diff_by_set[set_name] = [diff]
            
        if diff < min_diff:
            min_diff = diff
            
        if diff > max_diff:
            max_diff = diff
    
    return diff_by_set, min_diff, max_diff

# Load Trimmed

## Get Tasks

In [11]:
from variational_irt import *

In [12]:
datasets="boolq,cb,commonsenseqa,copa,cosmosqa,hellaswag,adversarial-nli,rte,snli,wic,qamr,arct,mcscript,mctaco,mutual,mutual-plus,quoref,socialiqa,squad-v2,wsc,mnli,mrqa-nq,newsqa,abductive-nli,arc-easy,arc-challenge,piqa,quail,winogrande,anli"
data_names, responses, n_items = get_files(
    os.path.join(repo, 'data_trimmed_item'),
    "csv",
    set(datasets.split(','))
)

In [13]:
task_metadata = pd.read_csv('task_metadata.csv')
task_metadata.set_index("jiant_name", inplace=True)
task_list = [x for x in task_metadata.index if x in data_names]

In [14]:
total = 0
task_name = []
task_format = []

for tname, size in zip(data_names, n_items):
    name = task_metadata.loc[tname]['taskname']
    total += size
    task_name += [name for _ in range(size)]
    task_format += [task_metadata.loc[tname]['format'] for _ in range(size)]
    
task_name = pd.DataFrame(task_name, columns=['task_name'])
task_format = pd.DataFrame(task_format, columns=['format'])
task_name_format_trimmed = pd.concat([task_name, task_format], axis=1)

In [15]:
len(data_names)

29

## Get Params and Order

In [16]:
exp_dir = os.path.join(repo, 'params_trimmed_item', f'alpha-lognormal-identity_theta-normal-identity_nosubsample_1.00_0.30')
p = 0.95

with open(os.path.join(exp_dir, 'responses.p'), 'rb') as f:
    combined_responses = pickle.load(f).reset_index()

In [17]:
# Check accuracy of roberta-large models

extractmodel = 'roberta-large_best'
tie_break = 0

acc_by_dataset = {}

roberta_rp = combined_responses.loc[combined_responses['userid']==extractmodel, :]
if roberta_rp.shape[0] > 1:
    roberta_rp = roberta_rp.iloc[tie_break, :]

cols = combined_responses.columns.values

for item in cols[1:]:
    data_name = '_'.join(item.split('_')[:-1])
    resp = roberta_rp[item].item()
    
    if data_name in acc_by_dataset:
        acc_by_dataset[data_name]['correct'] += resp
        acc_by_dataset[data_name]['total'] += 1
    else:
        acc_by_dataset[data_name] = {'correct': resp, 'total': 1}

print(extractmodel)
print('='*90)
print(f'Overall acc: {roberta_rp.iloc[0, 1:].sum()/(roberta_rp.shape[1]-1):.4f}')        

for data_name, acc_dict in acc_by_dataset.items():
    print(f'{data_name} acc: {acc_dict["correct"]/acc_dict["total"]:.4f}')

roberta-large_best
Overall acc: 0.7692
abductive_nli acc: 0.8564
adversarial_nli acc: 0.4995
arc_challenge acc: 0.3319
arc_easy acc: 0.6299
arct acc: 0.8604
boolq acc: 0.8217
cb acc: 0.8571
commonsenseqa acc: 0.6759
copa acc: 0.8400
cosmosqa acc: 0.8000
hellaswag acc: 0.8420
mcscript acc: 0.9183
mctaco acc: 0.6010
mnli acc: 0.8995
mrqa_natural_questions acc: 0.7489
mutual_plus acc: 0.7314
mutual acc: 0.8668
newsqa acc: 0.6608
piqa acc: 0.7617
qamr acc: 0.7944
quail acc: 0.6691
quoref acc: 0.8241
rte acc: 0.8345
snli acc: 0.9192
socialiqa acc: 0.7738
squad_v2 acc: 0.4395
wic acc: 0.7085
winogrande acc: 0.7697
wsc acc: 0.6154


In [18]:
# set to False if run for the first time
# note that this will take sometimes to run if the datasets are big
load_from_cache = True

In [19]:
# distribution and transformation
alpha_dist = 'log-normal'
alpha_transf = 'standard'
theta_dist = 'normal'
theta_transf = 'standard'

exp_dir = os.path.join(repo, 'params_trimmed_item', f'alpha-lognormal-identity_theta-normal-identity_nosubsample_1.00_0.30')
p = 0.95

with open(os.path.join(exp_dir, 'responses.p'), 'rb') as f:
    combined_responses = pickle.load(f).reset_index()
data, accuracies, example_accuracies = get_data_accuracies(combined_responses)
column_names = combined_responses.columns[1:]
select_ts = {
    'standard':lambda x:x,
    'positive':lambda x:torch.log(1+torch.exp(torch.tensor(x))),
    'sigmoid':lambda x:sigmoid(torch.tensor(x)),
}

transforms = {
    'a':select_ts[alpha_transf],
    'b':select_ts['standard'],
    'g':select_ts['sigmoid'],
    't':select_ts[theta_transf],
}

if load_from_cache:
    param_plot_stats_trimmed = {}

    for key in transforms.keys():
        with open(os.path.join('plot_stats_pickles_trimmed_item', f'{key}.p'), 'rb') as f:
            param_plot_stats_trimmed[key] = pickle.load(f)
else:
    param_plot_stats = get_plot_stats(
        exp_dir,
        alpha_dist,
        theta_dist,
        transforms,
        p = 0.95
    )
    
    os.makedirs('plot_stats_pickles_trimmed_item', exist_ok=True)
    for key, value in param_plot_stats.items():
        with open(os.path.join('plot_stats_pickles_trimmed_item', f'{key}.p'), 'wb') as f:
            pickle.dump(value, f)

In [20]:
combined_responses

Unnamed: 0,userid,abductive_nli_0,abductive_nli_1,abductive_nli_2,abductive_nli_3,abductive_nli_4,abductive_nli_5,abductive_nli_6,abductive_nli_7,abductive_nli_8,...,wsc_42,wsc_43,wsc_44,wsc_45,wsc_46,wsc_47,wsc_48,wsc_49,wsc_50,wsc_51
0,roberta-base-10M-1_best,1,0,1,1,1,1,1,0,1,...,1,0,1,1,0,1,0,1,0,1
1,roberta-base-10M-1_1,0,0,0,1,1,0,1,0,1,...,1,0,1,0,0,1,0,0,0,1
2,roberta-base-10M-1_25,1,1,1,1,1,1,1,1,1,...,0,0,0,1,0,1,1,1,1,1
3,roberta-base-10M-1_50,1,0,1,1,1,1,1,0,1,...,1,0,0,1,0,1,1,1,0,1
4,roberta-base-10M-1_10,1,0,1,1,1,1,1,0,1,...,1,0,0,1,0,1,1,1,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85,xlm-roberta-large_best,1,1,0,1,1,1,1,1,1,...,1,1,0,1,0,1,1,1,0,0
86,xlm-roberta-large_1,1,1,0,1,1,0,1,0,1,...,0,0,1,0,1,0,0,0,1,1
87,xlm-roberta-large_25,1,1,0,1,1,1,0,1,1,...,1,1,0,1,0,1,1,1,0,0
88,xlm-roberta-large_50,1,1,1,1,1,1,1,1,1,...,1,1,0,1,0,1,1,1,0,0


In [21]:
model_names = []
model_levels = []
for m in combined_responses['userid']:
    mname = m.split('_')[0]
    mlevel = m.split('_')[-1]
    if mname.endswith('-1') or mname.endswith('-2') or mname.endswith('-3'):
        mname = mname[:-2]
    model_names.append(mname)
    
    mlevel_append = '' if mlevel == 'best' else r'%'
    model_levels.append(mlevel+mlevel_append)

In [22]:
#  we will only use log mean for discriminative parameter
for param_key, param_stat in param_plot_stats_trimmed.items():
    param_stat['log_mean'] = numpy.log(param_stat['mean'])
    print(param_key, param_stat['log_mean'].isnull().sum())

a 0
b 35469
g 0
t 46


In [23]:
param_a = pd.concat([param_plot_stats_trimmed['a'], task_name_format_trimmed], axis=1)
param_b = pd.concat([param_plot_stats_trimmed['b'], task_name_format_trimmed], axis=1)

task_order = [task_metadata.loc[x]['taskname'] for x in task_list]

In [24]:
leh_scores = icc_best_deriv(
    param_plot_stats_trimmed['a'],
    param_plot_stats_trimmed['b'],
    param_plot_stats_trimmed['t'],
    model_names,
    gamma = param_plot_stats_trimmed['g'],
)

leh_scores_plot = pd.DataFrame(pd.Series(leh_scores), columns = ['mean'])
print(leh_scores_plot)

Best model: albert-xxlarge-v2
1.7979713678359985
No gamma: 0.14499040325361917
With gamma: 0.11655480116768933
           mean
0      0.166683
1      0.221677
2      0.087377
3      0.014670
4      0.098858
...         ...
78716  0.048649
78717  0.106998
78718  0.084177
78719  0.188978
78720  0.081421

[78721 rows x 1 columns]


In [25]:
leh_scores_plot_trimmed = pd.concat([leh_scores_plot, task_name_format_trimmed], axis=1)

In [26]:
task_metadata.set_index("taskname", inplace=True)

# Load Full

## Get Tasks

In [27]:
from variational_irt import *

In [28]:
datasets="boolq,cb,commonsenseqa,copa,cosmosqa,hellaswag,adversarial-nli,rte,snli,wic,qamr,arct,mcscript,mctaco,mutual,mutual-plus,quoref,socialiqa,squad-v2,wsc,mnli,mrqa-nq,newsqa,abductive-nli,arc-easy,arc-challenge,piqa,quail,winogrande,anli"
data_names, responses, n_items = get_files(
    os.path.join(repo, 'data'),
    "csv",
    set(datasets.split(','))
)

In [29]:
task_metadata = pd.read_csv('task_metadata.csv')
task_metadata.set_index("jiant_name", inplace=True)
task_list = [x for x in task_metadata.index if x in data_names]

In [30]:
total = 0
task_name = []
task_format = []

for tname, size in zip(data_names, n_items):
    name = task_metadata.loc[tname]['taskname']
    total += size
    task_name += [name for _ in range(size)]
    task_format += [task_metadata.loc[tname]['format'] for _ in range(size)]
    
task_name = pd.DataFrame(task_name, columns=['task_name'])
task_format = pd.DataFrame(task_format, columns=['format'])
task_name_format = pd.concat([task_name, task_format], axis=1)

In [31]:
len(data_names)

29

## Get Params and Order

In [32]:
exp_dir = os.path.join(repo, 'params_trimmed_item', f'alpha-lognormal-identity_theta-normal-identity_nosubsample_1.00_0.30')
p = 0.95

with open(os.path.join(exp_dir, 'responses.p'), 'rb') as f:
    combined_responses = pickle.load(f).reset_index()

In [33]:
# Check accuracy of roberta-large models

extractmodel = 'roberta-large_best'
tie_break = 0

acc_by_dataset = {}

roberta_rp = combined_responses.loc[combined_responses['userid']==extractmodel, :]
if roberta_rp.shape[0] > 1:
    roberta_rp = roberta_rp.iloc[tie_break, :]

cols = combined_responses.columns.values

for item in cols[1:]:
    data_name = '_'.join(item.split('_')[:-1])
    resp = roberta_rp[item].item()
    
    if data_name in acc_by_dataset:
        acc_by_dataset[data_name]['correct'] += resp
        acc_by_dataset[data_name]['total'] += 1
    else:
        acc_by_dataset[data_name] = {'correct': resp, 'total': 1}

print(extractmodel)
print('='*90)
print(f'Overall acc: {roberta_rp.iloc[0, 1:].sum()/(roberta_rp.shape[1]-1):.4f}')        

for data_name, acc_dict in acc_by_dataset.items():
    print(f'{data_name} acc: {acc_dict["correct"]/acc_dict["total"]:.4f}')

roberta-large_best
Overall acc: 0.7692
abductive_nli acc: 0.8564
adversarial_nli acc: 0.4995
arc_challenge acc: 0.3319
arc_easy acc: 0.6299
arct acc: 0.8604
boolq acc: 0.8217
cb acc: 0.8571
commonsenseqa acc: 0.6759
copa acc: 0.8400
cosmosqa acc: 0.8000
hellaswag acc: 0.8420
mcscript acc: 0.9183
mctaco acc: 0.6010
mnli acc: 0.8995
mrqa_natural_questions acc: 0.7489
mutual_plus acc: 0.7314
mutual acc: 0.8668
newsqa acc: 0.6608
piqa acc: 0.7617
qamr acc: 0.7944
quail acc: 0.6691
quoref acc: 0.8241
rte acc: 0.8345
snli acc: 0.9192
socialiqa acc: 0.7738
squad_v2 acc: 0.4395
wic acc: 0.7085
winogrande acc: 0.7697
wsc acc: 0.6154


In [34]:
# set to False if run for the first time
# note that this will take sometimes to run if the datasets are big
load_from_cache = True

In [35]:
# distribution and transformation
alpha_dist = 'log-normal'
alpha_transf = 'standard'
theta_dist = 'normal'
theta_transf = 'standard'

exp_dir = os.path.join(repo, 'params', f'alpha-lognormal-identity_theta-normal-identity_nosubsample_1.00_0.30')
p = 0.95

with open(os.path.join(exp_dir, 'responses.p'), 'rb') as f:
    combined_responses = pickle.load(f).reset_index()
data, accuracies, example_accuracies = get_data_accuracies(combined_responses)
column_names = combined_responses.columns[1:]
select_ts = {
    'standard':lambda x:x,
    'positive':lambda x:torch.log(1+torch.exp(torch.tensor(x))),
    'sigmoid':lambda x:sigmoid(torch.tensor(x)),
}

transforms = {
    'a':select_ts[alpha_transf],
    'b':select_ts['standard'],
    'g':select_ts['sigmoid'],
    't':select_ts[theta_transf],
}

if load_from_cache:
    param_plot_stats = {}

    for key in transforms.keys():
        with open(os.path.join('plot_stats_pickles', f'{key}.p'), 'rb') as f:
            param_plot_stats[key] = pickle.load(f)
else:
    param_plot_stats = get_plot_stats(
        exp_dir,
        alpha_dist,
        theta_dist,
        transforms,
        p = 0.95
    )
    
    os.makedirs('plot_stats_pickles', exist_ok=True)
    for key, value in param_plot_stats.items():
        with open(os.path.join('plot_stats_pickles', f'{key}.p'), 'wb') as f:
            pickle.dump(value, f)

In [36]:
combined_responses

Unnamed: 0,userid,abductive_nli_0,abductive_nli_1,abductive_nli_2,abductive_nli_3,abductive_nli_4,abductive_nli_5,abductive_nli_6,abductive_nli_7,abductive_nli_8,...,wsc_42,wsc_43,wsc_44,wsc_45,wsc_46,wsc_47,wsc_48,wsc_49,wsc_50,wsc_51
0,roberta-base-10M-1_best,1,0,1,1,1,1,1,0,1,...,1,0,1,1,0,1,0,1,0,1
1,roberta-base-10M-1_1,0,0,0,1,1,0,1,0,1,...,1,0,1,0,0,1,0,0,0,1
2,roberta-base-10M-1_25,1,1,1,1,1,1,1,1,1,...,0,0,0,1,0,1,1,1,1,1
3,roberta-base-10M-1_50,1,0,1,1,1,1,1,0,1,...,1,0,0,1,0,1,1,1,0,1
4,roberta-base-10M-1_10,1,0,1,1,1,1,1,0,1,...,1,0,0,1,0,1,1,1,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85,xlm-roberta-large_best,1,1,0,1,1,1,1,1,1,...,1,1,0,1,0,1,1,1,0,0
86,xlm-roberta-large_1,1,1,0,1,1,0,1,0,1,...,0,0,1,0,1,0,0,0,1,1
87,xlm-roberta-large_25,1,1,0,1,1,1,0,1,1,...,1,1,0,1,0,1,1,1,0,0
88,xlm-roberta-large_50,1,1,1,1,1,1,1,1,1,...,1,1,0,1,0,1,1,1,0,0


In [37]:
model_names = []
model_levels = []
for m in combined_responses['userid']:
    mname = m.split('_')[0]
    mlevel = m.split('_')[-1]
    if mname.endswith('-1') or mname.endswith('-2') or mname.endswith('-3'):
        mname = mname[:-2]
    model_names.append(mname)
    
    mlevel_append = '' if mlevel == 'best' else r'%'
    model_levels.append(mlevel+mlevel_append)

In [38]:
#  we will only use log mean for discriminative parameter
for param_key, param_stat in param_plot_stats.items():
    param_stat['log_mean'] = numpy.log(param_stat['mean'])
    print(param_key, param_stat['log_mean'].isnull().sum())

a 0
b 36860
g 0
t 49


In [39]:
param_a = pd.concat([param_plot_stats['a'], task_name_format], axis=1)
param_b = pd.concat([param_plot_stats['b'], task_name_format], axis=1)

task_order = [task_metadata.loc[x]['taskname'] for x in task_list]

In [40]:
leh_scores = icc_best_deriv(
    param_plot_stats['a'],
    param_plot_stats['b'],
    param_plot_stats['t'],
    model_names,
    gamma = param_plot_stats['g'],
)

leh_scores_plot = pd.DataFrame(pd.Series(leh_scores), columns = ['mean'])
print(leh_scores_plot)

Best model: albert-xxlarge-v2
1.582378625869751
No gamma: 0.14226206233809996
With gamma: 0.11365353401378117
           mean
0      0.193985
1      0.178925
2      0.057262
3      0.014374
4      0.105861
...         ...
82228  0.059020
82229  0.123432
82230  0.083337
82231  0.107942
82232  0.082005

[82233 rows x 1 columns]


In [41]:
leh_scores_plot = pd.concat([leh_scores_plot, task_name_format], axis=1)

In [42]:
task_metadata.set_index("taskname", inplace=True)

# Compare

## LEH

In [53]:
trimmed = leh_scores_plot_trimmed.groupby(by='task_name').quantile(q=0.75).rename(columns={'mean':'Trimmed'})
trimmed

Unnamed: 0_level_0,Trimmed
task_name,Unnamed: 1_level_1
ANLI,0.21752
ARC-C,0.236766
ARC-E,0.22507
ARCT,0.146903
AbductNLI,0.183463
BoolQ,0.139165
CB,0.091976
COPA,0.183826
CSQA,0.238214
CosmosQA,0.208724


In [54]:
full = leh_scores_plot.groupby(by='task_name').quantile(q=0.75).rename(columns={'mean':'Full'})
full

Unnamed: 0_level_0,Full
task_name,Unnamed: 1_level_1
ANLI,0.198573
ARC-C,0.215639
ARC-E,0.214382
ARCT,0.145347
AbductNLI,0.174404
BoolQ,0.122382
CB,0.108013
COPA,0.17037
CSQA,0.231926
CosmosQA,0.199577


In [55]:
combined = pd.concat([trimmed, full], axis=1)
combined['diff'] = combined['Trimmed'] - combined['Full']
combined['rel_diff'] = (combined['Trimmed'] - combined['Full'])/combined['Full']

print('diff', combined['diff'].median())
print('rel_diff', combined['rel_diff'].median())

diff 0.009146576121313699
rel_diff 0.04985242828403232


In [64]:
combined.loc[:,['Trimmed','Full']].corr(method='spearman')

Unnamed: 0,Trimmed,Full
Trimmed,1.0,0.985714
Full,0.985714,1.0


## Discr

In [57]:
print(param_plot_stats_trimmed['a'].shape)

(78721, 5)


In [58]:
print(task_name_format_trimmed.shape)

(78721, 2)


In [59]:
trimmed = pd.concat([param_plot_stats_trimmed['a'], task_name_format_trimmed], axis=1) 
trimmed = trimmed.groupby(by='task_name').quantile(q=0.75).rename(columns={'log_mean':'Trimmed'})
trimmed

Unnamed: 0_level_0,lower,upper,mean,var,Trimmed
task_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
ANLI,0.910985,2.254991,1.458205,0.128729,0.377206
ARC-C,0.829711,2.111609,1.348144,0.116188,0.298729
ARC-E,0.832761,2.090126,1.337743,0.112663,0.290984
ARCT,0.797911,2.125889,1.332243,0.123776,0.286864
AbductNLI,0.831509,2.223219,1.385403,0.136846,0.325991
BoolQ,0.833742,2.14684,1.338929,0.125501,0.29187
CB,0.891985,2.186326,1.360933,0.116663,0.308167
COPA,0.691181,2.033327,1.211465,0.131899,0.191828
CSQA,0.851922,2.168921,1.400281,0.123441,0.336673
CosmosQA,0.873289,2.182176,1.411812,0.119116,0.344874


In [60]:
full = pd.concat([param_plot_stats['a'], task_name_format], axis=1)
full = full.groupby(by='task_name').quantile(q=0.75).rename(columns={'log_mean':'Full'})
full

Unnamed: 0_level_0,lower,upper,mean,var,Full
task_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
ANLI,0.896719,2.221928,1.439807,0.126375,0.364509
ARC-C,0.817265,2.082838,1.323434,0.112265,0.28023
ARC-E,0.824985,2.064168,1.317038,0.109276,0.275386
ARCT,0.786051,2.102383,1.291031,0.122855,0.255441
AbductNLI,0.823848,2.216965,1.391689,0.135476,0.330518
BoolQ,0.801109,2.106907,1.306342,0.123751,0.267231
CB,0.954628,2.059321,1.427322,0.101397,0.355759
COPA,0.704257,1.885803,1.176215,0.104034,0.162301
CSQA,0.84866,2.09339,1.343194,0.111521,0.29505
CosmosQA,0.85838,2.143313,1.384106,0.121368,0.325055


In [61]:
combined_a = pd.concat([trimmed, full], axis=1)
combined_a['diff'] = combined_a['Trimmed'] - combined_a['Full']
combined_a['rel_diff'] = (combined_a['Trimmed'] - combined_a['Full'])/combined_a['Full']

print('diff', combined_a['diff'].median())
print('rel_diff', combined_a['rel_diff'].median())

diff 0.01320694910671591
rel_diff 0.04832417383542685


In [62]:
combined_a

Unnamed: 0_level_0,lower,upper,mean,var,Trimmed,lower,upper,mean,var,Full,diff,rel_diff
task_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
ANLI,0.910985,2.254991,1.458205,0.128729,0.377206,0.896719,2.221928,1.439807,0.126375,0.364509,0.012698,0.034835
ARC-C,0.829711,2.111609,1.348144,0.116188,0.298729,0.817265,2.082838,1.323434,0.112265,0.28023,0.018499,0.066014
ARC-E,0.832761,2.090126,1.337743,0.112663,0.290984,0.824985,2.064168,1.317038,0.109276,0.275386,0.015598,0.056642
ARCT,0.797911,2.125889,1.332243,0.123776,0.286864,0.786051,2.102383,1.291031,0.122855,0.255441,0.031423,0.123015
AbductNLI,0.831509,2.223219,1.385403,0.136846,0.325991,0.823848,2.216965,1.391689,0.135476,0.330518,-0.004527,-0.013696
BoolQ,0.833742,2.14684,1.338929,0.125501,0.29187,0.801109,2.106907,1.306342,0.123751,0.267231,0.024639,0.0922
CB,0.891985,2.186326,1.360933,0.116663,0.308167,0.954628,2.059321,1.427322,0.101397,0.355759,-0.047592,-0.133777
COPA,0.691181,2.033327,1.211465,0.131899,0.191828,0.704257,1.885803,1.176215,0.104034,0.162301,0.029526,0.181922
CSQA,0.851922,2.168921,1.400281,0.123441,0.336673,0.84866,2.09339,1.343194,0.111521,0.29505,0.041622,0.141069
CosmosQA,0.873289,2.182176,1.411812,0.119116,0.344874,0.85838,2.143313,1.384106,0.121368,0.325055,0.019819,0.060972


In [63]:
combined_a.corr(method='spearman')

Unnamed: 0,lower,upper,mean,var,Trimmed,lower.1,upper.1,mean.1,var.1,Full,diff,rel_diff
lower,1.0,0.950246,0.963547,0.594581,0.963547,0.982266,0.871921,0.966995,0.591626,0.966995,-0.283251,-0.518719
upper,0.950246,1.0,0.969458,0.753695,0.969458,0.936946,0.929064,0.970443,0.731527,0.970443,-0.302463,-0.537438
mean,0.963547,0.969458,1.0,0.67734,1.0,0.95665,0.93202,0.966995,0.687192,0.966995,-0.255665,-0.490148
var,0.594581,0.753695,0.67734,1.0,0.67734,0.575862,0.762562,0.642857,0.801478,0.642857,-0.204433,-0.34532
Trimmed,0.963547,0.969458,1.0,0.67734,1.0,0.95665,0.93202,0.966995,0.687192,0.966995,-0.255665,-0.490148
lower,0.982266,0.936946,0.95665,0.575862,0.95665,1.0,0.863054,0.979803,0.578818,0.979803,-0.376355,-0.589655
upper,0.871921,0.929064,0.93202,0.762562,0.93202,0.863054,1.0,0.906404,0.86798,0.906404,-0.333005,-0.538916
mean,0.966995,0.970443,0.966995,0.642857,0.966995,0.979803,0.906404,1.0,0.661084,1.0,-0.41133,-0.629064
var,0.591626,0.731527,0.687192,0.801478,0.687192,0.578818,0.86798,0.661084,1.0,0.661084,-0.346798,-0.474384
Full,0.966995,0.970443,0.966995,0.642857,0.966995,0.979803,0.906404,1.0,0.661084,1.0,-0.41133,-0.629064
