In [1]:
#python3 train.py --finetuning --logdir finetuned_by_WWW --trainset /home/cilab/LabMembers/YS/WWW/finetuning/train.txt --validset /home/cilab/LabMembers/YS/WWW/finetuning/valid.txt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
from data_load import NerDataset, pad, VOCAB, tag2idx, idx2tag
import os
import numpy as np
import argparse
from tensorboardX import SummaryWriter
from pytorch_pretrained_bert import BertTokenizer

In [2]:
tokenizer = BertTokenizer.from_pretrained("./bert_model/scibert_scivocab_uncased/vocab.txt", do_lower_case=False)

from tqdm import tqdm

def train(model, iterator, optimizer, criterion):
    model.train()
    for i, batch in enumerate(tqdm(iterator)):
        words, x, is_heads, tags, y, seqlens = batch
        _y = y # for monitoring
        optimizer.zero_grad()
        logits, y, _ = model(x, y) # logits: (N, T, VOCAB), y: (N, T)

        logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)
        y = y.view(-1)  # (N*T,)

        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

        if i==0:
            print("=====sanity check======")
            print("words:", words[0])
            print("x:", x.cpu().numpy()[0][:seqlens[0]])
            print("tokens:", tokenizer.convert_ids_to_tokens(x.cpu().numpy()[0])[:seqlens[0]])
            print("is_heads:", is_heads[0])
            print("y:", _y.cpu().numpy()[0][:seqlens[0]])
            print("tags:", tags[0])
            print("seqlen:", seqlens[0])
            print("=======================")


        if i%100==0: # monitoring
            print(f"step: {i} /{len(iterator)}, loss: {loss.item()}")
    return loss.item()

