In [1]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}


<IPython.core.display.Javascript object>

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams

In [3]:
%matplotlib inline

In [4]:
TEXT_COLOUR = {
    'PURPLE':'\033[95m',
    'CYAN':'\033[96m',
    'DARKCYAN':'\033[36m',
    'BLUE':'\033[94m',
    'GREEN':'\033[92m',
    'YELLOW':'\033[93m',
    'RED':'\033[91m',
    'BOLD':'\033[1m',
    'UNDERLINE':'\033[4m',
    'END':'\033[0m'
}

def print_bold(*msgs):
    print(TEXT_COLOUR['BOLD'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def print_green(*msgs):
    print(TEXT_COLOUR['GREEN'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def print_error(*msgs):
    print(TEXT_COLOUR['RED'])
    print(*msgs)
    print(TEXT_COLOUR['END'])

def wrap_green(msg):
    return TEXT_COLOUR['GREEN'] + msg + TEXT_COLOUR['END']

def wrap_red(msg):
    return TEXT_COLOUR['RED'] + msg + TEXT_COLOUR['END']

def up_down_str(val):
    msg = str(val)
    if val > 0:
        msg = wrap_green(msg)
    elif val < 0:
        msg = wrap_red(msg)
    return msg

In [5]:
exp='bert-large'
num_layers = 24

In [6]:
tasks = ["CoLA","SST-2","MRPC","STS-B","QQP","MNLI", "MNLI-MM", "QNLI", "RTE"]

metrics = {
    "CoLA":["mcc"],
    "MNLI":["acc"],
    "MNLI-MM":["acc"],
    "MRPC":["f1"],
    "QNLI":["acc"],
    "QQP":["f1"],
    "RTE":["acc"],
    "SST-2":["acc"],
    "STS-B":["spearmanr"],
    "WNLI":["acc"] #temp
}

reported_in_paper = {
    "CoLA":0.00,
    "MNLI":0.00,
    "MNLI-MM":0.0,
    "MRPC":0.00,
    "QNLI":0.00,
    "QQP":0.00,
    "RTE":0.00,
    "SST-2":0.00,
    "STS-B":0.00,
    "WNLI":0.00
}

In [7]:

def get_average_val(lines):
    reported = []
    for line in lines:
        print('\t', line)
        val = float(line.split('\t')[1])
        if val != 0:
            reported.append(val)
    out = 0
    if len(reported) != 0:
        reported.sort(reverse = True)
        candidates = [reported[0]]
        for j in range(1, len(reported)):
            if reported[j] > 0.9 * reported[0]:
                candidates.append(reported[j])
        out = np.mean(candidates)
        
    return out


In [8]:
results = {}

for task in tasks:
    task_results = {}
    task_metrics = metrics[task]
    for metric in task_metrics:
        
        # base metrics
        print(f"../exp_results/{exp}/{task}/base-{metric}.txt")
        f=open(f"../exp_results/{exp}/{task}/base-{metric}.txt", "r")
        lines = f.read().splitlines()
        task_results[f'base-{metric}'] = get_average_val(lines)
        
        # no layer metrics
        
        fine_tuning_metrics = []
        f=open(f"../exp_results/{exp}/{task}/no_layer-{metric}.txt", "r")

        lines = f.read().splitlines()
        fine_tuning_metrics.append(get_average_val(lines))
        
        # fine-tuned metrics
        
        log_file_prefix=''
        for i in reversed(range(int(num_layers/2), num_layers)):
            print(i)
            log_file_prefix += str(i)
            f=open(f"../exp_results/{exp}/{task}/{log_file_prefix}-{metric}.txt", "r")
            lines = f.read().splitlines()
            fine_tuning_metrics.append(get_average_val(lines))
            
            log_file_prefix +='_'
        
        task_results[f'{metric}'] = list(reversed(fine_tuning_metrics))
        
    results[task] = task_results

../exp_results/bert-large/CoLA/base-mcc.txt
	 1	0.6132520976952389
	 2	0.6082287537654698
	 3	0.6196504222850292
	 4	0.6159937744560633
	 5	0.6314267706887338
	 6	0.6307435718195655
	 7	0.6063017263508896
	 8	0.6132520976952389
	 9	0.6236307572313473
	 10	0.6239164540479953
	 1	0.21560140866635025
	 2	0.2439588933160818
	 3	0.21725141763919706
	 4	0.23501301708307232
	 5	0.2439588933160818
	 6	0.24316892062216255
	 7	0.23951602463962363
	 8	0.22581591101652074
	 9	0.23415050823672504
	 10	0.2643569071590391
23
	 1	0.3853198145814999
	 2	0.38910318028075963
	 3	0.3729828208660848
	 4	0.3825675441068272
	 5	0.39141351655952766
	 6	0.40756536011513605
	 7	0.38853046110587136
	 8	0.3943819600110244
	 9	0.4004653712114985
	 10	0.39733191583974115
22
	 1	0.4296458950606047
	 2	0.4100506577314648
	 3	0.40923059850264826
	 4	0.42097986053121045
	 5	0.42393201453251056
	 6	0.4297412045484628
	 7	0.4325679591343375
	 8	0.41855144278510453
	 9	0.41217715910488273
	 10	0.42696977338948733
21
	 1	0

	 1	0.6931407942238267
	 2	0.7292418772563177
	 3	0.7256317689530686
	 5	0.6859205776173285
	 4	0.7220216606498195
	 6	0.740072202166065
	 7	0.7256317689530686
	 8	0.7328519855595668
	 9	0.7364620938628159
	 10	0.7364620938628159
	 1	0.5523465703971119
	 2	0.5667870036101083
	 3	0.555956678700361
	 4	0.5595667870036101
	 5	0.5740072202166066
	 6	0.5740072202166066
	 7	0.5667870036101083
	 8	0.5595667870036101
	 9	0.5667870036101083
	 10	0.5667870036101083
23
	 1	0.5812274368231047
	 2	0.5740072202166066
	 3	0.5703971119133574
	 4	0.5848375451263538
	 5	0.5667870036101083
	 6	0.5848375451263538
	 7	0.5956678700361011
	 8	0.5776173285198556
	 9	0.5848375451263538
	 10	0.5848375451263538
22
	 1	0.6101083032490975
	 2	0.6137184115523465
	 3	0.6209386281588448
	 4	0.5992779783393501
	 5	0.5956678700361011
	 6	0.628158844765343
	 7	0.5992779783393501
	 8	0.5956678700361011
	 9	0.6173285198555957
	 10	0.6173285198555957
21
	 1	0.6209386281588448
	 2	0.628158844765343
	 3	0.6137184115523465
	 

In [9]:
x_axis = []

for i in range(int(num_layers/2), num_layers):
    x_axis.append(str(i))

x_axis.append("none")

In [10]:
def draw_graph(task, y_label, paper, base, reported):
    plt.figure(figsize=(10,6))
    plt.plot(x_axis, reported)
    
    plt.xlabel("layers")
    plt.ylabel(y_label)
    
    if paper == 0.0:    
        gap = max(reported) - min(reported)
        top = max(max(reported), base) + (gap*0.2)
        bottom = min(min(reported), base) - (gap*0.2)
    
        plt.ylim(bottom, top)

        plt.axhline(y=base, linestyle='--', c='green')
    else:
        gap = max(reported) - min(reported)
        top = max(max(reported), base, paper) + (gap*0.2)
        bottom = min(min(reported), base, paper) - (gap*0.2)
    
        plt.ylim(bottom, top)

        plt.axhline(y=base, linestyle='--', c='green')
        plt.axhline(y=paper, linestyle='--', c='red')
    
    plt.title(f'{exp}-{task} ({round(base,4)})')
    plt.savefig(f'images/{exp}/{task}', format='png', bbox_inches='tight')
    plt.show()

In [11]:
for task in tasks:
    task_results = results[task]
    task_metrics = metrics[task]
    for metric in task_metrics:
        reported = task_results[metric]
        base = task_results[f'base-{metric}']
        print_bold(task, metric, round(base * 100, 2), round(task_results[metric][0] * 100, 2), round(task_results[metric][-1] * 100, 2))

[1m
CoLA mcc 61.86 59.12 24.7
[0m
[1m
SST-2 acc 93.41 93.33 87.79
[0m
[1m
MRPC f1 90.34 88.9 81.34
[0m
[1m
STS-B spearmanr 89.77 89.03 71.57
[0m
[1m
QQP f1 88.33 87.48 70.81
[0m
[1m
MNLI acc 86.35 86.25 58.87
[0m
[1m
MNLI-MM acc 86.16 85.93 60.23
[0m
[1m
QNLI acc 92.23 92.25 74.84
[0m
[1m
RTE acc 72.27 71.6 56.43
[0m


In [12]:
import copy 

layer_90 = []
layer_95 = []

threshold_90 = 0.9
threshold_95 = 0.95
x_axis.reverse()

for task in tasks:
#     print_bold(task)
    task_results = results[task]
    task_metrics = metrics[task]
    for metric in task_metrics:
        base = task_results[f'base-{metric}']
        reported = copy.deepcopy(task_results[metric])
        reported.reverse()
#         print(reported)
        
        flag_90 = True
        flag_95 = True
        
        for ind, val in enumerate(reported):
#             print(val/base)

#             if task != "CoLA":
            if val/base > threshold_90 and flag_90:
                flag_90 = False
                layer_90.append(ind)

            if val/base > threshold_95 and flag_95:
                flag_95 = False
                layer_95.append(ind)
        
        if flag_90:
            print(task, "Fails to achieve 90% threshold", reported[-1]/base)
            layer_90.append(len(reported)-1)
            
        if flag_95:
            print(task, "Fails to achieve 95% threshold", reported[-1]/base)
            layer_95.append(len(reported)-1)


            
print(x_axis)
            
            
print(layer_90)
min_layer_ind_90 = max(layer_90)
print("layer_90 ", min_layer_ind_90, 'layer:', x_axis[min_layer_ind_90], round((1-(min_layer_ind_90/num_layers)) * 100, 2), '%')

print(layer_95)
min_layer_ind_95 = max(layer_95)
print("layer_95 ", min_layer_ind_95, 'layer:', x_axis[min_layer_ind_95], round((1-(min_layer_ind_95/num_layers)) * 100, 2), '%')
            
    
for task in tasks:
    task_results = results[task]
    task_metrics = metrics[task]
    for metric in task_metrics:
        base = task_results[f'base-{metric}']
        reported = copy.deepcopy(task_results[metric])
        reported.reverse()
        print_bold(task, base)
        print('\t90', reported[min_layer_ind_90], round(reported[min_layer_ind_90]/base * 100, 2))
        print('\t95', reported[min_layer_ind_95], round(reported[min_layer_ind_95]/base * 100, 2))

['none', '23', '22', '21', '20', '19', '18', '17', '16', '15', '14', '13', '12']
[8, 0, 0, 1, 1, 2, 2, 1, 6]
layer_90  8 layer: 16 66.67 %
[11, 1, 8, 2, 3, 5, 5, 4, 10]
layer_95  11 layer: 13 54.17 %
[1m
CoLA 0.6186396426035572
[0m
	90 0.560072055512793 90.53
	95 0.589130068619958 95.23
[1m
SST-2 0.9340596330275229
[0m
	90 0.9315366972477065 99.73
	95 0.9332568807339451 99.91
[1m
MRPC 0.9034378599451618
[0m
	90 0.8596944237203612 95.16
	95 0.8831302353091584 97.75
[1m
STS-B 0.8977256044874211
[0m
	90 0.8858842674792656 98.68
	95 0.8906436027892216 99.21
[1m
QQP 0.8832536038016956
[0m
	90 0.8665966930987468 98.11
	95 0.8730054673451821 98.84
[1m
MNLI 0.8635082356936662
[0m
	90 0.8545491594498216 98.96
	95 0.8613007301749024 99.74
[1m
MNLI-MM 0.8615574993219419
[0m
	90 0.8506712774613507 98.74
	95 0.8579976946026581 99.59
[1m
QNLI 0.9222954420647994
[0m
	90 0.918488010250778 99.59
	95 0.9232106900970163 100.1
[1m
RTE 0.7227436823104694
[0m
	90 0.683754512635379 94.61
	9