1. See accuracy in top 2, 3, and 5 selected entities
2. Manual check correct and incorrect samples

In [1]:
import os
import torch
import json

from tqdm import tqdm, trange
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset

from longformer_encoder import LongEncoderRanker
from data_process import read_dataset, process_mention_data
from params import EvalParser
import utils

In [44]:
parser = EvalParser()
args = parser.parse_args([])

params = args.__dict__
params['use_longformer'] = not params['use_bert']

#model_path = 'experiments/no_global_attn/bi_golden'
model_path = 'experiments/no_global_attn/bi_pred'
params['model_path'] = model_path

params['is_biencoder'] = True

In [45]:
# init model
ranker = LongEncoderRanker(params)
tokenizer = ranker.tokenizer
device = ranker.device

model_name = params['model_name']
checkpoint = torch.load(os.path.join(model_path, model_name), map_location=device)
# load model
ranker.model.load_state_dict(checkpoint['model_state_dict'])
model = ranker.model
# load optimizer
optim = torch.optim.Adam(model.parameters(), lr = params['learning_rate'])
optim.load_state_dict(checkpoint['optimizer_state_dict'])

Some weights of LongformerModel were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['longformer.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [46]:
params['split'] = 'test'

params['eval_batch_size'] = 231
eval_batch_size = params['eval_batch_size']
valid_samples = read_dataset(params['data_path'], params['split'])
# check just the first document
#valid_samples = valid_samples[:1]

cand_enc_path = os.path.join(params['data_path'], f'{params["split"]}_enc.json')
valid_tensor_data = process_mention_data(
    valid_samples,
    tokenizer,
    max_context_length=params['max_context_length'],
    silent=params['silent'],
    end_tag=params['end_tag'],
    is_biencoder=params['is_biencoder'],
    cand_enc_path=cand_enc_path,
    use_longformer=params['use_longformer']
)

valid_tensor_data = TensorDataset(*valid_tensor_data)
valid_sampler = SequentialSampler(valid_tensor_data)
valid_dataloader = DataLoader(
    valid_tensor_data, sampler=valid_sampler, batch_size=eval_batch_size
)

100%|██████████| 231/231 [00:01<00:00, 137.75it/s]


In [47]:
# load candidate entities
cand_set_enc = torch.load(params['selected_set_path'], map_location=device)
id2label = torch.load(params['id_to_label_path'], map_location=device)

In [48]:
batch = next(iter(valid_dataloader))

In [49]:
batch = tuple(t.to(device) for t in batch)
token_ids, tags, cand_enc, cand_enc_mask, label_ids, label_mask, attn_mask, global_attn_mask = batch
tuple(t.size() for t in batch)

(torch.Size([230, 512]),
 torch.Size([230, 512]),
 torch.Size([230, 80, 1024]),
 torch.Size([230, 80]),
 torch.Size([230, 80]),
 torch.Size([230, 80]),
 torch.Size([230, 512]),
 torch.Size([230, 512]))

In [51]:
#ranker.model.eval()

In [52]:
global_attn_mask = None
with torch.no_grad():
    raw_ctxt_encoding = ranker.model.get_raw_ctxt_encoding(token_ids, attn_mask, global_attn_mask)
    ctxt_embeds = ranker.model.get_ctxt_embeds(raw_ctxt_encoding, tags)

In [53]:
ctxt_embeds.size()

torch.Size([4053, 1024])

In [54]:
#torch.save(ctxt_embeds, 'bi_golden_ctxt_embeds.t7')
torch.save(ctxt_embeds, 'bi_pred_ctxt_embeds.t7')

In [55]:
def get_top_acc(ctxt_embeds, cand_set_enc, k=5):
    scores = ctxt_embeds.mm(cand_set_enc.t())
    print(f'Scores.size(): {scores.size()}')
    
    true_labels = label_ids[label_mask].cpu().tolist()
    assert len(true_labels)== scores.size(0)
    
    top_k = torch.topk(scores, k, dim=1)
    top_indices = top_k[1].cpu().tolist()
    top_labels = [[id2label[i].item() for i in l] for l in top_indices]
    
    top1 = top2 = top3 = top5 = 0

    for i in range(scores.size(0)):
        true_label = true_labels[i]
        top_label = top_labels[i]

        if true_label == top_label[0]:
            top1 += 1
        elif true_label == top_label[1]:
            top2 += 1
        elif true_label == top_label[2]:
            top3 += 1
        elif true_label in top_label[3:]:
            top5 += 1
        else:
            continue

    top5 += top1+top2+top3
    top3 += top1+top2
    top2 += top1
    
    return (scores.size(0), top1, top2, top3, top5)

In [42]:
print('Bi golden model')
total, top1, top2, top3, top5 = get_top_acc(ctxt_embeds, cand_set_enc)
acc = [t/total for t in [top1, top2, top3, top5]]
print(f'Acc are: {acc[0]:.4f}, {acc[1]:.4f}, {acc[2]:.4f}, {acc[3]:.4f}')

Bi golden model
Scores.size(): torch.Size([4053, 5329])
Acc are: 0.6876, 0.7777, 0.8154, 0.8562


In [56]:
print('Bi pred model')
total, top1, top2, top3, top5 = get_top_acc(ctxt_embeds, cand_set_enc)
acc = [t/total for t in [top1, top2, top3, top5]]
print(f'Acc are: {acc[0]:.4f}, {acc[1]:.4f}, {acc[2]:.4f}, {acc[3]:.4f}')

Bi pred model
Scores.size(): torch.Size([4053, 5329])
Acc are: 0.6625, 0.7518, 0.7866, 0.8344


### Manual check first 1

In [23]:
top_5 = torch.topk(scores, 5, dim=1)
top_5[0].size()

torch.Size([38, 5])

In [24]:
top_indices = top_5[1]
top_indices.size()

torch.Size([38, 5])

In [57]:
top_indices = top_indices.cpu()

In [31]:
top_ind_list = top_indices.tolist()

In [37]:
top_labels = [[id2label[i].item() for i in l] for l in top_ind_list]

In [38]:
len(top_labels), len(top_labels[0])

(38, 5)

In [48]:
true_labels = label_ids[label_mask].cpu().tolist()

In [49]:
len(true_labels)

38

In [52]:
top1 = top2 = top3 = top5 = 0

for i, lab in enumerate(true_labels):
    if lab == top_labels[i][0]:
        top1 += 1
        top2 += 1
        top3 += 1
        top5 += 1
    elif lab in top_labels[i][:2]:
        top2 += 1
        top3 += 1
        top5 += 1
    elif lab in top_labels[i][:3]:
        top3 += 1
        top5 += 1
    elif lab in top_labels[i]:
        top5 += 1
    else:
        continue

In [53]:
top1, top2, top3, top5

(28, 32, 35, 36)

In [55]:
[t/38 for t in [top1, top2, top3, top5]]

[0.7368421052631579,
 0.8421052631578947,
 0.9210526315789473,
 0.9473684210526315]

In [None]:
entity_token_ids = torch.load('../models/entity_token_ids_128.t7')

In [58]:
# def print_true_pred(i):
#     true_labels

In [None]:
def cand_set_eval(ranker, valid_dataloader, params, device, cand_set_enc, id2label):
    ranker.model.eval()
#     y_true = []
#     y_pred = []
    top1 = top2 = top3 = top5 = 0

    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        assert params['is_biencoder']

        token_ids, tags, cand_enc, cand_enc_mask, label_ids, label_mask, attn_mask, global_attn_mask = batch

        # evaluate: not leak information about tags
        global_attn_mask = None
        with torch.no_grad():
            raw_ctxt_encoding = ranker.model.get_raw_ctxt_encoding(token_ids, attn_mask, global_attn_mask)
            ctxt_embeds = ranker.model.get_ctxt_embeds(raw_ctxt_encoding, tags)

        scores = ctxt_embeds.mm(cand_set_enc.t())

        true_labels = label_ids[label_mask].cpu().tolist()
        top_5 = torch.topk(scores, 5, dim=1)
        top_indices = top_5[1].cpu()
#         y_true.extend(true_labels)
#         pred_inds = torch.argmax(scores, dim=1).cpu().tolist()
#         pred_labels = [id2label[i].item() for i in pred_inds]
#         y_pred.extend(pred_labels)
#         assert len(y_true)==len(y_pred)
    
#     acc, f1_macro, f1_micro = utils.get_metrics_result(y_true, y_pred)
#     print(f'Accuracy: {acc:.4f}, F1 macro: {f1_macro:.4f}, F1 micro: {f1_micro:.4f}')