In [201]:
import sys
sys.path.insert(0,'..')
import json
from abc import abstractmethod, ABC
from collections import OrderedDict
from logging import Logger
from typing import List
from tqdm import tqdm
from transformers import BertTokenizer

from spert import util, models, prediction,  sampling
from spert.entities import Dataset, EntityType, RelationType, Entity, Relation, Document
from spert.opt import spacy
from spert.evaluator import Evaluator
from spert.input_reader import JsonInputReader, BaseInputReader
from spert.loss import SpERTLoss, Loss
from spert.trainer import BaseTrainer
from spert.models import SpERT,SpROB, SpLONG

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from logging import raiseExceptions
import math
import os
from pathlib import Path
from typing import Type
from unittest import case

import torch
from torch.nn import DataParallel
from torch.optim import Optimizer
import transformers
from torch.utils.data import DataLoader
from transformers import AdamW, BertConfig,RobertaConfig,LongformerConfig  
from transformers import BertTokenizer,RobertaTokenizer,LongformerTokenizer
from transformers import BertModel, RobertaModel,LongformerModel
from transformers import BertPreTrainedModel, RobertaPreTrainedModel,LongformerPreTrainedModel


# Log Extraction

We find all log directories that appear both in the data/save and data/log directories. We then extract arguments (args.json) of the run and last epoch macro/micro scores (last row of eval_valid.csv). If the eval_valid.csv file is missing we discard the run altogether. 

The code below assembles a pandas data frame of all run IDs that are names of both log and save directories.

In [202]:
LOGS = Path('../data/log')
SAVES = Path('../data/save')

logLabels = {x.name : x for x in LOGS.iterdir() if x.is_dir()}
#print(logLabels)
logRunLabels = {x.name : L for L,D in logLabels.items() for x in D.iterdir() if x.is_dir()}
logRunPaths = {x.name :  x for L,D in logLabels.items() for x in D.iterdir() if x.is_dir()}
runDF = pd.DataFrame([logRunLabels.keys(),logRunLabels.values(),logRunPaths.values()],columns=logRunLabels.keys(),index=['run','label','logPath']).T
saveLabels = {x.name : x for x in SAVES.iterdir() if x.is_dir()}
#print(logLabels)
saveRunLabels = {x.name : L for L,D in saveLabels.items() for x in D.iterdir() if x.is_dir() if x.name in logRunLabels.keys()}
saveRunPaths =  {x.name :  x for L,D in saveLabels.items() for x in D.iterdir() if x.is_dir() if x.name in logRunLabels.keys()}
runDF = runDF.join(pd.DataFrame([saveRunPaths.keys(),saveRunPaths.values()],columns=logRunLabels.keys(),index=['runSave','savePath']).T,how='left').drop(columns = ['runSave'])
runDF

Unnamed: 0,run,label,logPath,savePath
2022-03-22_10.17.46.145076,2022-03-22_10.17.46.145076,scierc_bert_train,../data/log/scierc_bert_train/2022-03-22_10.17...,../data/save/scierc_bert_train/2022-03-22_10.1...
2022-03-22_10.46.00.854943,2022-03-22_10.46.00.854943,scierc_bert_train,../data/log/scierc_bert_train/2022-03-22_10.46...,../data/save/scierc_bert_train/2022-03-22_10.4...
2022-03-22_09.51.47.931037,2022-03-22_09.51.47.931037,scierc_bert_train,../data/log/scierc_bert_train/2022-03-22_09.51...,../data/save/scierc_bert_train/2022-03-22_09.5...
2022-03-22_17.10.24.480132,2022-03-22_17.10.24.480132,scierc_rob_train,../data/log/scierc_rob_train/2022-03-22_17.10....,../data/save/scierc_rob_train/2022-03-22_17.10...
2022-03-22_17.58.14.707227,2022-03-22_17.58.14.707227,scierc_rob_train,../data/log/scierc_rob_train/2022-03-22_17.58....,../data/save/scierc_rob_train/2022-03-22_17.58...
...,...,...,...,...
2022-03-22_18.41.01.219059,2022-03-22_18.41.01.219059,scierc_elec_train,../data/log/scierc_elec_train/2022-03-22_18.41...,../data/save/scierc_elec_train/2022-03-22_18.4...
2022-03-22_20.15.48.486074,2022-03-22_20.15.48.486074,scierc_elec_train,../data/log/scierc_elec_train/2022-03-22_20.15...,../data/save/scierc_elec_train/2022-03-22_20.1...
2022-03-23_11.07.38.225160,2022-03-23_11.07.38.225160,scierc_elec_train,../data/log/scierc_elec_train/2022-03-23_11.07...,../data/save/scierc_elec_train/2022-03-23_11.0...
2022-03-22_20.32.00.719556,2022-03-22_20.32.00.719556,scierc_elec_train,../data/log/scierc_elec_train/2022-03-22_20.32...,../data/save/scierc_elec_train/2022-03-22_20.3...


