# Import

In [1]:
%matplotlib widget

In [2]:
import os
import pickle
import copy

import pandas as pd
import seaborn as sns
import numpy
import torch
import scipy
import scipy.stats

import pyro
import pyro.infer
import pyro.infer.mcmc
import pyro.distributions as dist
import torch.distributions.constraints as constraints
from tqdm.auto import tqdm

from sklearn.decomposition import FastICA

import matplotlib.pyplot as plot
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import warnings
warnings.filterwarnings('ignore')
sns.set(style="whitegrid")

In [3]:
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))

# Define

In [4]:
def sigmoid(x):
    return 1./(1.+torch.exp(-x))

def icc_best_deriv(alpha, beta, theta, model_names, gamma=None, col='mean'):
    '''
    Method to calculate the locally estimated headroom (LEH) score, defined as
    the derivative of the item characteristic curve w.r.t. the best performing model.
    
    Args:
        alpha:       DataFrame of discrimination parameter statistics for each item.
        beta:        DataFrame of difficulty parameter statistics for each item.
        theta:       DataFrame of ability parameter statistics for each responder.
        model_names: List of responder names.
        gamma:       DataFrame of guessing parameter statistics for each item.
        col:         DataFrame column name to use for calculating LEH scores.
    
    Returns:
        scores:      LEH scores for each item.    
    '''
    best_idx, best_value = theta[col].argmax(), theta[col].max()
    print(f'Best model: {model_names[best_idx]}\n{best_value}')
    
    a, b = torch.tensor(alpha[col].values), torch.tensor(beta[col].values)
    
    #logits = (a*(best_value-b))
    logits = (torch.matmul(a,best_value.T) + b).T
    sigmoids = sigmoid(logits)
    scores = sigmoids*(1.-sigmoids)*a
    
    print(f'No gamma: {scores.mean()}')
    if not gamma is None:
        g = torch.tensor(gamma[col].apply(lambda x: x.item()).values)
        scores = (1.-g)*scores
        print(f'With gamma: {scores.mean()}')
    
    return scores      
    
    

In [5]:
def get_model_guide(alpha_dist, theta_dist, alpha_transform, theta_transform):
    model = lambda obs: irt_model(obs, alpha_dist, theta_dist, alpha_transform = alpha_transform, theta_transform = theta_transform)
    guide = lambda obs: vi_posterior(obs, alpha_dist, theta_dist)
    
    return model, guide

In [6]:
def get_data_accuracies(data, verbose = False, get_cols = False):
    '''
    Method to reformat `data` and calculate item and responder accuracies.
    
    Args:
        data:                DataFrame of item responses.
        verbose:             Boolean value of whether to print statements.
        get_cols:            Boolean value of whether to return original column
                             values of `data`.
        
    Returns:
        new_data:            Reformatted `data`, dropping first column.
        accuracies:          Accuracy for each responder across examples.
        example_accuracies:  Accuracy for each example across responders.
        data.columns.values: Returns only if `get_cols` is True. Original column
                             values of `data`.
    '''
    new_data = numpy.array(data)
    new_data = new_data[:,1:]
    
    model_names = dict(data['userid'])
    accuracies = new_data.mean(-1)
    example_accuracies = new_data.mean(0)
    
    if verbose:
        print('\n'.join([f'{name}: {acc}' for name, acc in zip(model_names.values(),accuracies)]))
    
    if get_cols:
        return new_data, accuracies, example_accuracies, data.columns.values
    else:
        return new_data, accuracies, example_accuracies

In [7]:
def get_stats_CI(params, p=0.95, dist='normal'):
    '''
    Method to calculate lower and upper quantiles defined by `p`, mean, and variance of `param`
    
    Args:
        params: Dictionary of distribution parameters for each item keyed according to the 
                parametric distribution defined by `dist`.
        p:      Percent of distribution covered by the lower and upper interval values for each
                parameter.
        dist:   Name of parametric distribution
    
    Returns:
        return: {
            'lower': Lower interval values of each parameter,
            'upper': Upper interval values of each parameter,
            'mean' : Mean of each parameter,
            'var'  : Variance of each parameter
        }
    '''
    stats = {}
    if dist == 'normal':
        L,U = scipy.stats.norm.interval(p,loc=params['mu'], scale=torch.exp(params['logstd']))
        M,V = scipy.stats.norm.stats(loc=params['mu'], scale=torch.exp(params['logstd']))
    elif dist == 'log-normal':
        L,U = scipy.stats.lognorm.interval(p, s=torch.exp(params['logstd']), scale=torch.exp(params['mu']))
        M,V = scipy.stats.lognorm.stats(s=torch.exp(params['logstd']), scale=torch.exp(params['mu']))
    elif dist == 'beta':
        L,U = scipy.stats.beta.interval(p,a=params['alpha'], b=params['beta'])
        M,V = scipy.stats.beta.stats(a=params['alpha'], b=params['beta'])
    else:
        raise TypeError(f'Distribution type {dist} not supported.')
    
    return {
        'lower':[L],
        'upper':[U],
        'mean':[M],
        'var':[V],
    }

In [8]:
def get_plot_stats(exp_dir, alpha_dist, theta_dist, transforms, p = 0.95):
    '''
    Method to return plotting statistics for 3 parameter IRT model parameters.
    
    Args:
        exp_dir:          Path to 3 parameter IRT parameters and responses.
        alpha_dist:       Name of the item discrimination [a] distribution.
        theta_dist:       Name of the responder ability [t] distribution.
        transforms:       Dictionary of transformations to apply to each parameter type
                          where keys are parameter names and values are functions.
        p:                Percent of distribution covered by the lower and upper interval 
                          values for each parameter.
    
    Returns:
        param_plot_stats: Dictionary of parameter plot statistics where keys are parameter
                          names and values are plot statistics dictionaries as defined by
                          get_stats_CI().
    '''
    param_dists = {
        'a':alpha_dist,
        'b':'normal',
        'g':'normal',
        't':theta_dist,
    }

    dist_params = {
        'normal':['mu', 'logstd'],
        'log-normal':['mu', 'logstd'],
        'beta':['alpha', 'beta'],
    }

    pyro.clear_param_store()
    pyro.get_param_store().load(os.path.join(exp_dir, 'params.p'))

    with torch.no_grad():
        pyro_param_dict = dict(pyro.get_param_store().named_parameters())
    
    # get stats for plotting
    param_plot_stats = {}

    for param, param_dist in param_dists.items():
        temp_params = dist_params[param_dist]

        for idx, (p1_orig, p2_orig) in enumerate(zip(pyro_param_dict[f'{param} {temp_params[0]}'], pyro_param_dict[f'{param} {temp_params[1]}'])):
            p1, p2 = p1_orig.detach(), p2_orig.detach()
            
            temp_stats_df = pd.DataFrame.from_dict(
                get_stats_CI(
                    params = {
                        temp_params[0]:p1,
                        temp_params[1]:p2,
                    },
                    p=p,
                    dist = param_dist,
                )
            )
            
            temp_stats_df = temp_stats_df.applymap(transforms[param])
        
            if idx == 0:
                param_plot_stats[param] = temp_stats_df
            else:
                param_plot_stats[param] = param_plot_stats[param].append(temp_stats_df, ignore_index = True)
    
    return param_plot_stats

