In [9]:
def eval(model, iterator):
    model.eval()

    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            words, x, is_heads, tags, y, seqlens = batch
            _, _, y_hat = model(x, y)  # y_hat: (N, T)

            Words.extend(words)
            Is_heads.extend(is_heads)
            Tags.extend(tags)
            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist())

    ## gets results and save
    with open("temp", 'w') as fout:
        for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):
            y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
            preds = [idx2tag[hat] for hat in y_hat]
            assert len(preds)==len(words.split())==len(tags.split())
            for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                fout.write(f"{w} {t} {p}\n")
            fout.write("\n")

    ## calc metric
    y_true =  np.array([tag2idx[line.split()[1]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])
    y_pred =  np.array([tag2idx[line.split()[2]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])
    print(y_pred)
    num_proposed = len(y_pred[y_pred>1])
    num_correct = (np.logical_and(y_true==y_pred, y_true>1)).astype(np.int).sum()
    num_gold = len(y_true[y_true>1])

    print(f"num_proposed:{num_proposed}")
    print(f"num_correct:{num_correct}")
    print(f"num_gold:{num_gold}")
    try:
        precision = num_correct / num_proposed
    except ZeroDivisionError:
        print("="*80)
        precision = 1.0

    try:
        recall = num_correct / num_gold
    except ZeroDivisionError:
        print("="*80)
        recall = 1.0

    try:
        f1 = 2*precision*recall / (precision + recall)
    except ZeroDivisionError:
        print("="*80)
        if precision*recall==0:
            f1=1.0
        else:
            f1=0

#     final = f + ".P%.2f_R%.2f_F%.2f" %(precision, recall, f1)
#     with open(final, 'w') as fout:
#         result = open("temp", "r").read()
#         fout.write(f"{result}\n")

#         fout.write(f"precision={precision}\n")
#         fout.write(f"recall={recall}\n")
#         fout.write(f"f1={f1}\n")

    os.remove("temp")

    print("precision=%.2f"%precision)
    print("recall=%.2f"%recall)
    print("f1=%.2f"%f1)
    return precision, recall, f1

In [6]:
import torch
import torch.nn as nn
from pytorch_pretrained_bert import BertModel
from data_load import NerDataset, pad, VOCAB, tokenizer, tag2idx, idx2tag
from torch.utils import data
from data_load import NerDataset, pad, VOCAB, tokenizer, tag2idx, idx2tag
import numpy as np
import os



class Net(nn.Module):
    def __init__(self, top_rnns=False, vocab_size=None, device='cpu', finetuning=False):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.top_rnns=top_rnns
        if top_rnns:
            self.rnn = nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=768//2, batch_first=True)
        self.fc = nn.Linear(768, vocab_size)

        self.device = device
        self.finetuning = finetuning

    def forward(self, x, y, ):
        '''
        x: (N, T). int64
        y: (N, T). int64

        Returns
        enc: (N, T, VOCAB)
        '''
        x = x.to(self.device)
        y = y.to(self.device)
        
        if self.training and self.finetuning:
            # print("->bert.train()")
            self.bert.train()
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
            print("training\n\n");
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers, _ = self.bert(x)
                enc = encoded_layers[-1]

        if self.top_rnns:
            enc, _ = self.rnn(enc)
        logits = self.fc(enc)
        y_hat = logits.argmax(-1)
        return logits, y, y_hat

In [13]:

testset_url = "/home/cilab/LabMembers/YS/WWW/finetuning/valid.txt"
eval_dataset = NerDataset(testset_url)

import glob

f1_list = []

check = "./24.pt"
model = Net(False, len(VOCAB), device='cuda', finetuning=False).cuda()
checkpoint = torch.load(check)
#print("load check")
#print(checkpoint)
model.load_state_dict(checkpoint['model_state_dict'],strict=False)
model.eval()
print("eval")
eval_iter = data.DataLoader(dataset=eval_dataset, batch_size=16, shuffle=True, num_workers=4,collate_fn=pad)


precision, recall, f1 = eval(model, eval_iter)
f1_list.append(f1)
print(precision, recall, f1)
print("\n")
print(f1_list)


eval
[1 1 2 ... 1 2 1]
num_proposed:14090
num_correct:244
num_gold:3643
precision=0.02
recall=0.07
f1=0.03
0.017317246273953157 0.06697776557782048 0.027519314272824676


[0.027519314272824676]
