In [1]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}


<IPython.core.display.Javascript object>

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams
from pprint import pprint

In [3]:
%matplotlib inline

In [4]:
TEXT_COLOUR = {
    'PURPLE':'\033[95m',
    'CYAN':'\033[96m',
    'DARKCYAN':'\033[36m',
    'BLUE':'\033[94m',
    'GREEN':'\033[92m',
    'YELLOW':'\033[93m',
    'RED':'\033[91m',
    'BOLD':'\033[1m',
    'UNDERLINE':'\033[4m',
    'END':'\033[0m'
}

def print_bold(*msgs):
    print(TEXT_COLOUR['BOLD'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def print_green(*msgs):
    print(TEXT_COLOUR['GREEN'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def print_error(*msgs):
    print(TEXT_COLOUR['RED'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def wrap_green(msg):
    return TEXT_COLOUR['GREEN'] + msg + TEXT_COLOUR['END']

def wrap_red(msg):
    return TEXT_COLOUR['RED'] + msg + TEXT_COLOUR['END']

def up_down_str(val):
    msg = str(val)
    if val > 0:
        msg = wrap_green(msg)
    elif val < 0:
        msg = wrap_red(msg)
    return msg

In [5]:
tasks = ["CoLA","SST-2","MRPC","STS-B","QQP","MNLI","QNLI","RTE","WNLI"]



metrics = {
    "CoLA":["mcc"],
    "MNLI":["acc"],
    "MNLI-MM":["acc"],
    "MRPC":["acc", "f1", "acc_and_f1"],
    "QNLI":["acc"],
    "QQP":["acc", "f1", "acc_and_f1"],
    "RTE":["acc"],
    "SST-2":["acc"],
    "STS-B":["pearson", "spearmanr", "corr"],
    "WNLI":["acc"]
}

for task in tasks:
    print(metrics[task])

['mcc']
['acc']
['acc', 'f1', 'acc_and_f1']
['pearson', 'spearmanr', 'corr']
['acc', 'f1', 'acc_and_f1']
['acc']
['acc']
['acc']
['acc']


In [6]:
def find_best(exp):

    results = {}

    for task in tasks:
        task_results = {}
        task_metrics = metrics[task]
        for metric in task_metrics:
            per_lr = []
            for i in range(1,6):
                # base metrics

                f=open(f"../exp_results/baseline/{exp}/{task}/{i}e-5_{metric}.txt", "r")
                lines = f.read().splitlines()
                reported = []
                for line in lines:
                    val = float(line.split('\t')[1])
                    if val != 0:
                        reported.append(val)

                if len(reported) != 0:
                    reported.sort(reverse = True)
                    candidates = [reported[0]]
                    for j in range(1, len(reported)):
                        if reported[j] > 0.9 * reported[0]:
                            candidates.append(reported[j])
                        
                        
                    per_lr.append(np.mean(candidates))
                    print(task, metric, i,'(', len(candidates), ')', [round(cand, 5) for cand in candidates])
                    print('\t', np.mean(candidates), candidates[0] - candidates[-1])
                else:
                    per_lr.append(0)


            task_results[metric] = str(np.argmax(per_lr) + 1) + f" - {np.max(per_lr)}"
            print(task_results[metric])
        results[task] = task_results
        
    return results

# bert

In [7]:
exp='bert-base'

pprint(find_best(exp))

reported_in_paper = {
    "CoLA":0.521,
    "MNLI":0.846,
    "MNLI-MM":0.834,
    "MRPC":0.889,
    "QNLI":0.905,
    "QQP":0.712,
    "RTE":0.664,
    "SST-2":0.935,
    "STS-B":0.858,
    "WNLI":0.651 #temp
}

CoLA mcc 1 ( 5 ) [0.5599, 0.55983, 0.55731, 0.55241, 0.55206]
	 0.5563013962481641 0.007836097511614182
CoLA mcc 2 ( 5 ) [0.58065, 0.57645, 0.5727, 0.55741, 0.55738]
	 0.5689188131688764 0.023266208114516962
CoLA mcc 3 ( 5 ) [0.60593, 0.58669, 0.58649, 0.56278, 0.55982]
	 0.5803439334714199 0.04610497633691435
CoLA mcc 4 ( 5 ) [0.60616, 0.59871, 0.586, 0.5782, 0.56894]
	 0.5876033745283135 0.03721607535321125
CoLA mcc 5 ( 5 ) [0.5887, 0.57902, 0.56683, 0.55581, 0.53273]
	 0.5646156934852402 0.05596778554160087
4 - 0.5876033745283135
SST-2 acc 1 ( 5 ) [0.9289, 0.92775, 0.92661, 0.92202, 0.91743]
	 0.9245412844036698 0.011467889908256867
SST-2 acc 2 ( 5 ) [0.93119, 0.9289, 0.9289, 0.9289, 0.91514]
	 0.9266055045871561 0.016055045871559592
SST-2 acc 3 ( 5 ) [0.92775, 0.92775, 0.92317, 0.92202, 0.91972]
	 0.9240825688073395 0.008027522935779796
SST-2 acc 4 ( 5 ) [0.92661, 0.92546, 0.91514, 0.91399, 0.91399]
	 0.9190366972477065 0.012614678899082521
SST-2 acc 5 ( 5 ) [0.91743, 0.91743, 0.91

In [8]:
exp='bert-large'

pprint(find_best(exp))

reported_in_paper = {
    "CoLA":0.521,
    "MNLI":0.846,
    "MNLI-MM":0.834,
    "MRPC":0.889,
    "QNLI":0.905,
    "QQP":0.712,
    "RTE":0.664,
    "SST-2":0.935,
    "STS-B":0.858,
    "WNLI":0.651 #temp
}

CoLA mcc 1 ( 5 ) [0.60341, 0.59835, 0.5935, 0.59338, 0.58587]
	 0.5949004424313725 0.017539238903192533
CoLA mcc 2 ( 5 ) [0.63143, 0.61965, 0.61599, 0.61325, 0.60823]
	 0.617710363778107 0.023198016923263953
CoLA mcc 3 ( 4 ) [0.63064, 0.62341, 0.61097, 0.60393]
	 0.6172382026452863 0.026711823049750283
CoLA mcc 4 ( 4 ) [0.62408, 0.59434, 0.59361, 0.59305]
	 0.6012722019280655 0.031029248978073243
CoLA mcc 5 ( 1 ) [0.54869]
	 0.5486883812731442 0.0
2 - 0.617710363778107
SST-2 acc 1 ( 5 ) [0.93807, 0.93807, 0.93349, 0.93234, 0.93005]
	 0.9344036697247706 0.008027522935779796
SST-2 acc 2 ( 5 ) [0.94495, 0.93463, 0.93119, 0.92661, 0.92431]
	 0.9323394495412843 0.020642201834862428
SST-2 acc 3 ( 5 ) [0.93463, 0.93349, 0.9289, 0.92546, 0.92317]
	 0.9291284403669724 0.011467889908256867
SST-2 acc 4 ( 2 ) [0.92775, 0.91628]
	 0.9220183486238531 0.011467889908256867
SST-2 acc 5 ( 5 ) [0.50917, 0.50917, 0.50917, 0.50917, 0.50917]
	 0.5091743119266054 0.0
1 - 0.9344036697247706
MRPC acc 1 ( 5 ) [

# xlnet

In [9]:
exp='xlnet-base'

pprint(find_best(exp))

reported_in_paper = {
    "CoLA":0.521,
    "MNLI":0.846,
    "MNLI-MM":0.834,
    "MRPC":0.889,
    "QNLI":0.905,
    "QQP":0.712,
    "RTE":0.664,
    "SST-2":0.935,
    "STS-B":0.858,
    "WNLI":0.651 #temp
}

CoLA mcc 1 ( 4 ) [0.54016, 0.5175, 0.49857, 0.49409]
	 0.5125804485422553 0.04606546427881364
CoLA mcc 2 ( 4 ) [0.56126, 0.53248, 0.53109, 0.51082]
	 0.5339126606130538 0.05043562890809916
CoLA mcc 3 ( 2 ) [0.46996, 0.45258]
	 0.46127014184819704 0.017377751215785164
CoLA mcc 4 ( 1 ) [0.46917]
	 0.46916929434481164 0.0
2 - 0.5339126606130538
SST-2 acc 1 ( 5 ) [0.94266, 0.94151, 0.94037, 0.93922, 0.93922]
	 0.9405963302752293 0.0034403669724770714
SST-2 acc 2 ( 5 ) [0.9461, 0.93922, 0.93807, 0.93349, 0.93119]
	 0.9376146788990827 0.014908256880733939
SST-2 acc 3 ( 5 ) [0.94495, 0.94151, 0.93693, 0.93578, 0.93463]
	 0.9387614678899083 0.010321100917431214
SST-2 acc 4 ( 4 ) [0.93463, 0.93119, 0.92775, 0.92775]
	 0.9303325688073394 0.006880733944954143
SST-2 acc 5 ( 5 ) [0.91858, 0.91743, 0.91399, 0.9117, 0.9094]
	 0.9142201834862386 0.00917431192660556
1 - 0.9405963302752293
MRPC acc 1 ( 5 ) [0.8799, 0.8799, 0.87745, 0.8701, 0.85784]
	 0.8730392156862745 0.022058823529411797
MRPC acc 2 ( 

In [10]:
exp='xlnet-large'

pprint(find_best(exp))

reported_in_paper = {
    "CoLA":0.521,
    "MNLI":0.846,
    "MNLI-MM":0.834,
    "MRPC":0.889,
    "QNLI":0.905,
    "QQP":0.712,
    "RTE":0.664,
    "SST-2":0.935,
    "STS-B":0.858,
    "WNLI":0.651 #temp
}

CoLA mcc 1 ( 1 ) [0.59932]
	 0.5993174841218436 0.0
1 - 0.5993174841218436
SST-2 acc 1 ( 5 ) [0.95183, 0.95183, 0.95069, 0.94839, 0.91284]
	 0.9431192660550458 0.03899082568807344
SST-2 acc 2 ( 3 ) [0.94266, 0.93463, 0.86124]
	 0.9128440366972477 0.08142201834862384
SST-2 acc 3 ( 5 ) [0.50917, 0.50917, 0.50917, 0.50917, 0.50917]
	 0.5091743119266054 0.0
SST-2 acc 4 ( 1 ) [0.75688]
	 0.7568807339449541 0.0
SST-2 acc 5 ( 5 ) [0.50917, 0.50917, 0.50917, 0.50917, 0.50917]
	 0.5091743119266054 0.0
1 - 0.9431192660550458
MRPC acc 1 ( 5 ) [0.90196, 0.89461, 0.88725, 0.8848, 0.8701]
	 0.8877450980392156 0.031862745098039214
MRPC acc 2 ( 4 ) [0.89461, 0.87745, 0.875, 0.8652]
	 0.8780637254901961 0.02941176470588236
MRPC acc 3 ( 1 ) [0.875]
	 0.875 0.0
MRPC acc 4 ( 5 ) [0.68382, 0.68382, 0.68382, 0.68382, 0.68382]
	 0.6838235294117647 0.0
MRPC acc 5 ( 5 ) [0.68382, 0.68382, 0.68382, 0.68382, 0.68382]
	 0.6838235294117647 0.0
1 - 0.8877450980392156
MRPC f1 1 ( 5 ) [0.93031, 0.92599, 0.9215, 0.919