In [9]:
def sign_mult(df1, df2):
    newdf = copy.deepcopy(df2)
    
    for idx, row in df1.iterrows():
        if numpy.sign(row['mean']) < 0:
            newdf.loc[idx,'mean'] = -1*newdf.loc[idx,'mean']
            newdf.loc[idx,'lower'] = -1*newdf.loc[idx,'upper']
            newdf.loc[idx,'upper'] = -1*newdf.loc[idx,'lower']
    
    return newdf

In [10]:
def get_diff_by_set(diffs, item_ids):
    diff_by_set = {}
    id_split = '_'

    max_diff = -1e6
    min_diff = 1e6
    
    for idx, diff in enumerate(diffs):
        set_name = item_ids[idx].split(id_split)[0]

        if set_name in diff_by_set.keys():
            diff_by_set[set_name].append(diff)
        else:
            diff_by_set[set_name] = [diff]
            
        if diff < min_diff:
            min_diff = diff
            
        if diff > max_diff:
            max_diff = diff
    
    return diff_by_set, min_diff, max_diff

# Get Tasks

In [11]:
from multi_virt_v2 import *

In [12]:
datasets="boolq,cb,commonsenseqa,copa,cosmosqa,hellaswag,rte,snli,wic,qamr,arct,mcscript,mctaco,mutual,mutual-plus,quoref,socialiqa,squad-v2,wsc,mnli,mrqa-nq,newsqa,abductive-nli,arc-easy,arc-challenge,piqa,quail,winogrande,anli"
data_names, responses, n_items = get_files(
    os.path.join(repo, 'data'),
    "csv",
    set(datasets.split(','))
)

In [13]:
task_metadata = pd.read_csv('../irt_scripts/task_metadata.csv')
task_metadata.set_index("jiant_name", inplace=True)
task_list = [x for x in task_metadata.index if x in data_names]

In [14]:
total = 0
task_name = []
task_format = []

for tname, size in zip(data_names, n_items):
    name = task_metadata.loc[tname]['taskname']
    total += size
    task_name += [name for _ in range(size)]
    task_format += [task_metadata.loc[tname]['format'] for _ in range(size)]
    
task_name = pd.DataFrame(task_name, columns=['task_name'])
task_format = pd.DataFrame(task_format, columns=['format'])
task_name_format = pd.concat([task_name, task_format], axis=1)

In [15]:
len(data_names)

28

# Get Params and Order

In [16]:
file_name='alpha-lognormal-identity-dim3_theta-normal-identity_nosubsample_1.00_0.20'
exp_dir = os.path.join(repo, 'params_mvirt', file_name)
p = 0.95

combined_responses = pd.read_pickle(os.path.join(exp_dir, 'responses.p')).reset_index()

In [17]:
# Check accuracy of roberta-large models
extractmodel = 'roberta-large_best'
tie_break = 0

acc_by_dataset = {}

roberta_rp = combined_responses.loc[combined_responses['userid']==extractmodel, :]
if roberta_rp.shape[0] > 1:
    roberta_rp = roberta_rp.iloc[tie_break, :]

cols = combined_responses.columns.values

for item in cols[1:]:
    data_name = '_'.join(item.split('_')[:-1])
    resp = roberta_rp[item].item()
    
    if data_name in acc_by_dataset:
        acc_by_dataset[data_name]['correct'] += resp
        acc_by_dataset[data_name]['total'] += 1
    else:
        acc_by_dataset[data_name] = {'correct': resp, 'total': 1}

print(extractmodel)
print('='*90)
print(f'Overall acc: {roberta_rp.iloc[0, 1:].sum()/(roberta_rp.shape[1]-1):.4f}')        

for data_name, acc_dict in acc_by_dataset.items():
    print(f'{data_name} acc: {acc_dict["correct"]/acc_dict["total"]:.4f}')

roberta-large_best
Overall acc: 0.7521
abductive_nli acc: 0.8564
arc_challenge acc: 0.3319
arc_easy acc: 0.6296
arct acc: 0.8604
boolq acc: 0.8367
cb acc: 0.8571
commonsenseqa acc: 0.6759
copa acc: 0.8400
cosmosqa acc: 0.7984
hellaswag acc: 0.8417
mcscript acc: 0.9183
mctaco acc: 0.5360
mnli acc: 0.8991
mrqa_natural_questions acc: 0.6941
mutual_plus acc: 0.7314
mutual acc: 0.8668
newsqa acc: 0.5542
piqa acc: 0.7617
qamr acc: 0.7303
quail acc: 0.6691
quoref acc: 0.8023
rte acc: 0.8345
snli acc: 0.9203
socialiqa acc: 0.7738
squad_v2 acc: 0.4301
wic acc: 0.7085
winogrande acc: 0.7697
wsc acc: 0.6154


In [18]:
# set to False if run for the first time
# note that this will take sometimes to run if the datasets are big
load_from_cache = True

In [19]:
# distribution and transformation
alpha_dist = 'log-normal'
alpha_transf = 'standard'
theta_dist = 'normal'
theta_transf = 'standard'

exp_dir = os.path.join(repo, 'params_mvirt', file_name)