In [3]:
def eval(model, iterator, f):
    model.eval()
    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(tqdm(iterator)):
            words, x, is_heads, tags, y, seqlens = batch

            _, _, y_hat = model(x, y)  # y_hat: (N, T)

            Words.extend(words)
            Is_heads.extend(is_heads)
            Tags.extend(tags)
            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist())

    ## gets results and save
    with open("temp", 'w') as fout:
        for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):
            y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
            preds = [idx2tag[hat] for hat in y_hat]
            assert len(preds)==len(words.split())==len(tags.split())
            for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                fout.write(f"{w} {t} {p}\n")
            fout.write("\n")

    ## calc metric
    y_true =  np.array([tag2idx[line.split()[1]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])
    y_pred =  np.array([tag2idx[line.split()[2]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])

    num_proposed = len(y_pred[y_pred>1])
    num_correct = (np.logical_and(y_true==y_pred, y_true>1)).astype(np.int).sum()
    num_gold = len(y_true[y_true>1])

    print(f"num_proposed:{num_proposed}")
    print(f"num_correct:{num_correct}")
    print(f"num_gold:{num_gold}")
    try:
        precision = num_correct / num_proposed
    except ZeroDivisionError:
        precision = 1.0

    try:
        recall = num_correct / num_gold
    except ZeroDivisionError:
        recall = 1.0

    try:
        f1 = 2*precision*recall / (precision + recall)
    except ZeroDivisionError:
        if precision*recall==0:
            f1=1.0
        else:
            f1=0

    final = f + ".P%.2f_R%.2f_F%.2f" %(precision, recall, f1)
    with open(final, 'w') as fout:
        result = open("temp", "r").read()
        fout.write(f"{result}\n")

        fout.write(f"precision={precision}\n")
        fout.write(f"recall={recall}\n")
        fout.write(f"f1={f1}\n")

    os.remove("temp")

    print("precision=%.2f"%precision)
    print("recall=%.2f"%recall)
    print("f1=%.2f"%f1)
    return precision, recall, f1

In [4]:
args = """
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--n_epochs", type=int, default=5)
parser.add_argument("--finetuning", dest="finetuning", action="store_true")
parser.add_argument("--top_rnns", dest="top_rnns", action="store_true")
parser.add_argument("--logdir", type=str, default="./kp_pretrained_www/")
parser.add_argument("--trainset", type=str, default="/home/cilab/LabMembers/YS/KDD/finetuning/train.txt")
parser.add_argument("--validset", type=str, default="/home/cilab/LabMembers/YS/KDD/finetuning/valid.txt")
"""

In [5]:
class Args():
    def __init__(self):
        pass

hp = Args()
for l in args.split('\n'):
    if not "parser" in l:
        continue
    v = l[l.rfind('=')+1:-1]
    if "true" in v:
        v = True
    elif "False" in v:
        v = False
    else:
        try:
            i = int(v)
        except:
            try:
                v = float(v)
                setattr(hp, l[l.find("--")+2:l.find(",")-1], v)
                continue
            except:
                setattr(hp, l[l.find("--")+2:l.find(",")-1], v[1:-1])
                continue
        f = float(v)
        if i==f:
            v = int(v)
        else:
            v = float(v)
    setattr(hp, l[l.find("--")+2:l.find(",")-1], v)
hp.lr = 5e-5
hp.top_rnns = False
hp.logdir = "/mnt_data/models/scibert_KDD"

In [8]:
import torch
import torch.nn as nn
from pytorch_pretrained_bert import BertModel

class Net(nn.Module):
    def __init__(self, top_rnns=False, vocab_size=None, device='cpu', finetuning=False, model='bert-base-cased'):
        super().__init__()
        self.bert = BertModel.from_pretrained(model)

        self.top_rnns=top_rnns
        if top_rnns:
            self.rnn = nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=768//2, batch_first=True)
        self.fc = nn.Linear(768, vocab_size)

        self.device = device
        self.finetuning = finetuning

    def forward(self, x, y, ):
        '''
        x: (N, T). int64
        y: (N, T). int64

        Returns
        enc: (N, T, VOCAB)
        '''
        x = x.to(self.device)
        y = y.to(self.device)

        if self.training and self.finetuning:
            # print("->bert.train()")
            self.bert.train()
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers, _ = self.bert(x)
                enc = encoded_layers[-1]

        if self.top_rnns:
            enc, _ = self.rnn(enc)
        logits = self.fc(enc)
        y_hat = logits.argmax(-1)

        return logits, y, y_hat

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
summary = SummaryWriter()

model = Net(hp.top_rnns, len(VOCAB), device, hp.finetuning, './bert_model/scibert_scivocab_uncased/').cuda()
"""
if hp.checkpoint:
    print("load check point of model...")
    checkpoint = torch.load(hp.checkpoint)
    print(checkpoint.keys())
    model.load_state_dict(checkpoint['model_state_dict'],strict=False)
"""
model = nn.DataParallel(model)

train_dataset = NerDataset(hp.trainset, tokenizer)
eval_dataset = NerDataset(hp.validset, tokenizer)

train_iter = data.DataLoader(dataset=train_dataset,
                             batch_size=hp.batch_size,
                             shuffle=True,
                             num_workers=4,
                             collate_fn=pad)

eval_iter = data.DataLoader(dataset=eval_dataset,
                             batch_size=hp.batch_size,
                             shuffle=False,
                             num_workers=4,
                             collate_fn=pad)

optimizer = optim.Adam(model.parameters(), lr = hp.lr)
criterion = nn.CrossEntropyLoss(ignore_index=0)
#lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.1)
for epoch in range(1, hp.n_epochs+1):

    #model.rnn.flatten_parameters()
    loss = train(model, train_iter, optimizer, criterion)

    print(f"=========eval at epoch={epoch}=========")
    if not os.path.exists(hp.logdir): os.makedirs(hp.logdir)
    fname = os.path.join(hp.logdir, str(epoch))

    #model.rnn.flatten_parameters()
    precision, recall, f1 = eval(model, eval_iter, fname)

    #torch.save(model.state_dict(), f"{fname}.pt")
    try:
        state_dict = model.module.state_dict()
    except AttributeError:
        state_dict = model.state_dict()
    torch.save({
        'epoch':hp.n_epochs,
        'model_state_dict': state_dict,
        'optimizer_state_dict': optimizer.state_dict()
    },f"{fname}.pt")
    torch.save(model, "latest_model.pt")
    print(f"weights were saved to {fname}.pt")
    summary.add_scalar('f1_score',f1, epoch)
    summary.add_scalar('loss',loss, epoch)


  0%|          | 1/634 [00:01<14:02,  1.33s/it]

words: [CLS] existing graph generation models do not exhibit these types of behavior , even at a qualitative level . [SEP]
x: [ 102 3302 1845 3014 1262  572  302 5537  407 1910  131 1689  422 1390
  235  106 6526  615  205  103]
tokens: ['[CLS]', 'existing', 'graph', 'generation', 'models', 'do', 'not', 'exhibit', 'these', 'types', 'of', 'behavior', ',', 'even', 'at', 'a', 'qualitative', 'level', '.', '[SEP]']
is_heads: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
y: [0 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0]
tags: <PAD> O B-KP O O O O O O O O O O O O O O O O <PAD>
seqlen: 20
step: 0 /634, loss: 1.5029336214065552


 16%|█▌        | 101/634 [01:21<08:31,  1.04it/s]

step: 100 /634, loss: 0.14210844039916992


 32%|███▏      | 201/634 [03:06<07:43,  1.07s/it]

step: 200 /634, loss: 0.12192035466432571


 47%|████▋     | 301/634 [05:05<05:42,  1.03s/it]

step: 300 /634, loss: 0.347880095243454


 63%|██████▎   | 401/634 [06:51<03:32,  1.10it/s]

step: 400 /634, loss: 0.11123881489038467


 79%|███████▉  | 501/634 [08:41<02:10,  1.02it/s]

step: 500 /634, loss: 0.16059814393520355


 95%|█████████▍| 601/634 [10:24<00:34,  1.04s/it]

step: 600 /634, loss: 0.4009535312652588


100%|██████████| 634/634 [11:00<00:00,  1.04s/it]
  0%|          | 0/82 [00:00<?, ?it/s]



100%|██████████| 82/82 [00:53<00:00,  1.53it/s]


num_proposed:529
num_correct:321
num_gold:1033
precision=0.61
recall=0.31
f1=0.41


  0%|          | 0/634 [00:00<?, ?it/s]

weights were saved to /mnt_data/models/scibert_KDD/1.pt


  0%|          | 1/634 [00:01<19:54,  1.89s/it]

words: [CLS] we gathered 1971 benign and 1651 malicious executables and encoded each as a training example using n-grams of byte codes as features . [SEP]
x: [  102   185 15024  2693 30130 10346   137 16912 30130 15473 26483 30113
   137  9195   535   188   106  2208  1143   487   146   579 29809   131
 23392  6095   188  1882   205   103]
tokens: ['[CLS]', 'we', 'gathered', '197', '##1', 'benign', 'and', '165', '##1', 'malicious', 'executable', '##s', 'and', 'encoded', 'each', 'as', 'a', 'training', 'example', 'using', 'n', '-', 'grams', 'of', 'byte', 'codes', 'as', 'features', '.', '[SEP]']
is_heads: [1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1]
y: [0 1 1 1 0 1 1 1 0 2 1 0 1 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 0]
tags: <PAD> O O O O O O B-KP O O O O O O O O O O O O O O O O <PAD>
seqlen: 30
step: 0 /634, loss: 0.17166893184185028


 16%|█▌        | 101/634 [01:49<09:33,  1.08s/it]

step: 100 /634, loss: 0.14588750898838043


 32%|███▏      | 201/634 [03:19<05:53,  1.23it/s]

step: 200 /634, loss: 0.21736468374729156


 47%|████▋     | 301/634 [04:40<04:30,  1.23it/s]

step: 300 /634, loss: 0.13726988434791565


 63%|██████▎   | 401/634 [06:03<03:43,  1.04it/s]

step: 400 /634, loss: 0.13486182689666748


 79%|███████▉  | 501/634 [07:43<01:55,  1.15it/s]

step: 500 /634, loss: 0.14838054776191711


 95%|█████████▍| 601/634 [09:24<00:28,  1.15it/s]

step: 600 /634, loss: 0.11923671513795853


100%|██████████| 634/634 [09:59<00:00,  1.06it/s]
  0%|          | 0/82 [00:00<?, ?it/s]



100%|██████████| 82/82 [00:53<00:00,  1.54it/s]


num_proposed:1044
num_correct:500
num_gold:1033
precision=0.48
recall=0.48
f1=0.48


  0%|          | 0/634 [00:00<?, ?it/s]

weights were saved to /mnt_data/models/scibert_KDD/2.pt


  0%|          | 1/634 [00:01<16:37,  1.58s/it]

words: [CLS] discovering frequent patterns in sensitive data discovering frequent patterns from data is a popular exploratory technique in datamining . [SEP]
x: [  102 24367  5808  2465   121  4232   453 24367  5808  2465   263   453
   165   106  6237 14933  2358   121   453  1254   140   205   103]
tokens: ['[CLS]', 'discovering', 'frequent', 'patterns', 'in', 'sensitive', 'data', 'discovering', 'frequent', 'patterns', 'from', 'data', 'is', 'a', 'popular', 'exploratory', 'technique', 'in', 'data', '##min', '##ing', '.', '[SEP]']
is_heads: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1]
y: [0 1 2 3 1 1 1 1 2 3 1 1 1 1 1 1 1 1 1 0 0 1 0]
tags: <PAD> O B-KP I-KP O O O O B-KP I-KP O O O O O O O O O O <PAD>
seqlen: 23
step: 0 /634, loss: 0.17428305745124817


 15%|█▍        | 92/634 [01:34<08:51,  1.02it/s]

KeyboardInterrupt: 