In [25]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}


<IPython.core.display.Javascript object>

In [26]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams

In [27]:
%matplotlib inline

In [28]:
TEXT_COLOUR = {
    'PURPLE':'\033[95m',
    'CYAN':'\033[96m',
    'DARKCYAN':'\033[36m',
    'BLUE':'\033[94m',
    'GREEN':'\033[92m',
    'YELLOW':'\033[93m',
    'RED':'\033[91m',
    'BOLD':'\033[1m',
    'UNDERLINE':'\033[4m',
    'END':'\033[0m'
}

def print_bold(*msgs):
    print(TEXT_COLOUR['BOLD'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def print_green(*msgs):
    print(TEXT_COLOUR['GREEN'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def print_error(*msgs):
    print(TEXT_COLOUR['RED'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def wrap_green(msg):
    return TEXT_COLOUR['GREEN'] + msg + TEXT_COLOUR['END']

def wrap_red(msg):
    return TEXT_COLOUR['RED'] + msg + TEXT_COLOUR['END']

def up_down_str(val):
    msg = str(val)
    if val > 0:
        msg = wrap_green(msg)
    elif val < 0:
        msg = wrap_red(msg)
    return msg

In [29]:
exp='xlnet-base'
num_layers = 12

In [30]:
tasks = ["CoLA","SST-2","MRPC","STS-B","QQP","MNLI", "MNLI-MM", "QNLI","RTE"]

metrics = {
    "CoLA":["mcc"],
    "MNLI":["acc"],
    "MNLI-MM":["acc"],
    "MRPC":["f1"],
    "QNLI":["acc"],
    "QQP":["f1"],
    "RTE":["acc"],
    "SST-2":["acc"],
    "STS-B":["spearmanr"],
    "WNLI":["acc"] #temp
}

reported_in_paper = {
    "CoLA":0.00,
    "MNLI":0.00,
    "MNLI-MM":0.0,
    "MRPC":0.00,
    "QNLI":0.00,
    "QQP":0.00,
    "RTE":0.00,
    "SST-2":0.00,
    "STS-B":0.00,
    "WNLI":0.00
}

In [31]:

def get_average_val(lines):
    reported = []
    for line in lines:
        print('\t', line)
        val = float(line.split('\t')[1])
        if val != 0:
            reported.append(val)
    out = 0
    if len(reported) != 0:
        reported.sort(reverse = True)
        candidates = [reported[0]]
        for j in range(1, len(reported)):
            if reported[j] > 0.9 * reported[0]:
                candidates.append(reported[j])
        out = np.mean(candidates)
        
    return out


In [32]:
results = {}

for task in tasks:
    task_results = {}
    task_metrics = metrics[task]
    for metric in task_metrics:
        
        # base metrics
        print(f"../exp_results/{exp}/{task}/base-{metric}.txt")
        f=open(f"../exp_results/{exp}/{task}/base-{metric}.txt", "r")
        lines = f.read().splitlines()
        task_results[f'base-{metric}'] = get_average_val(lines)
        
        # no layer metrics
        
        fine_tuning_metrics = []
        print(f"../exp_results/{exp}/{task}/no_layer-{metric}.txt")
        f=open(f"../exp_results/{exp}/{task}/no_layer-{metric}.txt", "r")

        lines = f.read().splitlines()
        fine_tuning_metrics.append(get_average_val(lines))
        
        # fine-tuned metrics
        
        log_file_prefix=''
        for i in reversed(range(int(num_layers/2), num_layers)):
            log_file_prefix += str(i)
            f=open(f"../exp_results/{exp}/{task}/{log_file_prefix}-{metric}.txt", "r")
            lines = f.read().splitlines()
            print(i)
            fine_tuning_metrics.append(get_average_val(lines))
            
            log_file_prefix +='_'
        
        task_results[f'{metric}'] = list(reversed(fine_tuning_metrics))
        
    results[task] = task_results

../exp_results/xlnet-base/CoLA/base-mcc.txt
	 1	0.5416659867634687
	 2	0.4731954585829745
	 3	0.5164853025028764
	 4	0.5209209891115821
../exp_results/xlnet-base/CoLA/no_layer-mcc.txt
	 1	0.0
	 2	0.0
	 3	0.0
	 4	0.0
11
	 1	0.0592680243795702
	 2	0.058936921430734465
	 3	0.0863794254719202
	 4	0.08397945416660721
10
	 1	0.1258429813539947
	 2	0.1335596923323897
	 3	0.1258429813539947
	 4	0.05873054109498616
9
	 1	0.16376842992883034
	 2	0.1400546339527412
	 3	0.15781395596714062
	 4	0.14697190112289202
8
	 1	0.20455476991235036
	 2	0.1344399207394031
	 3	0.22337588766250718
	 4	0.1701806891694245
7
	 1	0.23143820141120688
	 2	0.2013844279302642
	 3	0.22275989469950308
	 4	0.21888224820199895
6
	 1	0.2976643739099074
	 2	0.30520979611347604
	 3	0.2636672353215384
	 4	0.2680137267269912
../exp_results/xlnet-base/SST-2/base-acc.txt
	 1	0.9438073394495413
	 2	0.948394495412844
	 3	0.9461009174311926
	 4	0.9438073394495413
../exp_results/xlnet-base/SST-2/no_layer-acc.txt
	 1	0.80389908256880

In [33]:
x_axis = []

for i in range(int(num_layers/2), num_layers):
    x_axis.append(str(i))

x_axis.append("none")

In [34]:
def draw_graph(task, y_label, paper, base, reported):
    plt.figure(figsize=(10,6))
    plt.plot(x_axis, reported)
    
    plt.xlabel("layers")
    plt.ylabel(y_label)
    
    if paper == 0.0:    
        gap = max(reported) - min(reported)
        top = max(max(reported), base) + (gap*0.2)
        bottom = min(min(reported), base) - (gap*0.2)
    
        plt.ylim(bottom, top)

        plt.axhline(y=base, linestyle='--', c='green')
    else:
        gap = max(reported) - min(reported)
        top = max(max(reported), base, paper) + (gap*0.2)
        bottom = min(min(reported), base, paper) - (gap*0.2)
    
        plt.ylim(bottom, top)

        plt.axhline(y=base, linestyle='--', c='green')
        plt.axhline(y=paper, linestyle='--', c='red')
    
    plt.title(f'{exp}-{task} ({round(base,4)})')
    plt.savefig(f'images/{exp}/{task}', format='png', bbox_inches='tight')
    plt.show()

In [35]:
for task in tasks:
    task_results = results[task]
    task_metrics = metrics[task]
    for metric in task_metrics:
        reported = task_results[metric]
        base = task_results[f'base-{metric}']
        print_bold(task, metric)
        print(f"\tbase : {round(base * 100, 2)}")
        print(f"\t50% : {round(task_results[metric][0] * 100, 2)}")
        print(f"\tnone : {round(task_results[metric][-1] * 100, 2)}")
#         draw_graph(task, metric, reported_in_paper[task], base, reported)

[1m
CoLA mcc
[0m
	base : 52.64
	50% : 30.14
	none : 0
[1m
SST-2 acc
[0m
	base : 94.55
	50% : 91.74
	none : 80.22
[1m
MRPC f1
[0m
	base : 90.51
	50% : 91.02
	none : 82.75
[1m
STS-B spearmanr
[0m
	base : 88.99
	50% : 86.07
	none : 66.12
[1m
QQP f1
[0m
	base : 83.94
	50% : 79.78
	none : 65.66
[1m
MNLI acc
[0m
	base : 84.41
	50% : 80.79
	none : 50.34
[1m
MNLI-MM acc
[0m
	base : 84.43
	50% : 81.04
	none : 51.12
[1m
QNLI acc
[0m
	base : 91.29
	50% : 89.8
	none : 67.79
[1m
RTE acc
[0m
	base : 69.95
	50% : 71.12
	none : 58.48
