In [1]:
from mg_lb.eval.prep import prep_for_eval, eval_data_set, network_sentences
from mg_lb.eval.visualize import sort_results
from mg_lb.eval.losses import f1_multi_layer, f1_ace

### Evaluating a dataset on a pretrained model from the paper


In [2]:
problem = 'ACE05' # one of ['ontonotes', 'ACE05']
dset = 'test.txt' # one of ['train.txt', 'val.txt', 'test.txt']
save_name = 'ACE05_83.10' # one of ['ACE05_83.10', 'ACE05_74.87', 'ontonotes_87.59', 'ontonotes_89.25']

# Load the model
model, vocab, args, prob_dict = prep_for_eval(save_name, paper_model=True, gpu=True)

# Run forward pass on dataset
res_all, iter_all = eval_data_set(problem, dset, args, vocab, model, prob_dict, batch_size=2000)

# Process the returns from the network
res_out = sort_results(res_all, iter_all, vocab, problem, cutoff=0.25)

# Calculate F1 score
if problem == 'ACE05':
    prec, rec, f1, wrong, acc_wrong, cl_wrong = f1_ace(res_out['cluster_preds'], res_out['labels'], 
                                                       remove_np=True)
elif problem == 'ontonotes':
    prec, rec, f1, wrong = f1_multi_layer(res_out['cluster_preds'], res_out['labels'], add_bs=True,
                                          remove_np=True)
    
print('F1:', "{:.4f}".format(f1))

Building model...
643 out of 15862 words unknown
Restoring model from checkpoint...
F1: 0.8302


### Running a trained model on a list of sentences

In [None]:
problem = 'ACE05' # one of ['ontonotes', 'ACE05']
save_name = 'ACE05_83.10' # one of ['ACE05_83.10', 'ACE05_74.87', 'ontonotes_87.59', 'ontonotes_89.25']

# Load the model
model, vocab, args, prob_dict = prep_for_eval(save_name, paper_model=True, gpu=True)

# Input sentences
sentences = ['The Prime Minister of the UK met with Donald Trump on Thursday.']

# Run forward pass
res_sents, iter_sents = network_sentences(sentences, problem, args, vocab, model, prob_dict=prob_dict, gpu=True)

# Process returns
res_out = sort_results(res_sents, iter_sents, vocab, problem, cutoff=0.25)

Building model...


In [None]:
# Nested structure of the entities. Dict keys denote layers
res_out['cluster_words']

In [None]:
# Labels for each cluster
res_out['cluster_preds']

### Evaluating a new trained model/ new dataset