For each of the log directories selected above, we extract the arguments (`args.json`) and the eval_valid.csv. We discard directories that don't contain eval_valid.csv (indicating incomplete runs). We create a data frame with all arguments, run info from the previous data frame and the scores of the last run. There is one row for each Run-ID. 

We show some of the columns below

In [203]:
argList= []
for run in runDF.itertuples():
    with open(run.logPath.joinpath('args.json')) as A:
        D = json.load(A)
    D['label']=run.label
    D['logPath'] = run.logPath
    D['savePath'] = run.savePath
    D['runID'] = run.run
    
    evPath = run.logPath.joinpath('eval_valid.csv')
    if evPath.exists():
        D.update(pd.read_csv(evPath,sep=';').iloc[-1].to_dict())
        argList.append(D)
#data/log/scierc_rob_train/2022-03-22_16.49.15.916860/eval_valid.csv
argDF = pd.DataFrame.from_dict(argList).drop(columns=['store_predictions',	'store_examples','tokenizer_path']).reset_index()
argDF[['label','runID','ner_f1_macro','rel_f1_macro','rel_nec_f1_macro','train_batch_size', 'epochs', 'neg_entity_count', 'neg_relation_count', 'lr', 'weight_decay',  'lowercase', 'model_path', 'rel_filter_threshold', 'prop_drop']]


Unnamed: 0,label,runID,ner_f1_macro,rel_f1_macro,rel_nec_f1_macro,train_batch_size,epochs,neg_entity_count,neg_relation_count,lr,weight_decay,lowercase,model_path,rel_filter_threshold,prop_drop
0,scierc_bert_train,2022-03-22_10.17.46.145076,69.027011,47.519995,39.019091,4,20,100,100,0.00006,0.02,True,allenai/scibert_scivocab_uncased,0.40,0.15
1,scierc_bert_train,2022-03-22_10.46.00.854943,67.227290,42.934329,30.853080,4,20,100,100,0.00006,0.02,False,bert-base-cased,0.40,0.15
2,scierc_bert_train,2022-03-22_09.51.47.931037,70.308508,49.589145,38.588136,4,20,100,100,0.00006,0.02,False,allenai/scibert_scivocab_cased,0.40,0.15
3,scierc_rob_train,2022-03-22_17.10.24.480132,67.316155,42.998732,30.314092,4,20,100,100,0.00004,0.15,False,allenai/biomed_roberta_base,0.40,0.25
4,scierc_rob_train,2022-03-22_17.58.14.707227,64.967055,42.488559,30.390785,8,40,125,100,0.00005,0.15,False,allenai/biomed_roberta_base,0.40,0.25
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
76,scierc_elec_train,2022-03-22_18.41.01.219059,66.484670,42.780381,32.376277,4,20,100,100,0.00006,0.01,False,kamalkraj/bioelectra-base-discriminator-pubmed,0.50,0.10
77,scierc_elec_train,2022-03-22_20.15.48.486074,67.948684,42.290068,30.977630,4,20,125,100,0.00006,0.02,False,kamalkraj/bioelectra-base-discriminator-pubmed,0.50,0.10
78,scierc_elec_train,2022-03-23_11.07.38.225160,66.066154,42.043701,33.120071,4,20,100,100,0.00006,0.01,False,google/electra-base-discriminator,0.50,0.10
79,scierc_elec_train,2022-03-22_20.32.00.719556,66.274939,40.568351,28.273242,4,20,125,100,0.00006,0.02,False,google/electra-base-discriminator,0.35,0.10


