In [1]:
from edge_probing_utils import (
    JiantDatasetSingleSpan,
    JiantDatasetTwoSpan
    )
import edge_probing as ep
import torch
import torch.nn as nn
import torch.utils.data as data
from transformers import AutoModel, AutoTokenizer


**Setup:**

In [2]:
tasks = [
    #"ner", 
    #"semeval",
    "coref",
    #"sup-squad",
    #"ques",
    #"sup-babi",
    #"sup-hotpot",
    ]

task_types = {
    "ner": "single_span", 
    "semeval": "single_span",
    "coref": "two_span",
    "sup-squad": "two_span",
    "ques": "single_span",
    "sup-babi": "two_span",
    "sup-hotpot": "two_span",
    }

models = [
    "bert-base-uncased", 
    #"csarron/bert-base-uncased-squad-v1",
    ]

task_labels_to_ids = {
    "ner": {'ORDINAL': 0, 'DATE': 1, 'PERSON': 2, 'LOC': 3, 'GPE': 4, 'QUANTITY': 5, 'ORG': 6, 'WORK_OF_ART': 7, 'CARDINAL': 8, 'TIME': 9, 'MONEY': 10, 'LANGUAGE': 11, 'NORP': 12, 'PERCENT': 13, 'EVENT': 14, 'LAW': 15, 'FAC': 16, 'PRODUCT': 17},
    "coref": {"0": 0, "1": 1},
    "semeval": {'Component-Whole(e2,e1)': 0, 'Other': 1, 'Instrument-Agency(e2,e1)': 2, 'Member-Collection(e1,e2)': 3, 'Entity-Destination(e1,e2)': 4, 'Content-Container(e1,e2)': 5, 'Message-Topic(e1,e2)': 6, 'Cause-Effect(e2,e1)': 7, 'Product-Producer(e2,e1)': 8, 'Member-Collection(e2,e1)': 9, 'Entity-Origin(e1,e2)': 10, 'Cause-Effect(e1,e2)': 11, 'Component-Whole(e1,e2)': 12, 'Message-Topic(e2,e1)': 13, 'Product-Producer(e1,e2)': 14, 'Entity-Origin(e2,e1)': 15, 'Content-Container(e2,e1)': 16, 'Instrument-Agency(e1,e2)': 17, 'Entity-Destination(e2,e1)': 18},
    "sup-squad": {"0": 0, "1": 1},
    "ques": {'LOC:other': 0, 'DESC:desc': 1, 'DESC:def': 2, 'DESC:manner': 3, 'ENTY:sport': 4, 'ENTY:termeq': 5, 'HUM:ind': 6, 'NUM:count': 7, 'DESC:reason': 8, 'LOC:country': 9, 'HUM:desc': 10, 'ENTY:animal': 11, 'ENTY:other': 12, 'LOC:city': 13, 'ENTY:cremat': 14, 'NUM:perc': 15, 'NUM:money': 16, 'NUM:date': 17, 'ENTY:dismed': 18, 'LOC:state': 19, 'NUM:speed': 20, 'HUM:gr': 21, 'NUM:dist': 22, 'ENTY:food': 23, 'ABBR:abb': 24, 'ENTY:product': 25, 'HUM:title': 26, 'NUM:weight': 27, 'ABBR:exp': 28, 'ENTY:veh': 29, 'NUM:period': 30, 'ENTY:religion': 31, 'ENTY:letter': 32, 'ENTY:color': 33, 'ENTY:body': 34, 'ENTY:event': 35, 'ENTY:substance': 36, 'ENTY:instru': 37, 'ENTY:plant': 38, 'ENTY:symbol': 39, 'NUM:other': 40, 'LOC:mount': 41, 'NUM:temp': 42, 'ENTY:techmeth': 43, 'NUM:code': 44, 'ENTY:word': 45, 'ENTY:lang': 46, 'NUM:volsize': 47, 'NUM:ord': 48, 'ENTY:currency': 49},
    "sup-babi": {"0": 0, "1": 1},
    }
    
model_to_linestyle = {
    "bert-base-uncased": ":g", 
    "csarron/bert-base-uncased-squad-v1": "-y",
}

# Pick a dataset size: small, medium, big.
# With the way training works, a smaller size will only improve the speed of tokenizing,
# training speed will be the same for all sizes.
size = "medium"

In [3]:
import os
import json
import matplotlib.pyplot as plt

loss_function = nn.BCELoss()
batch_size = 32
num_layers = range(1,13,2)
num_workers = 0

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")

# Disable warnings.
os.environ["TOKENIZERS_PARALLELISM"] = "false"

os.makedirs(f"../results/", exist_ok=True)
if os.path.isfile(f'../results/{size}.json'):
  with open(f'../results/{size}.json', 'r') as f:
    results = json.load(f)
else:
  results = {}