combined_responses = pd.read_pickle(os.path.join(exp_dir, 'responses.p')).reset_index()
data, accuracies, example_accuracies = get_data_accuracies(combined_responses)
column_names = combined_responses.columns[1:]
select_ts = {
    'standard':lambda x:x,
    'positive':lambda x:torch.log(1+torch.exp(torch.tensor(x))),
    'sigmoid':lambda x:sigmoid(torch.tensor(x)),
}

transforms = {
    'a':select_ts[alpha_transf],
    'b':select_ts['standard'],
    'g':select_ts['sigmoid'],
    't':select_ts[theta_transf],
}

if load_from_cache:
    param_plot_stats = {}

    for key in transforms.keys():
        with open(os.path.join(exp_dir, 'plot_stats_pickles', f'{key}.p'), 'rb') as f:
            param_plot_stats[key] = pickle.load(f)
else:
    param_plot_stats = get_plot_stats(
        exp_dir,
        alpha_dist,
        theta_dist,
        transforms,
        p = 0.95
    )
    
    os.makedirs(os.path.join(exp_dir, 'plot_stats_pickles'), exist_ok=True)
    for key, value in param_plot_stats.items():
        with open(os.path.join(exp_dir, 'plot_stats_pickles', f'{key}.p'), 'wb') as f:
            pickle.dump(value, f)

In [20]:
combined_responses

Unnamed: 0,userid,abductive_nli_0,abductive_nli_1,abductive_nli_2,abductive_nli_3,abductive_nli_4,abductive_nli_5,abductive_nli_6,abductive_nli_7,abductive_nli_8,...,wsc_42,wsc_43,wsc_44,wsc_45,wsc_46,wsc_47,wsc_48,wsc_49,wsc_50,wsc_51
0,roberta-base-10M-1_best,1,0,1,1,1,1,1,0,1,...,1,0,1,1,0,1,0,1,0,1
1,roberta-base-10M-1_1,0,0,0,1,1,0,1,0,1,...,1,0,1,0,0,1,0,0,0,1
2,roberta-base-10M-1_25,1,1,1,1,1,1,1,1,1,...,0,0,0,1,0,1,1,1,1,1
3,roberta-base-10M-1_50,1,0,1,1,1,1,1,0,1,...,1,0,0,1,0,1,1,1,0,1
4,roberta-base-10M-1_10,1,0,1,1,1,1,1,0,1,...,1,0,0,1,0,1,1,1,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85,xlm-roberta-large_best,1,1,0,1,1,1,1,1,1,...,1,1,0,1,0,1,1,1,0,0
86,xlm-roberta-large_1,1,1,0,1,1,0,1,0,1,...,0,0,1,0,1,0,0,0,1,1
87,xlm-roberta-large_25,1,1,0,1,1,1,0,1,1,...,1,1,0,1,0,1,1,1,0,0
88,xlm-roberta-large_50,1,1,1,1,1,1,1,1,1,...,1,1,0,1,0,1,1,1,0,0


In [21]:
model_names = []
model_levels = []
for m in combined_responses['userid']:
    mname = m.split('_')[0]
    mlevel = m.split('_')[-1]
    if mname.endswith('-1') or mname.endswith('-2') or mname.endswith('-3'):
        mname = mname[:-2]
    model_names.append(mname)
    
    mlevel_append = '' if mlevel == 'best' else r'%'
    model_levels.append(mlevel+mlevel_append)

In [30]:
#  we will only use log mean for discriminative parameter
for param_key, param_stat in param_plot_stats.items():
    param_stat['log_mean'] = param_stat['mean'].apply(lambda x: numpy.log(x))
    
    if param_key in ['a', 't', 'b']:
        dimensions =  param_stat['log_mean'][0].shape[0]
        for dim in range(dimensions):
            param_stat['log_mean_'+str(dim)]= param_stat['log_mean'].apply(lambda x: x[dim])
            param_stat['mean_'+str(dim)]= param_stat['mean'].apply(lambda x: x[dim])
        
    print(param_key, param_stat['log_mean'].isnull().sum())

a 0
b 0
g 0
t 0


In [31]:
param_a = pd.concat([param_plot_stats['a'], task_name_format], axis=1)
param_b = pd.concat([param_plot_stats['b'], task_name_format], axis=1)

task_order = [task_metadata.loc[x]['taskname'] for x in task_list]

In [32]:
param_plot_stats['a']['log_mean'][0]

array([-0.00951589,  0.03023954,  0.07342558])

In [34]:
param_plot_stats['t']

Unnamed: 0,lower,upper,mean,var,log_mean,log_mean_0,mean_0,log_mean_1,mean_1,log_mean_2,mean_2
0,"[-1.442148606546365, -1.7268539335048176, -1.0...","[1.1482866809444057, 1.6715730841672398, 0.715...","[-0.14693096280097961, -0.02764042466878891, -...","[0.43670615553855896, 0.7516224384307861, 0.19...","[nan, nan, nan]",,-0.146931,,-0.027640,,-0.144212
1,"[-2.357598511022056, -2.5945949808610234, -2.5...","[0.38045987252565694, 0.3165095106614382, 0.20...","[-0.9885693192481995, -1.1390427350997925, -1....","[0.48789823055267334, 0.5515176653862, 0.49179...","[nan, nan, nan]",,-0.988569,,-1.139043,,-1.167357
2,"[-0.9174365772970607, -1.242486372299513, -1.7...","[1.192494906402438, 0.9288059023097364, 1.2459...","[0.1375291645526886, -0.1568402349948883, -0.2...","[0.28972136974334717, 0.3068176805973053, 0.58...","[-1.9839192785290387, nan, nan]",-1.983919,0.137529,,-0.156840,,-0.254278
3,"[-1.2110584852718678, -1.6518333638367848, -1....","[0.983906793358263, 0.7339577758965687, 0.6595...","[-0.11357584595680237, -0.45893779397010803, -...","[0.31354445219039917, 0.37043213844299316, 0.3...","[nan, nan, nan]",,-0.113576,,-0.458938,,-0.417190
4,"[-2.1191745897778382, -1.5956859827241952, -1....","[1.4682521125325072, 1.1102622091970498, 0.992...","[-0.3254612386226654, -0.2427118867635727, -0....","[0.8375483751296997, 0.4765218198299408, 0.564...","[nan, nan, nan]",,-0.325461,,-0.242712,,-0.480240
...,...,...,...,...,...,...,...,...,...,...,...
85,"[-1.1783743309990695, -1.5428807687355954, -1....","[2.117602925302201, 2.2496449899270017, 1.7660...","[0.46961429715156555, 0.3533821105957031, 0.37...","[0.7069883346557617, 0.936053991317749, 0.5021...","[-0.7558435655493219, -1.0402053412162369, -0....",-0.755844,0.469614,-1.040205,0.353382,-0.974983,0.377199
86,"[-2.6400108241973657, -2.2896078091174914, -2....","[0.30422162055356705, 0.5832053881198715, 0.91...","[-1.1678946018218994, -0.8532012104988098, -0....","[0.5641414523124695, 0.5371042490005493, 0.929...","[nan, nan, nan]",,-1.167895,,-0.853201,,-0.978130
87,"[-1.6181951851560215, -1.7749836801764172, -1....","[1.4899627656652072, 1.6650538056370419, 1.265...","[-0.0641162097454071, -0.05496493726968765, 0....","[0.6287094354629517, 0.7701408863067627, 0.381...","[nan, nan, -2.9092914732792736]",,-0.064116,,-0.054965,-2.909291,0.054514
88,"[-0.9851609115381976, -1.1955144114032166, -1....","[1.1744053755780002, 1.8110531277194397, 1.344...","[0.09462223201990128, 0.3077693581581116, 0.12...","[0.30351272225379944, 0.5882822871208191, 0.38...","[-2.357862819753815, -1.1784046136761395, -2.0...",-2.357863,0.094622,-1.178405,0.307769,-2.091735,0.123473