In [204]:
def best_run(df=None,groupingLabel = 'label', 
    maxMetrics = ['ner_f1_macro','rel_f1_macro',	'rel_nec_f1_macro']):
    '''
    Function to extract best runs from a pandas dataframe of all runs 
    found in the standard directories. 
    '''
    res = pd.DataFrame()
    for metric in maxMetrics:
        idx=df.groupby(by='label')[metric].idxmax()
        df1 = df.loc[idx,['label','runID']+maxMetrics]
        df1['maximize'] = metric
        res = pd.concat([res,df1],axis=0)
    return res

The best runs are identified below. The label is the one assigned to each run that tracks dataset and model type. The metrics shown are those based on which we make the choice of best run. The column "maximize" indicates which of the list of metrics was used to select the row. 

In [205]:
best = best_run(argDF)
best

Unnamed: 0,label,runID,ner_f1_macro,rel_f1_macro,rel_nec_f1_macro,maximize
36,conll04_bert_train,2022-03-24_15.33.05.939120,83.609981,67.751994,67.751994,ner_f1_macro
2,scierc_bert_train,2022-03-22_09.51.47.931037,70.308508,49.589145,38.588136,ner_f1_macro
77,scierc_elec_train,2022-03-22_20.15.48.486074,67.948684,42.290068,30.97763,ner_f1_macro
8,scierc_rob_train,2022-03-23_17.42.00.075533,68.228957,43.401293,31.771205,ner_f1_macro
49,conll04_bert_train,2022-03-24_15.12.57.875259,82.941415,69.115147,68.814395,rel_f1_macro
2,scierc_bert_train,2022-03-22_09.51.47.931037,70.308508,49.589145,38.588136,rel_f1_macro
72,scierc_elec_train,2022-03-22_18.25.20.499416,66.454233,43.435618,35.455048,rel_f1_macro
12,scierc_rob_train,2022-03-23_17.24.48.920799,67.421825,45.794171,33.401826,rel_f1_macro
49,conll04_bert_train,2022-03-24_15.12.57.875259,82.941415,69.115147,68.814395,rel_nec_f1_macro
0,scierc_bert_train,2022-03-22_10.17.46.145076,69.027011,47.519995,39.019091,rel_nec_f1_macro


We now look at the columns of this dataframe and identify colummn names that have more than one value. If all the rows have the same value in a column, then the column is not interesting in terms of hyperparameter selection. We then select most significant columns from the original dataframe for the experiments that appear in the best-list.

In [206]:
varCols = [col for col in argDF.columns if len(set(argDF[col])) > 1 ]
lmt = varCols.index('label') # we discard anything on the right of column label (addidional metrics)
print(varCols)
argDF.iloc[list(set(best.index))][['label','runID','ner_f1_macro','rel_f1_macro','rel_nec_f1_macro']+varCols[:lmt]].drop(columns=['config','model_type','index'])

['index', 'train_path', 'valid_path', 'final_eval', 'train_batch_size', 'epochs', 'neg_entity_count', 'neg_relation_count', 'lr', 'weight_decay', 'config', 'types_path', 'lowercase', 'model_path', 'model_type', 'rel_filter_threshold', 'prop_drop', 'label', 'logPath', 'savePath', 'runID', 'ner_prec_micro', 'ner_rec_micro', 'ner_f1_micro', 'ner_prec_macro', 'ner_rec_macro', 'ner_f1_macro', 'rel_prec_micro', 'rel_rec_micro', 'rel_f1_micro', 'rel_prec_macro', 'rel_rec_macro', 'rel_f1_macro', 'rel_nec_prec_micro', 'rel_nec_rec_micro', 'rel_nec_f1_micro', 'rel_nec_prec_macro', 'rel_nec_rec_macro', 'rel_nec_f1_macro', 'epoch', 'global_iteration']


