In [1]:
import pickle

global eval_label
with open('./test_label.pkl', 'rb') as f:
    eval_label = pickle.load(f)

In [2]:
from inference_roberta_v5 import *
tmp_preds = inference([0.44, 0.44], eval_label)
print(len(tmp_preds), len(eval_label))
print(f'All predicted: {np.isin(tmp_preds.keys(), eval_label.keys()).mean()}')

predictions = {}
for k,v in eval_label.items():
    predictions[k] = tmp_preds[k]

Some weights of the model checkpoint at roberta-large were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  0%|          | 0/7652 [00:00<?, ?it/s]

7652 7652
All predicted: 1.0


In [6]:
from nltk.stem.porter import *
from rouge import Rouge 
from tqdm.auto import tqdm
import numpy as np

rouge = Rouge()
stemmer = PorterStemmer()

def eval_score(reference, hypothesis):

    stem_h = [stemmer.stem(re.sub(r'[^\w\s]',' ',x)) for x in hypothesis]
    stem_h = list(set(stem_h))
    stem_r = [stemmer.stem(re.sub(r'[^\w\s]',' ',x)) for x in reference] 
    
    recall = []
    precision = []
    
    for h in stem_h:
        precision.append(max([rouge.get_scores(h, r)[0]['rouge-l']['p'] for r in stem_r]))
        
    for r in stem_r:
        if len(stem_h) > 0:
            recall.append(max([rouge.get_scores(h, r)[0]['rouge-l']['r'] for h in stem_h]))
        else:
            recall.append(0)
            
    recall = sum(recall)/len(reference)
    precision = sum(precision)/(len(hypothesis) if len(hypothesis) > 0 else 1e-8)
            
    return recall, precision

def run_eval(predictions):
    scores = []
    for k,v in tqdm(eval_label.items()):

        p = predictions.get(k)

        if p is not None:
            recall, precision = eval_score(v, p)
        else:
            recall, precision = 0, 0
        scores.append([recall, precision])
        
    scores = np.mean(np.array(scores), 0).tolist()
    return {'Recall': scores[0], 'Precision': scores[1]}

In [7]:
results = run_eval(predictions)
results['F1'] = (2*results['Recall']*results['Precision'])/(results['Recall'] + results['Precision'])
print(results)

  0%|          | 0/7652 [00:00<?, ?it/s]

{'Recall': 0.8695231415823662, 'Precision': 0.8126418481996653, 'F1': 0.8401207932872664}


In [8]:
display_sample = np.random.choice(list(eval_label.keys()), size = 30)

for i, k in enumerate(display_sample):
    p = predictions.get(k)
    l = eval_label[k]
    R, P = eval_score(l, p)
    print(k)
    print(f'Ground Truth: {l}')
    print(f'Extraction: {p}')
    print(f'Recall: {R}; Precision: {P}')
    print('-----------------------------------------')

3577164
Ground Truth: ['Agamon', 'Nuance', 'Radloop', 'Royal', 'Within Health']
Extraction: ['3M*M Modal', 'Agamon', 'Nuance', 'Radloop', 'Royal']
Recall: 0.8; Precision: 0.8
-----------------------------------------
3163788
Ground Truth: ['Oracle Cloud', 'Oracle EBS', 'Oracle PeopleSoft']
Extraction: ['Oracle Cloud', 'Oracle EBS', 'Oracle PeopleSoft']
Recall: 1.0; Precision: 1.0
-----------------------------------------
3764125
Ground Truth: ['Rescale']
Extraction: ['Rescale']
Recall: 1.0; Precision: 1.0
-----------------------------------------
3519355
Ground Truth: ['Paack']
Extraction: ['Paack']
Recall: 1.0; Precision: 1.0
-----------------------------------------
3475426
Ground Truth: ['ClickSoftware', 'Clicksoftware', 'Kronos', 'ServiceMax', 'Servicemax', 'Skedulo']
Extraction: ['Clicksoftware', 'Kronos', 'Servicemax', 'Skedulo']
Recall: 1.0; Precision: 1.0
-----------------------------------------
3349267
Ground Truth: ['Acquiring', 'Digital Wallets', 'Issuing']
Extraction: ['Eu