In [35]:
param_plot_stats['b']['mean'][0]

array([0.92876601, 0.85530853, 0.95257145])

In [36]:
def icc_best_deriv(alpha, beta, theta, model_names, gamma=None, col='mean'):
    '''
    Method to calculate the locally estimated headroom (LEH) score, defined as
    the derivative of the item characteristic curve w.r.t. the best performing model.
    
    Args:
        alpha:       DataFrame of discrimination parameter statistics for each item.
        beta:        DataFrame of difficulty parameter statistics for each item.
        theta:       DataFrame of ability parameter statistics for each responder.
        model_names: List of responder names.
        gamma:       DataFrame of guessing parameter statistics for each item.
        col:         DataFrame column name to use for calculating LEH scores.
    
    Returns:
        scores:      LEH scores for each item.    
    '''
    best_idx, best_value = theta[col].argmax(), theta[col].max()
    print(f'Best model: {model_names[best_idx]}\n{best_value}')
    
    a, b = torch.tensor(alpha[col]), torch.tensor(beta[col])
    
    logits = (a*(best_value-b))
    sigmoids = sigmoid(logits)
    scores = (sigmoids*(1.-sigmoids)).unsqueeze(1) * a
    
    print(f'No gamma: {scores.mean()}')
    if not gamma is None:
        g = torch.tensor(gamma[col].apply(lambda x: x.item()).values)
        scores = (1.-g).unsqueeze(1) *scores
        print(f'With gamma: {scores.mean()}')
    
    return scores      

In [37]:
'''
leh_scores = icc_best_deriv(
    param_plot_stats['a'],
    param_plot_stats['b'],
    param_plot_stats['t'],
    model_names,
    gamma = param_plot_stats['g'],
)
leh_scores_plot = pd.DataFrame(pd.Series(leh_scores), columns = ['mean'])
print(leh_scores_plot)
'''

"\nleh_scores = icc_best_deriv(\n    param_plot_stats['a'],\n    param_plot_stats['b'],\n    param_plot_stats['t'],\n    model_names,\n    gamma = param_plot_stats['g'],\n)\nleh_scores_plot = pd.DataFrame(pd.Series(leh_scores), columns = ['mean'])\nprint(leh_scores_plot)\n"

In [38]:
#leh_scores_plot = pd.concat([leh_scores_plot, task_name_format], axis=1)

# Plots

- Distribution: log_normal
- Constraint: None

In [39]:
task_metadata.set_index("taskname", inplace=True)

## Set fonts

In [40]:
font_label = 12
font_legend = 10
font_legendtitle = font_legend + 4
font_xtick = 12
font_title = 14
marker_scale = 1.5

plot.rc('axes', labelsize=font_label)
plot.rc('axes', titlesize=font_title)
plot.rc('legend', fontsize=font_legend)

## Plot

In [41]:
def plot_box(df, order, param_type, task_metadata, dim = 0, xsize=12, ysize=3.5, width=0.6, rotation=65, ylim=None, ystep=None):
    
    param2label = {
        'discriminative':r'$\log$ Discrimination ($\log$ $\alpha$)-Dim'+str(dim),
        'difficulty': r'Difficulty ($\beta$)-Dim'+str(dim),
        "disc-diff": "IRT Score", # "discrimination - difficulty product\n", # + r"($ n(\alpha) x n(\beta$) )"
        "disc-diff_pos": "IRT Score", # "discrimination - difficulty product\n", # + r"($ n(\alpha) x n(\beta$) )"
        "disc-diff_minmax": "IRT Score", # "discrimination - difficulty product\n", # + r"($ n(\alpha) x n(\beta$) )"
        "irt-score": "LEH Score",
    }
    
    
    
    if dim is None:
        param2yname = {
            'discriminative': "log_mean",
            'difficulty': "mean",
            "disc-diff": 0,
            "disc-diff_pos": 0,
            "disc-diff_minmax": 0,
            "irt-score": "mean",
        }
    else:
        param2yname = {
            'discriminative': "log_mean_"+str(dim),
            'difficulty': "mean_"+str(dim),
            "disc-diff": 0,
            "disc-diff_pos": 0,
            "disc-diff_minmax": 0,
            "irt-score": "mean",
        }
    
    sns.set_style("whitegrid")
    sns.set_context("paper")
    f, ax = plot.subplots(figsize=(xsize, ysize))
    
    my_pal = {"MC-par": "r",
              "MC-sent": "b",
              "classification":"g",
              "span selection": "grey"}    
    
    ax = sns.boxplot(x="task_name", y=param2yname[param_type], data=df, order=order, width=width)
    
    
    
    for i, task in enumerate(order):
        # Select which box you want to change    
        mybox = ax.artists[i]
        
        # Change the appearance of that box
        skill = task_metadata.loc[task]['format']
        mybox.set_facecolor(my_pal[skill])
    
    
    # Add transparency to colors
    for patch in ax.artists:
         r, g, b, a = patch.get_facecolor()
         patch.set_facecolor((r, g, b, .6))

    sns.despine()
    plot.xticks(range(len(order)), order, rotation=rotation, fontsize=font_xtick)
    
    if not ylim is None and not ystep is None:
        plot.ylim(ylim)
        plot.yticks(numpy.arange(ylim[0], ylim[1]+ystep, ystep))
    
    plot.xlabel(None)
    plot.ylabel(param2label[param_type], fontsize=font_label)
    
    plot.savefig('../plots/' + param_type + "_box.png",
                format='png', dpi=300,
                bbox_inches = 'tight',
                pad_inches = .1)

