In [25]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}


<IPython.core.display.Javascript object>

In [26]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams
from pprint import pprint

In [27]:
%matplotlib inline

In [28]:
TEXT_COLOUR = {
    'PURPLE':'\033[95m',
    'CYAN':'\033[96m',
    'DARKCYAN':'\033[36m',
    'BLUE':'\033[94m',
    'GREEN':'\033[92m',
    'YELLOW':'\033[93m',
    'RED':'\033[91m',
    'BOLD':'\033[1m',
    'UNDERLINE':'\033[4m',
    'END':'\033[0m'
}

def print_bold(*msgs):
    print(TEXT_COLOUR['BOLD'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def print_green(*msgs):
    print(TEXT_COLOUR['GREEN'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def print_error(*msgs):
    print(TEXT_COLOUR['RED'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def wrap_green(msg):
    return TEXT_COLOUR['GREEN'] + msg + TEXT_COLOUR['END']

def wrap_red(msg):
    return TEXT_COLOUR['RED'] + msg + TEXT_COLOUR['END']

def up_down_str(val):
    msg = str(val)
    if val > 0:
        msg = wrap_green(msg)
    elif val < 0:
        msg = wrap_red(msg)
    return msg

In [104]:
tasks = ["CoLA", "MNLI", "MNLI-MM", "MRPC", "QNLI", "QQP", "RTE", "SST-2", "STS-B", "WNLI"]

metrics = {
    "CoLA":["mcc"],
    "MNLI":["acc"],
    "MNLI-MM":["acc"],
    "MRPC":["acc", "f1", "acc_and_f1"],
    "QNLI":["acc"],
    "QQP":["acc", "f1", "acc_and_f1"],
    "RTE":["acc"],
    "SST-2":["acc"],
    "STS-B":["pearson", "spearmanr", "corr"],
    "WNLI":["acc"]
}

for task in tasks:
    print(metrics[task])

['mcc']
['acc']
['acc']
['acc', 'f1', 'acc_and_f1']
['acc']
['acc', 'f1', 'acc_and_f1']
['acc']
['acc']
['pearson', 'spearmanr', 'corr']
['acc']


In [105]:
def find_best(exp):

    results = {}

    for task in tasks:
        task_results = {}
        task_metrics = metrics[task]
        for metric in task_metrics:
            per_lr = []
            for i in range(1,6):
                # base metrics

                f=open(f"../exp_results/baseline/{exp}/{task}/{i}e-5_{metric}.txt", "r")
                lines = f.read().splitlines()
                reported = []
                for line in lines:
                    val = float(line.split('\t')[1])
                    if val != 0:
                        reported.append(val)

                if len(reported) != 0:
                    reported.sort(reverse = True)
                    candidates = [reported[0]]
                    for j in range(1, len(reported)):
                        if reported[j] > 0.9 * reported[0]:
                            candidates.append(reported[j])
                        
                        
                    per_lr.append(np.mean(candidates))
                    print(task, metric, i, candidates)
                    print_green(candidates[0] - candidates[-1])
                else:
                    per_lr.append(0)


            task_results[metric] = str(np.argmax(per_lr) + 1) + f" - {np.max(per_lr)}"
        results[task] = task_results
        
    return results

In [106]:
exp='bert-base'

pprint(find_best(exp))

reported_in_paper = {
    "CoLA":0.521,
    "MNLI":0.846,
    "MNLI-MM":0.834,
    "MRPC":0.889,
    "QNLI":0.905,
    "QQP":0.712,
    "RTE":0.664,
    "SST-2":0.935,
    "STS-B":0.858,
    "WNLI":0.651 #temp
}

CoLA mcc 1 [0.5598962230809679, 0.557314902649449]
[92m
0.0025813204315188187
[0m
CoLA mcc 2 [0.572700939299986, 0.5574138665741355]
[92m
0.015287072725850503
[0m
CoLA mcc 3 [0.5864941797290588, 0.5627810283916928]
[92m
0.023713151337366067
[0m
CoLA mcc 4 [0.5782027446370238, 0.5689445008605172]
[92m
0.009258243776506636
[0m
CoLA mcc 5 [0.5558086597524818, 0.5327292010480984]
[92m
0.023079458704383438
[0m
MNLI acc 1 [0.8428935303107489, 0.8399388690779419]
[92m
0.0029546612328069655
[0m
MNLI acc 2 [0.8429954151808456, 0.8422822210901681]
[92m
0.0007131940906774936
[0m
MNLI acc 3 [0.8456444218033622, 0.838920020376974]
[92m
0.0067244014263881935
[0m
MNLI acc 4 [0.8402445236882323, 0.8358634742740703]
[92m
0.004381049414162064
[0m
MNLI acc 5 [0.8332144676515537, 0.83015792154865]
[92m
0.0030565461029037344
[0m
MNLI-MM acc 1 [0.8475386493083807, 0.842860048820179]
[92m
0.0046786004882017895
[0m
MNLI-MM acc 2 [0.8513018714401953, 0.8477420667209113]
[92m
0.0035598047

In [107]:
exp='bert-large'

pprint(find_best(exp))

reported_in_paper = {
    "CoLA":0.521,
    "MNLI":0.846,
    "MNLI-MM":0.834,
    "MRPC":0.889,
    "QNLI":0.905,
    "QQP":0.712,
    "RTE":0.664,
    "SST-2":0.935,
    "STS-B":0.858,
    "WNLI":0.651 #temp
}

CoLA mcc 1 [0.6034053904179437, 0.598349211554662, 0.5934998758283692, 0.5933815828411364, 0.5858661515147512]
[92m
0.017539238903192533
[0m
CoLA mcc 2 [0.6314267706887338, 0.6196504222850292, 0.6159937744560633, 0.6132520976952389, 0.6082287537654698]
[92m
0.023198016923263953
[0m
CoLA mcc 3 [0.6306425398187112, 0.6234124534066398, 0.6109671005868332, 0.6039307167689609]
[92m
0.026711823049750283
[0m
CoLA mcc 4 [0.6240838974095094, 0.5943397045380183, 0.5936105573332983, 0.5930546484314362]
[92m
0.031029248978073243
[0m
CoLA mcc 5 [0.5486883812731442]
[92m
0.0
[0m
MNLI acc 1 [0.8653082017320428, 0.8639836984207845, 0.8635761589403973, 0.8633723892002038, 0.8615384615384616]
[92m
0.003769740193581228
[0m
MNLI acc 2 [0.8624554253693326, 0.8621497707590423, 0.8621497707590423, 0.8603158430973, 0.8488028527763627]
[92m
0.013652572592969925
[0m
MNLI acc 3 [0.8606214977075904, 0.8577687213448802]
[92m
0.0028527763627101965
[0m
MNLI acc 4 [0.8494141619969434, 0.83993886907794