In [1]:
#python3 train.py --finetuning --logdir finetuned_by_WWW --trainset /home/cilab/LabMembers/YS/WWW/finetuning/train.txt --validset /home/cilab/LabMembers/YS/WWW/finetuning/valid.txt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
from data_load import NerDataset, pad, VOCAB, tag2idx, idx2tag
import os
import numpy as np
import argparse
from tensorboardX import SummaryWriter
from pytorch_pretrained_bert import BertTokenizer
from tqdm import tqdm

In [2]:
def eval(model, iterator):
    model.eval()

    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(tqdm(iterator)):
            words, x, is_heads, tags, y, seqlens = batch

            _, _, y_hat = model(x, y)  # y_hat: (N, T)

            Words.extend(words)
            Is_heads.extend(is_heads)
            Tags.extend(tags)
            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist())

    ## gets results and save
    with open("temp", 'w') as fout:
        for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):
            y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
            preds = [idx2tag[hat] for hat in y_hat]
            assert len(preds)==len(words.split())==len(tags.split())
            for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                fout.write(f"{w} {t} {p}\n")
            fout.write("\n")

    ## calc metric
    y_true =  np.array([tag2idx[line.split()[1]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])
    y_pred =  np.array([tag2idx[line.split()[2]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])

    num_proposed = len(y_pred[y_pred>1])
    num_correct = (np.logical_and(y_true==y_pred, y_true>1)).astype(np.int).sum()
    num_gold = len(y_true[y_true>1])

    print(f"num_proposed:{num_proposed}")
    print(f"num_correct:{num_correct}")
    print(f"num_gold:{num_gold}")
    try:
        precision = num_correct / num_proposed
    except ZeroDivisionError:
        precision = 1.0

    try:
        recall = num_correct / num_gold
    except ZeroDivisionError:
        recall = 1.0

    try:
        f1 = 2*precision*recall / (precision + recall)
    except ZeroDivisionError:
        if precision*recall==0:
            f1=1.0
        else:
            f1=0

    os.remove("temp")

    print("precision=%.2f"%precision)
    print("recall=%.2f"%recall)
    print("f1=%.2f"%f1)
    return precision, recall, f1

In [3]:
import torch
import torch.nn as nn
from pytorch_pretrained_bert import BertModel

class Net(nn.Module):
    def __init__(self, top_rnns=False, vocab_size=None, device='cpu', finetuning=False, model='bert-base-cased'):
        super().__init__()
        self.bert = BertModel.from_pretrained(model)

        self.top_rnns=top_rnns
        if top_rnns:
            self.rnn = nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=768//2, batch_first=True)
        self.fc = nn.Linear(768, vocab_size)

        self.device = device
        self.finetuning = finetuning

    def forward(self, x, y, ):
        '''
        x: (N, T). int64
        y: (N, T). int64

        Returns
        enc: (N, T, VOCAB)
        '''
        x = x.to(self.device)
        y = y.to(self.device)

        if self.training and self.finetuning:
            # print("->bert.train()")
            self.bert.train()
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers, _ = self.bert(x)
                enc = encoded_layers[-1]

        if self.top_rnns:
            enc, _ = self.rnn(enc)
        logits = self.fc(enc)
        y_hat = logits.argmax(-1)

        return logits, y, y_hat

In [8]:
import glob

tokenizer = BertTokenizer.from_pretrained("./bert_model/scibert_scivocab_uncased/vocab.txt", do_lower_case=False)

test_path = "/home/cilab/LabMembers/YS/WWW/finetuning/test.txt"

eval_dataset = NerDataset(test_path, tokenizer)

eval_iter = data.DataLoader(dataset=eval_dataset,
                             batch_size=4,
                             shuffle=False,
                             num_workers=4,
                             collate_fn=pad)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoints_path = "/mnt_data/models/scibert_WWW/*"
checkpoints = [ck for ck in glob.glob(checkpoints_path) if '.pt' in ck]
max_f1 = 0
max_pt = ""
print(checkpoints)
for ck in checkpoints:
    print(ck)
    model = Net(False, len(VOCAB), device, False, './bert_model/scibert_scivocab_uncased/').cuda()
    #model = nn.DataParallel(model)
    checkpoint = torch.load(ck)
    model.load_state_dict(checkpoint['model_state_dict'],strict=False)
    #model = model.module.cuda()
    precision, recall, f1 = eval(model, eval_iter)
    if max_f1 < f1:
        max_f1 = f1
        max_pt = ck
print("best model f1: {}\n model name: {}".format(max_f1, max_pt))

['/mnt_data/models/scibert_Inspec_p/1.pt', '/mnt_data/models/scibert_Inspec_p/4.pt', '/mnt_data/models/scibert_Inspec_p/3.pt', '/mnt_data/models/scibert_Inspec_p/2.pt', '/mnt_data/models/scibert_Inspec_p/5.pt']
/mnt_data/models/scibert_Inspec_p/1.pt


100%|██████████| 790/790 [00:16<00:00, 48.67it/s]


num_proposed:1718
num_correct:934
num_gold:2430
precision=0.54
recall=0.38
f1=0.45
/mnt_data/models/scibert_Inspec_p/4.pt


100%|██████████| 790/790 [00:16<00:00, 48.52it/s]


num_proposed:2526
num_correct:1211
num_gold:2430
precision=0.48
recall=0.50
f1=0.49
/mnt_data/models/scibert_Inspec_p/3.pt


100%|██████████| 790/790 [00:16<00:00, 48.85it/s]


num_proposed:2351
num_correct:1205
num_gold:2430
precision=0.51
recall=0.50
f1=0.50
/mnt_data/models/scibert_Inspec_p/2.pt


100%|██████████| 790/790 [00:16<00:00, 48.95it/s]


num_proposed:2389
num_correct:1207
num_gold:2430
precision=0.51
recall=0.50
f1=0.50
/mnt_data/models/scibert_Inspec_p/5.pt


100%|██████████| 790/790 [00:15<00:00, 49.71it/s]


num_proposed:2225
num_correct:1114
num_gold:2430
precision=0.50
recall=0.46
f1=0.48
best model f1: 0.5040786446350136
 model name: /mnt_data/models/scibert_Inspec_p/3.pt
