In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
from model_scibert import Net
from data_load import NerDataset, pad, VOCAB, tokenizer, tag2idx, idx2tag
import os
import numpy as np
import argparse
import glob

class Arg():
    def __init__(self, check_path):
        self.checkpoint = check_path
        self.batch_size = 8

In [2]:
def eval(model, iterator):
    model.eval()

    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            words, x, is_heads, tags, y, seqlens = batch

            _, _, y_hat = model(x, y)  # y_hat: (N, T)

            Words.extend(words)
            Is_heads.extend(is_heads)
            Tags.extend(tags)
            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist())

    ## gets results and save
    with open("temp", 'w') as fout:
        for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):
            y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
            preds = [idx2tag[hat] for hat in y_hat]
            assert len(preds)==len(words.split())==len(tags.split())
            for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                fout.write(f"{w} {t} {p}\n")
            fout.write("\n")

    ## calc metric
    y_true =  np.array([tag2idx[line.split()[1]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])
    y_pred =  np.array([tag2idx[line.split()[2]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])

    num_proposed = len(y_pred[y_pred>1])
    num_correct = (np.logical_and(y_true==y_pred, y_true>1)).astype(np.int).sum()
    num_gold = len(y_true[y_true>1])

    print(f"num_proposed:{num_proposed}")
    print(f"num_correct:{num_correct}")
    print(f"num_gold:{num_gold}")
    try:
        precision = num_correct / num_proposed
    except ZeroDivisionError:
        precision = 1.0

    try:
        recall = num_correct / num_gold
    except ZeroDivisionError:
        recall = 1.0

    try:
        f1 = 2*precision*recall / (precision + recall)
    except ZeroDivisionError:
        if precision*recall==0:
            f1=1.0
        else:
            f1=0

    os.remove("temp")

    print("precision=%.2f"%precision)
    print("recall=%.2f"%recall)
    print("f1=%.2f"%f1)
    return precision, recall, f1

In [7]:
import re
import glob

if __name__=="__main__":
    check_lists=[]
    dir_list = glob.glob("/mnt_data/sci_WWW/5e-5/*")
    dir_list.sort()
    dictionary = {}
    checkpoint = []
    print(dir_list)
    for directory in dir_list:
        if directory.endswith('.pt'):
            checkpoint.append(directory)

    for i in checkpoint:
        print(i)
                
    testset = "/home/cilab/LabMembers/YS/WWW/finetuning/test.txt"
    
    print("load check point of model...")

    model = Net(False, len(VOCAB), 'cpu', False)
    eval_dataset = NerDataset(testset)
    print(eval_dataset)
    eval_iter = data.DataLoader(dataset=eval_dataset,
                                     batch_size=8,
                                     shuffle=False,
                                     num_workers=4,
                                     collate_fn=pad)
    max_f1 = 0
    max_pt = ""
    f1_list = []
    for check in checkpoint:
        print("\nCheck Point : ",check)
        hp = Arg(check)
        checkpoint = torch.load(hp.checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'],strict=False)

        precision, recall, f1 = eval(model, eval_iter)
        f1_list.append(format(f1, '.2f'))
        if max_f1<f1:
            max_f1 = f1
            max_pt = check
    print("\n\n{} : F1_Score : {}".format(max_pt.split('/')[-1], max_f1))
    print(f1_list)

    



['/mnt_data/sci_WWW/5e-5/1.P0.47_R0.33_F0.39', '/mnt_data/sci_WWW/5e-5/1.pt', '/mnt_data/sci_WWW/5e-5/2.P0.57_R0.18_F0.27', '/mnt_data/sci_WWW/5e-5/2.pt', '/mnt_data/sci_WWW/5e-5/3.P0.45_R0.38_F0.41', '/mnt_data/sci_WWW/5e-5/3.pt', '/mnt_data/sci_WWW/5e-5/4.P0.42_R0.35_F0.38', '/mnt_data/sci_WWW/5e-5/4.pt', '/mnt_data/sci_WWW/5e-5/5.P0.44_R0.37_F0.40', '/mnt_data/sci_WWW/5e-5/5.pt']
/mnt_data/sci_WWW/5e-5/1.pt
/mnt_data/sci_WWW/5e-5/2.pt
/mnt_data/sci_WWW/5e-5/3.pt
/mnt_data/sci_WWW/5e-5/4.pt
/mnt_data/sci_WWW/5e-5/5.pt
load check point of model...
<data_load.NerDataset object at 0x7f40b5572550>

Check Point :  /mnt_data/sci_WWW/5e-5/1.pt
num_proposed:1260
num_correct:644
num_gold:1738
precision=0.51
recall=0.37
f1=0.43

Check Point :  /mnt_data/sci_WWW/5e-5/2.pt
num_proposed:572
num_correct:362
num_gold:1738
precision=0.63
recall=0.21
f1=0.31

Check Point :  /mnt_data/sci_WWW/5e-5/3.pt
num_proposed:1500
num_correct:678
num_gold:1738
precision=0.45
recall=0.39
f1=0.42

Check Point :  /

In [8]:
import re
import glob

if __name__=="__main__":
    check_lists=[]
    dir_list = glob.glob("/mnt_data/sci_WWW/5e-6/*")
    dir_list.sort()
    dictionary = {}
    checkpoint = []
    for directory in dir_list:
        if directory.endswith('.pt'):
            checkpoint.append(directory)

    for i in checkpoint:
        print(i)
                
    testset = "/home/cilab/LabMembers/YS/WWW/finetuning/test.txt"
    
    print("load check point of model...")

    model = Net(False, len(VOCAB), 'cpu', False)
    eval_dataset = NerDataset(testset)
    print(eval_dataset)
    eval_iter = data.DataLoader(dataset=eval_dataset,
                                     batch_size=8,
                                     shuffle=False,
                                     num_workers=4,
                                     collate_fn=pad)
    max_f1 = 0
    max_pt = ""
    f1_list = []
    for check in checkpoint:
        print("\nCheck Point : ",check)
        hp = Arg(check)
        checkpoint = torch.load(hp.checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'],strict=False)

        precision, recall, f1 = eval(model, eval_iter)
        f1_list.append(format(f1, '.2f'))
        if max_f1<f1:
            max_f1 = f1
            max_pt = check
    print("\n\n{} : F1_Score : {}".format(max_pt.split('/')[-1], max_f1))
    print(f1_list)

    



/mnt_data/sci_WWW/5e-6/1.pt
/mnt_data/sci_WWW/5e-6/2.pt
/mnt_data/sci_WWW/5e-6/3.pt
/mnt_data/sci_WWW/5e-6/4.pt
/mnt_data/sci_WWW/5e-6/5.pt
load check point of model...
<data_load.NerDataset object at 0x7f4150573d30>

Check Point :  /mnt_data/sci_WWW/5e-6/1.pt
num_proposed:623
num_correct:341
num_gold:1738
precision=0.55
recall=0.20
f1=0.29

Check Point :  /mnt_data/sci_WWW/5e-6/2.pt
num_proposed:778
num_correct:427
num_gold:1738
precision=0.55
recall=0.25
f1=0.34

Check Point :  /mnt_data/sci_WWW/5e-6/3.pt
num_proposed:983
num_correct:530
num_gold:1738
precision=0.54
recall=0.30
f1=0.39

Check Point :  /mnt_data/sci_WWW/5e-6/4.pt
num_proposed:1265
num_correct:636
num_gold:1738
precision=0.50
recall=0.37
f1=0.42

Check Point :  /mnt_data/sci_WWW/5e-6/5.pt
num_proposed:1248
num_correct:621
num_gold:1738
precision=0.50
recall=0.36
f1=0.42


4.pt : F1_Score : 0.4235764235764236
['0.29', '0.34', '0.39', '0.42', '0.42']


In [9]:
import re
import glob

if __name__=="__main__":
    check_lists=[]
    dir_list = glob.glob("/mnt_data/sci_KDD/5e-5/*")
    dir_list.sort()
    dictionary = {}
    checkpoint = []
    for directory in dir_list:
        if directory.endswith('.pt'):
            checkpoint.append(directory)

    for i in checkpoint:
        print(i)
                
    testset = "/home/cilab/LabMembers/YS/KDD/finetuning/test.txt"
    
    print("load check point of model...")

    model = Net(False, len(VOCAB), 'cpu', False)
    eval_dataset = NerDataset(testset)
    print(eval_dataset)
    eval_iter = data.DataLoader(dataset=eval_dataset,
                                     batch_size=8,
                                     shuffle=False,
                                     num_workers=4,
                                     collate_fn=pad)
    max_f1 = 0
    max_pt = ""
    f1_list = []
    for check in checkpoint:
        print("\nCheck Point : ",check)
        hp = Arg(check)
        checkpoint = torch.load(hp.checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'],strict=False)

        precision, recall, f1 = eval(model, eval_iter)
        f1_list.append(format(f1, '.2f'))
        if max_f1<f1:
            max_f1 = f1
            max_pt = check
    print("\n\n{} : F1_Score : {}".format(max_pt.split('/')[-1], max_f1))
    print(f1_list)

    



/mnt_data/sci_KDD/5e-5/1.pt
/mnt_data/sci_KDD/5e-5/2.pt
/mnt_data/sci_KDD/5e-5/3.pt
/mnt_data/sci_KDD/5e-5/4.pt
/mnt_data/sci_KDD/5e-5/5.pt
load check point of model...
<data_load.NerDataset object at 0x7f40afd4ff98>

Check Point :  /mnt_data/sci_KDD/5e-5/1.pt
num_proposed:457
num_correct:218
num_gold:1121
precision=0.48
recall=0.19
f1=0.28

Check Point :  /mnt_data/sci_KDD/5e-5/2.pt
num_proposed:587
num_correct:254
num_gold:1121
precision=0.43
recall=0.23
f1=0.30

Check Point :  /mnt_data/sci_KDD/5e-5/3.pt
num_proposed:820
num_correct:368
num_gold:1121
precision=0.45
recall=0.33
f1=0.38

Check Point :  /mnt_data/sci_KDD/5e-5/4.pt
num_proposed:930
num_correct:388
num_gold:1121
precision=0.42
recall=0.35
f1=0.38

Check Point :  /mnt_data/sci_KDD/5e-5/5.pt
num_proposed:675
num_correct:316
num_gold:1121
precision=0.47
recall=0.28
f1=0.35


3.pt : F1_Score : 0.3791859866048429
['0.28', '0.30', '0.38', '0.38', '0.35']


In [10]:
import re
import glob

if __name__=="__main__":
    check_lists=[]
    dir_list = glob.glob("/mnt_data/sci_KDD/5e-6/*")
    dir_list.sort()
    dictionary = {}
    checkpoint = []
    for directory in dir_list:
        if directory.endswith('.pt'):
            checkpoint.append(directory)

    for i in checkpoint:
        print(i)
                
    testset = "/home/cilab/LabMembers/YS/KDD/finetuning/test.txt"
    
    print("load check point of model...")

    model = Net(False, len(VOCAB), 'cpu', False)
    eval_dataset = NerDataset(testset)
    print(eval_dataset)
    eval_iter = data.DataLoader(dataset=eval_dataset,
                                     batch_size=8,
                                     shuffle=False,
                                     num_workers=4,
                                     collate_fn=pad)
    max_f1 = 0
    max_pt = ""
    f1_list = []
    for check in checkpoint:
        print("\nCheck Point : ",check)
        hp = Arg(check)
        checkpoint = torch.load(hp.checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'],strict=False)

        precision, recall, f1 = eval(model, eval_iter)
        f1_list.append(format(f1, '.2f'))
        if max_f1<f1:
            max_f1 = f1
            max_pt = check
    print("\n\n{} : F1_Score : {}".format(max_pt.split('/')[-1], max_f1))
    print(f1_list)

    



/mnt_data/sci_KDD/5e-6/1.pt
/mnt_data/sci_KDD/5e-6/2.pt
/mnt_data/sci_KDD/5e-6/3.pt
/mnt_data/sci_KDD/5e-6/4.pt
/mnt_data/sci_KDD/5e-6/5.pt
load check point of model...
<data_load.NerDataset object at 0x7f40afd14cf8>

Check Point :  /mnt_data/sci_KDD/5e-6/1.pt
num_proposed:270
num_correct:111
num_gold:1121
precision=0.41
recall=0.10
f1=0.16

Check Point :  /mnt_data/sci_KDD/5e-6/2.pt
num_proposed:366
num_correct:156
num_gold:1121
precision=0.43
recall=0.14
f1=0.21

Check Point :  /mnt_data/sci_KDD/5e-6/3.pt
num_proposed:614
num_correct:252
num_gold:1121
precision=0.41
recall=0.22
f1=0.29

Check Point :  /mnt_data/sci_KDD/5e-6/4.pt
num_proposed:912
num_correct:361
num_gold:1121
precision=0.40
recall=0.32
f1=0.36

Check Point :  /mnt_data/sci_KDD/5e-6/5.pt
num_proposed:784
num_correct:339
num_gold:1121
precision=0.43
recall=0.30
f1=0.36


5.pt : F1_Score : 0.3559055118110236
['0.16', '0.21', '0.29', '0.36', '0.36']


In [4]:
import re
import glob

if __name__=="__main__":
    check_lists=[]
    dir_list = glob.glob("/mnt_data/sci_bert_Inspec/*")
    dir_list.sort()
    dictionary = {}
    checkpoint = []
    for directory in dir_list:
        if directory.endswith('.pt'):
            checkpoint.append(directory)

    for i in checkpoint:
        print(i)
                
    testset = "/home/cilab/LabMembers/YS/Inspec/data/finetuning/test.txt"
    
    print("load check point of model...")

    model = Net(False, len(VOCAB), 'cpu', False)
    eval_dataset = NerDataset(testset)
    print(eval_dataset)
    eval_iter = data.DataLoader(dataset=eval_dataset,
                                     batch_size=8,
                                     shuffle=False,
                                     num_workers=4,
                                     collate_fn=pad)
    max_f1 = 0
    max_pt = ""
    f1_list = []
    for check in checkpoint:
        print("\nCheck Point : ",check)
        hp = Arg(check)
        checkpoint = torch.load(hp.checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'],strict=False)

        precision, recall, f1 = eval(model, eval_iter)
        f1_list.append(format(f1, '.2f'))
        if max_f1<f1:
            max_f1 = f1
            max_pt = check
    print("\n\n{} : F1_Score : {}".format(max_pt.split('/')[-1], max_f1))
    print(f1_list)

    



/mnt_data/sci_bert_Inspec/1.pt
/mnt_data/sci_bert_Inspec/2.pt
/mnt_data/sci_bert_Inspec/3.pt
/mnt_data/sci_bert_Inspec/4.pt
/mnt_data/sci_bert_Inspec/5.pt
load check point of model...
<data_load.NerDataset object at 0x7fd7f007e1d0>

Check Point :  /mnt_data/sci_bert_Inspec/1.pt
num_proposed:807
num_correct:468
num_gold:2430
precision=0.58
recall=0.19
f1=0.29

Check Point :  /mnt_data/sci_bert_Inspec/2.pt
num_proposed:1952
num_correct:915
num_gold:2430
precision=0.47
recall=0.38
f1=0.42

Check Point :  /mnt_data/sci_bert_Inspec/3.pt
num_proposed:1525
num_correct:760
num_gold:2430
precision=0.50
recall=0.31
f1=0.38

Check Point :  /mnt_data/sci_bert_Inspec/4.pt
num_proposed:1604
num_correct:789
num_gold:2430
precision=0.49
recall=0.32
f1=0.39

Check Point :  /mnt_data/sci_bert_Inspec/5.pt
num_proposed:2182
num_correct:910
num_gold:2430
precision=0.42
recall=0.37
f1=0.39


2.pt : F1_Score : 0.4176175262437243
['0.29', '0.42', '0.38', '0.39', '0.39']