for model in models:
    tokenizer = AutoTokenizer.from_pretrained(model)
    for task in tasks:
        task_results = results.setdefault(task, {})
        labels_to_ids = task_labels_to_ids[task]
        train_data = ep.tokenize_jiant_dataset(
            tokenizer,
            *(ep.read_jiant_dataset(f"../data/{task}/{size}/train.jsonl")),
            labels_to_ids,
            max_seq_length=384
            )
        val_data = ep.tokenize_jiant_dataset(
            tokenizer,
            *(ep.read_jiant_dataset(f"../data/{task}/{size}/val.jsonl")),
            labels_to_ids,
            max_seq_length=384
            )
        if task_types[task] == "single_span":
            train_data = JiantDatasetSingleSpan(train_data)
            val_data = JiantDatasetSingleSpan(val_data)
        elif task_types[task] == "two_span":
            train_data = JiantDatasetTwoSpan(train_data)
            val_data = JiantDatasetTwoSpan(val_data)
        train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)
        val_loader = data.DataLoader(val_data, batch_size=batch_size, pin_memory=True, num_workers=num_workers)
        results[task][model] = ep.probing(ep.ProbeConfig(
                train_loader,
                val_loader,
                val_loader,
                model,
                num_layers,
                loss_function,
                labels_to_ids,
                task_types[task],
                lr=0.0001,
                max_evals=30,
                eval_interval=100,
                dev=device,
                ))

with open(f'../results/{size}.json', 'w') as f:
    json.dump(results, f)

Reading ../data/coref/medium/train.jsonl


HBox(children=(FloatProgress(value=0.0, max=2647.0), HTML(value='')))


Tokenizing


HBox(children=(FloatProgress(value=0.0, max=5007.0), HTML(value='')))


Reading ../data/coref/medium/val.jsonl


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


Tokenizing


HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))


Probing model bert-base-uncased
Probing layer 1 of 11
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertEdgeProbingTwoSpan: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'bert.encoder.layer.1.attention.self.query.weight', 'bert.encoder.layer.1.attention.self.query.bias', 'bert.encoder.layer.1.attention.self.key.weight', 'bert.encoder.layer.1.attention.self.key.bias', 'bert.encoder.layer.1.attention.self.value.weight', 'bert.encoder.layer.1.attention.self.value.bias', 'bert.encoder.layer.1.attention.output.dense.weight', 'bert.encoder.layer.1.attention.output.dense.bias', 'bert.encoder.layer.1.intermediate.dense.weight', 'bert.encoder.layer.1.interm

HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5586018237200651



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5574835159561851



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5554650344631888



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5538204203952443



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5527683144265955



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5497362369840796



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.548951734196056



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5466697920452465



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5472771769220178



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5584562447938052



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5720803873105482



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.58539494601163



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6000145240263506
No improvement for 5 epochs, halving the learning rate to 5e-05



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6064392545006492



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6167203919454054



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6233113516460765



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6257196231321855



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6312588182362643
No improvement for 5 epochs, halving the learning rate to 2.5e-05



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6345860307866876



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.637100569226525



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6366344473578713



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6402349174022675



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6437926184047352
No improvement for 5 epochs, halving the learning rate to 1.25e-05



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6446441411972046



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6441342261704531



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6464507850733671



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6469613394953988



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6473534459417517
No improvement for 5 epochs, halving the learning rate to 6.25e-06



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.647428732026707



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.6486047045751051

Training is finished
Testing the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Test loss: 0.6486047045751051, accuracy: 0.6770833351395347, f1_score: 0.5640123397281402
Probing layer 3 of 11
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertEdgeProbingTwoSpan: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.3.attention.self.value.bias', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.3.attention.output.dense.bias', 'bert.encoder.layer.

HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5558960221030496



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5525514754382047



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5506249395283785



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5494681190360676



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5455421209335327



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5390696661038832



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5299753492528742



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5295200700109656



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5160476782105186



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5164862220937555



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5126265287399292



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5302747704766013



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5047715252096002



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5207321779294447



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5237474983388727



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5166623754934832



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5214354802261699



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5306058309294961
No improvement for 5 epochs, halving the learning rate to 5e-05



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5294611074707725



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5362413986162706



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5389026430520144



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5408958440477197



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5419205616820942
No improvement for 5 epochs, halving the learning rate to 2.5e-05



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5441968684846704



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5434718050739982



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5371848263523795



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5434406508098949



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5384325222535566
No improvement for 5 epochs, halving the learning rate to 1.25e-05



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5395391475070607



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5397242470221086

Training is finished
Testing the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Test loss: 0.5397242470221086, accuracy: 0.7428030317479913, f1_score: 0.6583902485296909
Probing layer 5 of 11
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertEdgeProbingTwoSpan: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.

HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


Loss: 0.5569149066101421



HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

Evaluating the model


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))




KeyboardInterrupt: 

In [6]:
with open(f'../results/{size}.json', 'r') as f:
    results = json.load(f)
for task, task_results in results.items():
    for model, _ in task_results.items():
        ep.plot_task(results, task, model, model_to_linestyle[model], num_layers)
    plt.show()

KeyError: 1