Unnamed: 0,label,runID,ner_f1_macro,rel_f1_macro,rel_nec_f1_macro,train_path,valid_path,final_eval,train_batch_size,epochs,neg_entity_count,neg_relation_count,lr,weight_decay,types_path,lowercase,model_path,rel_filter_threshold,prop_drop
0,scierc_bert_train,2022-03-22_10.17.46.145076,69.027011,47.519995,39.019091,data/datasets/scierc/scierc_train.json,data/datasets/scierc/scierc_dev.json,False,4,20,100,100,6e-05,0.02,data/datasets/scierc/scierc_types.json,True,allenai/scibert_scivocab_uncased,0.4,0.15
2,scierc_bert_train,2022-03-22_09.51.47.931037,70.308508,49.589145,38.588136,data/datasets/scierc/scierc_train.json,data/datasets/scierc/scierc_dev.json,False,4,20,100,100,6e-05,0.02,data/datasets/scierc/scierc_types.json,False,allenai/scibert_scivocab_cased,0.4,0.15
36,conll04_bert_train,2022-03-24_15.33.05.939120,83.609981,67.751994,67.751994,data/datasets/conll04/conll04_train.json,data/datasets/conll04/conll04_dev.json,False,4,20,100,100,1e-05,0.005,data/datasets/conll04/conll04_types.json,False,bert-base-cased,0.4,0.1
8,scierc_rob_train,2022-03-23_17.42.00.075533,68.228957,43.401293,31.771205,data/datasets/scierc/scierc_train.json,data/datasets/scierc/scierc_dev.json,False,4,20,100,100,5e-05,0.15,data/datasets/scierc/scierc_types.json,False,allenai/biomed_roberta_base,0.4,0.2
72,scierc_elec_train,2022-03-22_18.25.20.499416,66.454233,43.435618,35.455048,data/datasets/scierc/scierc_train.json,data/datasets/scierc/scierc_dev.json,False,4,20,100,100,6e-05,0.01,data/datasets/scierc/scierc_types.json,False,google/electra-base-discriminator,0.5,0.1
12,scierc_rob_train,2022-03-23_17.24.48.920799,67.421825,45.794171,33.401826,data/datasets/scierc/scierc_train.json,data/datasets/scierc/scierc_dev.json,False,4,20,100,100,5e-05,0.15,data/datasets/scierc/scierc_types.json,False,allenai/biomed_roberta_base,0.4,0.2
77,scierc_elec_train,2022-03-22_20.15.48.486074,67.948684,42.290068,30.97763,data/datasets/scierc/scierc_train.json,data/datasets/scierc/scierc_dev.json,False,4,20,125,100,6e-05,0.02,data/datasets/scierc/scierc_types.json,False,kamalkraj/bioelectra-base-discriminator-pubmed,0.5,0.1
14,scierc_rob_train,2022-03-22_13.01.05.757922,67.571642,44.340069,35.433664,data/datasets/scierc/scierc_train.json,data/datasets/scierc/scierc_dev.json,False,4,20,100,100,5e-05,0.2,data/datasets/scierc/scierc_types.json,False,allenai/biomed_roberta_base,0.4,0.2
49,conll04_bert_train,2022-03-24_15.12.57.875259,82.941415,69.115147,68.814395,data/datasets/conll04/conll04_train.json,data/datasets/conll04/conll04_dev.json,False,4,20,100,100,1e-05,0.01,data/datasets/conll04/conll04_types.json,False,bert-base-cased,0.5,0.1


## CoNLL04 - BERT Hyperparameter Search - Chronological Sequence

In [207]:
df = argDF.loc[argDF.label == 'conll04_bert_train'][['label','runID','ner_f1_macro','rel_f1_macro','rel_nec_f1_macro']+varCols[:lmt]].drop(columns=['config','model_type','index','train_path','valid_path','types_path','final_eval','lowercase','model_path'])
df.sort_values('runID')