In [42]:
plot_box(
    param_a,
    task_order,
    "discriminative",
    task_metadata
)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [43]:
plot_box(
    param_a,
    task_order,
    "discriminative",
    task_metadata,
    1
)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [44]:
plot_box(
    param_a,
    task_order,
    "discriminative",
    task_metadata,
    2
)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [45]:
alpha_arr = np.vstack(param_a['log_mean'])
ica = FastICA(n_components=2, random_state=0)
S_alpha_ = ica.fit_transform(alpha_arr)  # Reconstruct signals
A_alpha_ = ica.mixing_  # Get estimated mixing 
S_alpha_.shape

(79033, 2)

In [46]:
param_a_ica = param_a.copy()
alpha_arr = np.vstack(param_a_ica['log_mean'])

In [47]:
param_a_ica['log_mean']

0        [-0.009515893665189453, 0.03023954024362936, 0...
1        [-0.1513421920939463, 0.18606777582159761, 0.0...
2        [-0.014001359841734198, -0.08325552196657199, ...
3        [-0.03889330837442285, 0.13329149986081698, 0....
4        [0.09827703336900333, 0.08355058098296675, -0....
                               ...                        
79028    [-0.05183591913926985, -0.07838336067062691, 0...
79029    [-0.08501122609610069, -0.11224314457738692, 0...
79030    [-0.0015630414200946258, -0.06388216360960638,...
79031    [-0.03182185703978923, 0.21183932087775573, 0....
79032    [0.061542532675543424, -0.0005882920398324119,...
Name: log_mean, Length: 79033, dtype: object

In [48]:
param_a_ica['log_mean'] = [np.array(x) for x in S_alpha_.tolist()]
dimensions =  param_a_ica['log_mean'][0].shape[0]
for dim in range(dimensions):
    param_a_ica['log_mean_'+str(dim)]= param_a_ica['log_mean'].apply(lambda x: x[dim])


In [49]:
%matplotlib widget
import matplotlib.pyplot as plt

plt.figure()

models = [alpha_arr, S_alpha_]
names = ['Alpha (mixed signal)',
         'ICA recovered Alpha signals']
colors = ['red', 'steelblue']

for ii, (model, name) in enumerate(zip(models, names), 1):
    plt.subplot(2, 1, ii)
    plt.title(name)
    for sig, color in zip(model.T, colors):
        plt.plot(sig, color=color)

plt.subplots_adjust(0.09, 0.04, 0.94, 0.94, 0.26, 0.46)
plt.show()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [50]:
plot_box(
    param_a_ica,
    task_order,
    "discriminative",
    task_metadata
)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [51]:
plot_box(
    param_a_ica,
    task_order,
    "discriminative",
    task_metadata,
    1
)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [52]:
print(numpy.corrcoef(param_a_ica['log_mean_0'], param_a_ica['log_mean_1']))

[[ 1.00000000e+00 -8.85986298e-16]
 [-8.85986298e-16  1.00000000e+00]]


In [53]:
plot_box(
    param_a,
    task_order,
    "discriminative",
    task_metadata,
    2
)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [54]:
plot_box(
    param_b,
    task_order,
    "difficulty",
    task_metadata,
    0
)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [55]:
plot_box(
    param_b,
    task_order,
    "difficulty",
    task_metadata,
    1
)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [56]:
plot_box(
    param_b,
    task_order,
    "difficulty",
    task_metadata,
    2
)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [58]:
#plot_box(
#    leh_scores_plot,
#    task_order,
#    "irt-score",
#    task_metadata
#)

In [59]:
dim=0
xsize = 12
ysize = 7
width = 0.6
rotation = 65
order = task_order


param2label = {
    'discriminative':r'$\log$ Discrimination ($\log$ $\alpha$)-Dim'+str(dim),
    'difficulty': r'Difficulty ($\beta$)-Dim'+str(dim),
    "disc-diff": "IRT Score", # "discrimination - difficulty product\n", # + r"($ n(\alpha) x n(\beta$) )"
    "disc-diff_pos": "IRT Score", # "discrimination - difficulty product\n", # + r"($ n(\alpha) x n(\beta$) )"
    "disc-diff_minmax": "IRT Score", # "discrimination - difficulty product\n", # + r"($ n(\alpha) x n(\beta$) )"
}

param2yname = {
    'discriminative': "log_mean_"+str(dim),
    'difficulty': "mean_"+str(dim),
    "disc-diff": 0,
    "disc-diff_pos": 0,
    "disc-diff_minmax": 0,
}

sns.set_style("whitegrid")
sns.set_context("paper")

f, (ax1, ax2) = plot.subplots(nrows=2, ncols=1, squeeze=True, sharex=True, figsize=(xsize, ysize))

my_pal = {"MC-par": "r",
          "MC-sent": "b",
          "classification":"g",
          "span selection": "grey"}    

param_type = "discriminative"
ax = sns.boxplot(x="task_name", y=param2yname["discriminative"], data=param_a, order=order, width=width, ax=ax1)


for i, task in enumerate(order):
    # Select which box you want to change    
    mybox = ax.artists[i]

    # Change the appearance of that box
    skill = task_metadata.loc[task]['format']
    mybox.set_facecolor(my_pal[skill])


# Add transparency to colors
for patch in ax.artists:
     r, g, b, a = patch.get_facecolor()
     patch.set_facecolor((r, g, b, .6))

ax1.set_ylabel(param2label[param_type], fontsize=font_label)
ax1.set_xlabel("")

param_type = "difficulty"
ax = sns.boxplot(x="task_name", y=param2yname["difficulty"], data=param_b, order=order, width=width, ax=ax2)


for i, task in enumerate(order):
    # Select which box you want to change    
    mybox = ax.artists[i]

    # Change the appearance of that box
    skill = task_metadata.loc[task]['format']
    mybox.set_facecolor(my_pal[skill])


# Add transparency to colors
for patch in ax.artists:
     r, g, b, a = patch.get_facecolor()
     patch.set_facecolor((r, g, b, .6))

sns.despine()
plot.xticks(range(len(order)), order, rotation=rotation, fontsize=font_xtick)
plot.xlabel(None)
ax2.set_ylabel(param2label[param_type], fontsize=font_label)

plot.subplots_adjust(hspace=0.1)

plot.savefig('../plots/disc_diff_combined_box.png',
            format='png', dpi=300,
            bbox_inches = 'tight',
            pad_inches = .1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Per dataset analysis - single scatter plots

In [60]:
def single_scatter_plot(param1, param2, label1, label2, ymin, ymax, step, taskname, order, dim=0):
    
    param2label = {
        'alpha':r'$\log$ Discrimination ($\log$ $\alpha$)-Dim',
        'beta': r'Difficulty ($\beta$)-Dim',
        'gamma': r'Guessing ($\gamma$)'
    }
    
    sns.set_theme(style="whitegrid")
    sns.despine(offset=5, trim=True)

    keys = [
        'lower1', 
        'upper1', 
        'mean1', 
        'var1', 
        'log_mean1',
        param2label[label1]+'0',
        'mean_0_1',
        param2label[label1]+'1',
        'mean_1_1',
        param2label[label1]+'2',
        'mean_2_1',
        'lower2',
        'upper2',
        'mean2',
        'var2',
        'log_mean2',
        'log_mean_0_2',
        param2label[label2]+'0',
        'log_mean_1_2',
        param2label[label2]+'1',
        'log_mean_2_2',
        param2label[label2]+'2',
        'task',
        'format'
    ]
    combined_data = pd.concat([param1, param2, taskname], axis=1)
    combined_data = combined_data.set_axis(keys, axis=1)

    # Create an array with the colors you want to use
    colors = ['b', 'r', 'g', 'grey']
    # Set your custom color palette
    customPalette = sns.set_palette(sns.color_palette(colors))
    grid = sns.FacetGrid(combined_data, col="task", hue="format",
                         col_order=order,
                         palette=customPalette,
                         col_wrap=7, height=2)

    grid.map(plot.scatter, param2label[label1]+str(dim), param2label[label2]+str(dim), marker=".", s=1, alpha=0.75)
    grid.set_titles(size=font_title)

    # Adjust the tick positions and labels
    grid.set(xticks=np.arange(-1, 1.5, 0.5), yticks=np.arange(ymin, ymax, step),
             xlim=(-1, 1), ylim=(ymin, ymax), )
    
    grid.set_axis_labels('', '')
    
    grid.fig.text(
        x=-0.01, y=0.5,
        verticalalignment='center',
        s=param2label[label2]+str(dim),
        size=font_label,
        rotation=90,
    )
    
    grid.fig.text(
        x=0.5, y=-0.01,
        horizontalalignment='center',
        s=param2label[label1]+str(dim),
        size=font_label,
    )
    
    # Adjust sub-plot title
    axes = grid.axes.flatten()
    for ax in axes:
        new_title = ax.get_title().replace('task = ', '')
        ax.set_title(new_title)
        
    # Adjust the arrangement of the plots
    grid.fig.tight_layout()
    plot.subplots_adjust(hspace=0.2)
    
    plot.savefig('../plots/single_' + label1 + '_' + label2 + ".png",
                format='png',dpi=300,bbox_inches = 'tight',
                pad_inches = .1)

In [61]:
ymin, ymax, step = -2, 4, 1
single_scatter_plot(param_plot_stats['a'], param_plot_stats['b'], 'alpha', 'beta', ymin, ymax, step, task_name_format, task_order)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [62]:
ymin, ymax, step = -2, 4, 1
single_scatter_plot(param_plot_stats['a'], param_plot_stats['b'], 'alpha', 'beta', ymin, ymax, step, task_name_format, task_order, 1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [63]:
ymin, ymax, step = -2, 4, 1
single_scatter_plot(param_plot_stats['a'], param_plot_stats['b'], 'alpha', 'beta', ymin, ymax, step, task_name_format, task_order, 2)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [64]:
#ymin, ymax, step = 0, 1.0, 0.2
#single_scatter_plot(param_plot_stats['a'], param_plot_stats['g'], 'alpha', 'gamma', ymin, ymax, step, task_name_format, task_order)

In [65]:
model_name = 'Model'
checkpoint_name = 'Checkpoint'

In [66]:
model2plot = {
    'roberta-med-small-1M': r'RoBERTa-Med-Small-1M',
    'roberta-base-10M': r'RoBERTa-Base-10M',
    'roberta-base-100M': r'RoBERTa-Base-100M',
    'roberta-base-1B': r'RoBERTa-Base-1B',
    'bert-base-cased': r'BERT-Base',
    'bert-large-cased': r'BERT-Large',
    'roberta-base': r'RoBERTa-Base',
    'roberta-large': r'RoBERTa-Large',
    'xlm-roberta-large': r'XLM-R-Large',
    'albert-xxlarge-v2': r'ALBERT-XXL-v2',
}
model_names = [model2plot[name] for name in model_names]

In [67]:
df_acc = pd.DataFrame(accuracies, columns=['acc.'])
df_model = pd.DataFrame(model_names, columns=[model_name])
df_mlevel = pd.DataFrame(model_levels, columns=[checkpoint_name])

In [70]:
dimensions =  param_plot_stats['t']['mean'][0].shape[0]
dimensions

3

In [71]:
param_plot_stats['t']['mean_0']

0    -0.146931
1    -0.988569
2     0.137529
3    -0.113576
4    -0.325461
        ...   
85    0.469614
86   -1.167895
87   -0.064116
88    0.094622
89   -0.052684
Name: mean_0, Length: 90, dtype: float64

In [74]:
combined_data

Unnamed: 0,acc.,theta,Model,Checkpoint
0,0.54166,-0.146931,RoBERTa-Base-10M,best
1,0.212025,-0.027640,RoBERTa-Base-10M,1%
2,0.498184,-0.144212,RoBERTa-Base-10M,25%
3,0.531221,,RoBERTa-Base-10M,50%
4,0.396973,,RoBERTa-Base-10M,10%
...,...,...,...,...
85,0.709678,,XLM-R-Large,best
86,0.202042,,XLM-R-Large,1%
87,0.676376,,XLM-R-Large,25%
88,0.696823,,XLM-R-Large,50%


In [76]:
font_label = 14
font_legend = 14
font_legendtitle = font_legend + 4
font_xtick = 14
font_title = 14
marker_scale = 1.5

sns.set_theme(style="whitegrid")

#keys = ['acc.', 'lower', 'upper', 'theta', 'var', 'log_mean', model_name, checkpoint_name]
keys = ['acc.', 'theta', model_name, checkpoint_name]
combined_data = pd.concat([df_acc, pd.DataFrame(param_plot_stats['t']['mean_0']), df_model, df_mlevel], axis=1)
combined_data = combined_data.set_axis(keys, axis=1)
hue_order = ['roberta-med-small-1M', 'roberta-base-10M', 'roberta-base-100M', 'roberta-base-1B', 'bert-base-cased', 'bert-large-cased',
            'roberta-base', 'roberta-large', 'xlm-roberta-large', 'albert-xxlarge-v2']
hue_order = [model2plot[name] for name in hue_order]
style_order = [r'1%', r'10%', r'25%', r'50%', 'best']

level2marker = {
    r'1%':'o',
    r'10%':'s',
    r'25%':'P',
    r'50%':'X',
    'best':'^',
}

sizes = [5, 50, 100, 200, 400]
size_order = [r'1%', r'10%', r'25%', r'50%', 'best']

f, ax = plot.subplots(figsize=(14, 4))
sns.despine()

x_lbl, y_lbl = 'theta', 'acc.'
prefix = ''

# Create an array with the colors you want to use
colors = ['r', 'b', 'g', 'm', 'grey', 'orange', 'olive', 'teal', 'skyblue', 'navy']
# Set your custom color palette
customPalette = sns.set_palette(sns.color_palette(colors))

ax = sns.scatterplot(data=combined_data, x=x_lbl, y=y_lbl,
                alpha=0.7,
                palette=customPalette,
                linewidth=0,
                markers=level2marker,
                style=checkpoint_name,
                style_order=style_order,
                s=200,
                hue_order=hue_order,
                hue=model_name)


plot.xlabel(prefix + r'Ability ($\theta$)-Dim0', fontsize=font_label)
plot.ylabel(r'Average Model Accuracy', fontsize=font_label)
plot.legend(
    borderaxespad=0,
    loc="right center",
    ncol=2,
    bbox_to_anchor=(1,1),
    fontsize=font_legend,
    title_fontsize=font_legendtitle,
    markerscale=marker_scale,
)
f.tight_layout()

plot.savefig("../plots/acc_" + prefix + "theta.png",
                format='png',dpi=300,bbox_inches = 'tight',
                pad_inches = .1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [77]:
font_label = 14
font_legend = 14
font_legendtitle = font_legend + 4
font_xtick = 14
font_title = 14
marker_scale = 1.5

sns.set_theme(style="whitegrid")

#keys = ['acc.', 'lower', 'upper', 'theta', 'var', 'log_mean', model_name, checkpoint_name]
keys = ['acc.', 'theta', model_name, checkpoint_name]
combined_data = pd.concat([df_acc, pd.DataFrame(param_plot_stats['t']['mean_1']), df_model, df_mlevel], axis=1)
combined_data = combined_data.set_axis(keys, axis=1)
hue_order = ['roberta-med-small-1M', 'roberta-base-10M', 'roberta-base-100M', 'roberta-base-1B', 'bert-base-cased', 'bert-large-cased',
            'roberta-base', 'roberta-large', 'xlm-roberta-large', 'albert-xxlarge-v2']
hue_order = [model2plot[name] for name in hue_order]
style_order = [r'1%', r'10%', r'25%', r'50%', 'best']

level2marker = {
    r'1%':'o',
    r'10%':'s',
    r'25%':'P',
    r'50%':'X',
    'best':'^',
}

sizes = [5, 50, 100, 200, 400]
size_order = [r'1%', r'10%', r'25%', r'50%', 'best']

f, ax = plot.subplots(figsize=(14, 4))
sns.despine()

x_lbl, y_lbl = 'theta', 'acc.'
prefix = ''

# Create an array with the colors you want to use
colors = ['r', 'b', 'g', 'm', 'grey', 'orange', 'olive', 'teal', 'skyblue', 'navy']
# Set your custom color palette
customPalette = sns.set_palette(sns.color_palette(colors))

ax = sns.scatterplot(data=combined_data, x=x_lbl, y=y_lbl,
                alpha=0.7,
                palette=customPalette,
                linewidth=0,
                markers=level2marker,
                style=checkpoint_name,
                style_order=style_order,
                s=200,
                hue_order=hue_order,
                hue=model_name)


plot.xlabel(prefix + r'Ability ($\theta$)-Dim1', fontsize=font_label)
plot.ylabel(r'Average Model Accuracy', fontsize=font_label)
plot.legend(
    borderaxespad=0,
    loc="right center",
    ncol=2,
    bbox_to_anchor=(1,1),
    fontsize=font_legend,
    title_fontsize=font_legendtitle,
    markerscale=marker_scale,
)
f.tight_layout()

plot.savefig("../plots/acc_" + prefix + "theta1.png",
                format='png',dpi=300,bbox_inches = 'tight',
                pad_inches = .1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [78]:
font_label = 14
font_legend = 14
font_legendtitle = font_legend + 4
font_xtick = 14
font_title = 14
marker_scale = 1.5

sns.set_theme(style="whitegrid")

#keys = ['acc.', 'lower', 'upper', 'theta', 'var', 'log_mean', model_name, checkpoint_name]
keys = ['acc.', 'theta', model_name, checkpoint_name]
combined_data = pd.concat([df_acc, pd.DataFrame(param_plot_stats['t']['mean_2']), df_model, df_mlevel], axis=1)
combined_data = combined_data.set_axis(keys, axis=1)
hue_order = ['roberta-med-small-1M', 'roberta-base-10M', 'roberta-base-100M', 'roberta-base-1B', 'bert-base-cased', 'bert-large-cased',
            'roberta-base', 'roberta-large', 'xlm-roberta-large', 'albert-xxlarge-v2']
hue_order = [model2plot[name] for name in hue_order]
style_order = [r'1%', r'10%', r'25%', r'50%', 'best']

level2marker = {
    r'1%':'o',
    r'10%':'s',
    r'25%':'P',
    r'50%':'X',
    'best':'^',
}

sizes = [5, 50, 100, 200, 400]
size_order = [r'1%', r'10%', r'25%', r'50%', 'best']

f, ax = plot.subplots(figsize=(14, 4))
sns.despine()

x_lbl, y_lbl = 'theta', 'acc.'
prefix = ''

# Create an array with the colors you want to use
colors = ['r', 'b', 'g', 'm', 'grey', 'orange', 'olive', 'teal', 'skyblue', 'navy']
# Set your custom color palette
customPalette = sns.set_palette(sns.color_palette(colors))

ax = sns.scatterplot(data=combined_data, x=x_lbl, y=y_lbl,
                alpha=0.7,
                palette=customPalette,
                linewidth=0,
                markers=level2marker,
                style=checkpoint_name,
                style_order=style_order,
                s=200,
                hue_order=hue_order,
                hue=model_name)


plot.xlabel(prefix + r'Ability ($\theta$)-Dim2', fontsize=font_label)
plot.ylabel(r'Average Model Accuracy', fontsize=font_label)
plot.legend(
    borderaxespad=0,
    loc="right center",
    ncol=2,
    bbox_to_anchor=(1,1),
    fontsize=font_legend,
    title_fontsize=font_legendtitle,
    markerscale=marker_scale,
)
f.tight_layout()

plot.savefig("../plots/acc_" + prefix + "theta1.png",
                format='png',dpi=300,bbox_inches = 'tight',
                pad_inches = .1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
combined_data

In [None]:
font_label = 14
font_legend = 14
font_legendtitle = font_legend + 4
font_xtick = 14
font_title = 14
marker_scale = 1.5

sns.set_theme(style="whitegrid")

keys = ['acc.', 'lower', 'upper', 'theta', 'var', 'log_mean', model_name, checkpoint_name]
combined_data = pd.concat([df_acc, param_plot_stats['t'], df_model, df_mlevel], axis=1)
combined_data = combined_data.set_axis(keys, axis=1)
hue_order = ['roberta-med-small-1M', 'roberta-base-10M', 'roberta-base-100M', 'roberta-base-1B', 'bert-base-cased', 'bert-large-cased',
            'roberta-base', 'roberta-large', 'xlm-roberta-large', 'albert-xxlarge-v2']
hue_order = [model2plot[name] for name in hue_order]
style_order = [r'1%', r'10%', r'25%', r'50%', 'best']

level2marker = {
    r'1%':'o',
    r'10%':'s',
    r'25%':'P',
    r'50%':'X',
    'best':'^',
}

sizes = [5, 50, 100, 200, 400]
size_order = [r'1%', r'10%', r'25%', r'50%', 'best']

f, ax = plot.subplots(figsize=(14, 4))
sns.despine()

x_lbl, y_lbl = 'theta', 'acc.'
prefix = ''

# Create an array with the colors you want to use
colors = ['r', 'b', 'g', 'm', 'grey', 'orange', 'olive', 'teal', 'skyblue', 'navy']
# Set your custom color palette
customPalette = sns.set_palette(sns.color_palette(colors))

ax = sns.scatterplot(data=combined_data, x=x_lbl, y=y_lbl,
                alpha=0.7,
                palette=customPalette,
                linewidth=0,
                markers=level2marker,
                style=checkpoint_name,
                style_order=style_order,
                s=200,
                hue_order=hue_order,
                hue=model_name)


plot.xlabel(prefix + r'Ability ($\theta$)', fontsize=font_label)
plot.ylabel(r'Average Model Accuracy', fontsize=font_label)
plot.legend(
    borderaxespad=0,
    loc="right center",
    ncol=2,
    bbox_to_anchor=(1,1),
    fontsize=font_legend,
    title_fontsize=font_legendtitle,
    markerscale=marker_scale,
)
f.tight_layout()

plot.savefig("../plots/acc_" + prefix + "theta.png",
                format='png',dpi=300,bbox_inches = 'tight',
                pad_inches = .1)

In [None]:
df_model_perf = pd.read_csv('model_performance.csv')
df_model_perf = df_model_perf.melt(id_vars=["Task"])
df_model_perf = df_model_perf.set_axis(["task", "model", "accuracy"], axis=1)
df = df_model_perf[df_model_perf.model.isin(['roberta-large', 'roberta-med-small-1M', 'albert-xxlarge-v2'])]
df = df.replace({"roberta-large": "RoBERTa-Large", "roberta-med-small-1M": "RoBERTa-Med-Small-1M-2", "albert-xxlarge-v2":"ALBERT-XXL-v2"})
df

In [None]:
font_label = 24
font_legend = 24
font_legendtitle = font_legend + 4
font_xtick = 24
font_title = 24
marker_scale = 1.5

sns.set_theme(style="whitegrid")

# Create an array with the colors you want to use
colors = ['navy', 'b', 'skyblue', 'olive', 'teal', 'g', 'navy']
# Set your custom color palette
customPalette = sns.set_palette(sns.color_palette(colors))


f, ax = plot.subplots(figsize=(30, 7))
order = task_order
sns.barplot(x="task", y="accuracy", hue="model", data=df, order=order)
plot.xticks(range(len(order)), order, rotation=45, fontsize=font_xtick)
sns.despine()

plot.xlabel('Tasks', fontsize=font_label)
plot.xlabel(None)
plot.ylabel('Model Performance', fontsize=font_label)

# Shrink current axis's height by 10% on the bottom
box = ax.get_position()
ax.set_position([box.x0, box.y0 + box.height * 0.1,
                 box.width, box.height * 0.9])

# Put a legend below current axis
ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.2),
          fancybox=True, shadow=True, ncol=5, prop=dict(size=24))

f.tight_layout()

plot.savefig("../plots/results.png",
                format='png',dpi=300,bbox_inches = 'tight',
                pad_inches = .1)