Unnamed: 0,label,runID,ner_f1_macro,rel_f1_macro,rel_nec_f1_macro,train_batch_size,epochs,neg_entity_count,neg_relation_count,lr,weight_decay,rel_filter_threshold,prop_drop
29,conll04_bert_train,2022-03-24_10.11.21.546339,74.78162,55.913825,55.660661,4,20,100,100,5e-06,0.1,0.4,0.0
43,conll04_bert_train,2022-03-24_10.17.07.397146,66.974673,43.934888,43.934888,4,20,100,100,5e-06,0.1,0.4,0.1
59,conll04_bert_train,2022-03-24_10.22.57.368149,66.277413,44.126256,44.126256,4,20,100,100,5e-06,0.1,0.4,0.2
55,conll04_bert_train,2022-03-24_10.29.25.335308,69.179068,45.10944,45.10944,4,20,100,100,5e-06,0.1,0.5,0.0
32,conll04_bert_train,2022-03-24_10.36.00.136194,68.78528,49.93047,49.93047,4,20,100,100,5e-06,0.1,0.5,0.1
33,conll04_bert_train,2022-03-24_10.42.34.301900,68.654684,46.802987,46.521296,4,20,100,100,5e-06,0.1,0.5,0.2
46,conll04_bert_train,2022-03-24_10.48.56.001071,72.806643,51.894814,51.894814,4,20,100,100,5e-06,0.05,0.4,0.0
35,conll04_bert_train,2022-03-24_10.55.08.682239,74.892555,56.301172,56.301172,4,20,100,100,5e-06,0.05,0.4,0.1
62,conll04_bert_train,2022-03-24_11.01.42.482450,67.870696,43.257922,43.257922,4,20,100,100,5e-06,0.05,0.4,0.2
25,conll04_bert_train,2022-03-24_11.08.27.152037,76.396685,59.023859,59.023859,4,20,100,100,5e-06,0.05,0.5,0.0


## CoNLL04 - BERT Hyperparameter Search - Order by NER-F1-Macro

In [208]:
df.sort_values('ner_f1_macro')

Unnamed: 0,label,runID,ner_f1_macro,rel_f1_macro,rel_nec_f1_macro,train_batch_size,epochs,neg_entity_count,neg_relation_count,lr,weight_decay,rel_filter_threshold,prop_drop
59,conll04_bert_train,2022-03-24_10.22.57.368149,66.277413,44.126256,44.126256,4,20,100,100,5e-06,0.1,0.4,0.2
34,conll04_bert_train,2022-03-24_13.19.11.692779,66.600443,46.26013,46.26013,4,20,100,100,5e-06,0.0,0.5,0.2
65,conll04_bert_train,2022-03-24_13.13.03.499623,66.731984,43.861935,43.861935,4,20,100,100,5e-06,0.0,0.5,0.1
43,conll04_bert_train,2022-03-24_10.17.07.397146,66.974673,43.934888,43.934888,4,20,100,100,5e-06,0.1,0.4,0.1
51,conll04_bert_train,2022-03-24_12.48.18.157291,67.588311,45.52721,45.52721,4,20,100,100,5e-06,0.0,0.4,0.0
18,conll04_bert_train,2022-03-24_12.21.26.863312,67.833949,45.234442,45.234442,4,20,100,100,5e-06,0.005,0.4,0.2
62,conll04_bert_train,2022-03-24_11.01.42.482450,67.870696,43.257922,43.257922,4,20,100,100,5e-06,0.05,0.4,0.2
57,conll04_bert_train,2022-03-24_11.41.44.084588,68.230394,43.612457,43.612457,4,20,100,100,5e-06,0.01,0.4,0.2
21,conll04_bert_train,2022-03-24_11.35.04.621984,68.332713,44.123794,44.123794,4,20,100,100,5e-06,0.01,0.4,0.1
33,conll04_bert_train,2022-03-24_10.42.34.301900,68.654684,46.802987,46.521296,4,20,100,100,5e-06,0.1,0.5,0.2


## CoNLL04 - BERT Hyperparameter Search - Order by Rel-F1-Macro

In [209]:
df.sort_values('rel_f1_macro')


Unnamed: 0,label,runID,ner_f1_macro,rel_f1_macro,rel_nec_f1_macro,train_batch_size,epochs,neg_entity_count,neg_relation_count,lr,weight_decay,rel_filter_threshold,prop_drop
62,conll04_bert_train,2022-03-24_11.01.42.482450,67.870696,43.257922,43.257922,4,20,100,100,5e-06,0.05,0.4,0.2
57,conll04_bert_train,2022-03-24_11.41.44.084588,68.230394,43.612457,43.612457,4,20,100,100,5e-06,0.01,0.4,0.2
65,conll04_bert_train,2022-03-24_13.13.03.499623,66.731984,43.861935,43.861935,4,20,100,100,5e-06,0.0,0.5,0.1
43,conll04_bert_train,2022-03-24_10.17.07.397146,66.974673,43.934888,43.934888,4,20,100,100,5e-06,0.1,0.4,0.1
21,conll04_bert_train,2022-03-24_11.35.04.621984,68.332713,44.123794,44.123794,4,20,100,100,5e-06,0.01,0.4,0.1
59,conll04_bert_train,2022-03-24_10.22.57.368149,66.277413,44.126256,44.126256,4,20,100,100,5e-06,0.1,0.4,0.2
55,conll04_bert_train,2022-03-24_10.29.25.335308,69.179068,45.10944,45.10944,4,20,100,100,5e-06,0.1,0.5,0.0
18,conll04_bert_train,2022-03-24_12.21.26.863312,67.833949,45.234442,45.234442,4,20,100,100,5e-06,0.005,0.4,0.2
19,conll04_bert_train,2022-03-24_12.54.43.839527,68.683081,45.281489,45.281489,4,20,100,100,5e-06,0.0,0.4,0.1
51,conll04_bert_train,2022-03-24_12.48.18.157291,67.588311,45.52721,45.52721,4,20,100,100,5e-06,0.0,0.4,0.0


## CoNLL04 - BERT Hyperparameter Search - Order by Rel-Nec-F1-Macro (Joint Rel)

In [210]:
df.sort_values('rel_nec_f1_macro')


Unnamed: 0,label,runID,ner_f1_macro,rel_f1_macro,rel_nec_f1_macro,train_batch_size,epochs,neg_entity_count,neg_relation_count,lr,weight_decay,rel_filter_threshold,prop_drop
62,conll04_bert_train,2022-03-24_11.01.42.482450,67.870696,43.257922,43.257922,4,20,100,100,5e-06,0.05,0.4,0.2
57,conll04_bert_train,2022-03-24_11.41.44.084588,68.230394,43.612457,43.612457,4,20,100,100,5e-06,0.01,0.4,0.2
65,conll04_bert_train,2022-03-24_13.13.03.499623,66.731984,43.861935,43.861935,4,20,100,100,5e-06,0.0,0.5,0.1
43,conll04_bert_train,2022-03-24_10.17.07.397146,66.974673,43.934888,43.934888,4,20,100,100,5e-06,0.1,0.4,0.1
21,conll04_bert_train,2022-03-24_11.35.04.621984,68.332713,44.123794,44.123794,4,20,100,100,5e-06,0.01,0.4,0.1
59,conll04_bert_train,2022-03-24_10.22.57.368149,66.277413,44.126256,44.126256,4,20,100,100,5e-06,0.1,0.4,0.2
55,conll04_bert_train,2022-03-24_10.29.25.335308,69.179068,45.10944,45.10944,4,20,100,100,5e-06,0.1,0.5,0.0
18,conll04_bert_train,2022-03-24_12.21.26.863312,67.833949,45.234442,45.234442,4,20,100,100,5e-06,0.005,0.4,0.2
19,conll04_bert_train,2022-03-24_12.54.43.839527,68.683081,45.281489,45.281489,4,20,100,100,5e-06,0.0,0.4,0.1
51,conll04_bert_train,2022-03-24_12.48.18.157291,67.588311,45.52721,45.52721,4,20,100,100,5e-06,0.0,0.4,0.0
