In [2]:
import sys
import torch
import argparse
import os
from os.path import dirname, join
from data_load import NerDataset, pad, VOCAB, tag2idx, idx2tag
from model import Net
from conlleval import evaluate_conll_file
import torch.nn as nn
import numpy as np

In [3]:
checkpoint_path='finetuning/4.pt'

In [4]:
model = torch.load(checkpoint_path)
model.eval()

DataParallel(
  (module): Net(
    (xlnet): XLNetModel(
      (word_embedding): Embedding(32000, 768)
      (layer): ModuleList(
        (0): XLNetLayer(
          (rel_attn): XLNetRelativeAttention(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ff): XLNetFeedForward(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (layer_1): Linear(in_features=768, out_features=3072, bias=True)
            (layer_2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): XLNetLayer(
          (rel_attn): XLNetRelativeAttention(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ff): XLNetFeedForward(
           

In [5]:
newmodel = torch.nn.DataParallel(*(list(model.module.children())[:-1]))

In [6]:
newmodel

DataParallel(
  (module): XLNetModel(
    (word_embedding): Embedding(32000, 768)
    (layer): ModuleList(
      (0): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (1): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine

In [25]:
from torch.utils import data
train_dataset = NerDataset('conll2003/train.txt')
eval_dataset = NerDataset('conll2003/valid.txt')
test_dataset = NerDataset('conll2003/test.txt')

train_iter = data.DataLoader(dataset=train_dataset,
                             batch_size=16,
                             num_workers=4,
                             collate_fn=pad)
eval_iter = data.DataLoader(dataset=eval_dataset,
                            batch_size=16,
                            num_workers=4,
                            collate_fn=pad)

test_iter = data.DataLoader(dataset=test_dataset,
                            batch_size=16,
                            num_workers=4,
                            collate_fn=pad)

In [8]:
for sent in train_dataset.sents[:8]:
    print(len(sent))

3
11
4
4
32
33
35
27


In [9]:
import pickle as pkl
gcn_train = pkl.load(open('../conll_gcn/pkl/train_predictions.pkl', 'rb'))
gcn_val = pkl.load(open('../conll_gcn/pkl/val_predictions.pkl', 'rb'))
gcn_test = pkl.load(open('../conll_gcn/pkl/test_predictions.pkl', 'rb'))

In [10]:
gcn_train = gcn_train[1:]

In [11]:
gcn_train.shape

(18447, 124, 256)

In [12]:
gcn_train[:8].shape

(8, 124, 256)

In [13]:
class EnsembleModel(nn.Module):
    def __init__(self, xln_model, gcn_pretrained, vocab_size, device = 'cuda'):
        super().__init__()
        self.xln_model = xln_model
        self.gcn_pretrained = gcn_pretrained

        self.fc = nn.Linear(768 + 256, vocab_size)

        self.device = device


    def forward(self, x, gcn_tensor):
        '''
        x: (N, T). int64
        y: (N, T). int64

        Returns
        '''
        x = x.to(self.device)

        self.xln_model.eval()
        with torch.no_grad():
            encoded_layers = self.xln_model(x)
            enc = encoded_layers[0]
        gcn_tensor = torch.from_numpy(gcn_tensor).float()
        gcn_tensor = gcn_tensor.to(self.device)
        ensemble = torch.cat((enc, gcn_tensor), dim=2)
        logits = self.fc(ensemble)
        
        return logits


In [26]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ensemble_model = EnsembleModel(newmodel, gcn_train, vocab_size = len(VOCAB), device=device).to(device)

In [27]:
optimizer = torch.optim.Adam(ensemble_model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [28]:
ensemble_model

EnsembleModel(
  (xln_model): DataParallel(
    (module): XLNetModel(
      (word_embedding): Embedding(32000, 768)
      (layer): ModuleList(
        (0): XLNetLayer(
          (rel_attn): XLNetRelativeAttention(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ff): XLNetFeedForward(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (layer_1): Linear(in_features=768, out_features=3072, bias=True)
            (layer_2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): XLNetLayer(
          (rel_attn): XLNetRelativeAttention(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ff): XLNetFeedForwar

In [29]:
for ep in range(4):
    ensemble_model.train()
    for i, batch in enumerate(train_iter):
        words, x, is_heads, _, y, seqlens = batch
        
        max_len = x.shape[1]
        batch_size = x.shape[0]
        idx = i * batch_size
        
        if idx+batch_size > gcn_train.shape[0]:
            break
        gcn_tensor = gcn_train[idx:idx+batch_size]
        padded = np.zeros((batch_size, max_len, 256))

        for ix in range(batch_size):
            is_heads[ix][0] = is_heads[ix][-1] = 0
            num_word = sum(is_heads[ix])
            indexs = [head == 1 for head in is_heads[ix]] + [False] * (max_len - len(is_heads[ix]))
            padded[ix][indexs]=gcn_tensor[ix][:num_word]
        
        optimizer.zero_grad()

        logits = ensemble_model(x, padded)

        logits = logits.view(-1, logits.shape[-1])
        y = y.to(device).view(-1)
        
        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

        if i % 10 == 0:  # monitoring
            print(f"step: {i}, loss: {loss.item()}")

step: 0, loss: 2.97223162651062
step: 10, loss: 2.2374839782714844
step: 20, loss: 1.3089395761489868
step: 30, loss: 2.462332010269165
step: 40, loss: 1.5569565296173096
step: 50, loss: 1.0468534231185913
step: 60, loss: 1.0251495838165283
step: 70, loss: 0.4931497275829315
step: 80, loss: 0.468456894159317
step: 90, loss: 0.41893771290779114
step: 100, loss: 0.5957819819450378
step: 110, loss: 0.3539648652076721
step: 120, loss: 0.32535508275032043
step: 130, loss: 0.5606241226196289
step: 140, loss: 0.48682472109794617
step: 150, loss: 0.2428196519613266
step: 160, loss: 0.6382691264152527
step: 170, loss: 0.16788998246192932
step: 180, loss: 0.1333020180463791
step: 190, loss: 0.1095779687166214
step: 200, loss: 0.16327139735221863
step: 210, loss: 0.18492990732192993
step: 220, loss: 0.09266629070043564
step: 230, loss: 0.21930654346942902
step: 240, loss: 0.27662137150764465
step: 250, loss: 0.21072933077812195
step: 260, loss: 0.2820131182670593
step: 270, loss: 0.16538234055042

step: 1030, loss: 0.0017528204480186105
step: 1040, loss: 0.008346643298864365
step: 1050, loss: 0.006653104908764362
step: 1060, loss: 0.008463237434625626
step: 1070, loss: 0.00028318105614744127
step: 1080, loss: 0.010494867339730263
step: 1090, loss: 0.008232113905251026
step: 1100, loss: 0.017727728933095932
step: 1110, loss: 0.02903505600988865
step: 1120, loss: 0.0026537671219557524
step: 1130, loss: 0.021265359595417976
step: 1140, loss: 0.01818530075252056
step: 1150, loss: 0.001997973769903183
step: 0, loss: 0.001829238492064178
step: 10, loss: 0.003328567836433649
step: 20, loss: 0.002644812688231468
step: 30, loss: 0.00316420616582036
step: 40, loss: 0.03986021503806114
step: 50, loss: 0.008281751535832882
step: 60, loss: 0.011312411166727543
step: 70, loss: 0.00783583428710699
step: 80, loss: 0.004169153049588203
step: 90, loss: 0.0034841829910874367
step: 100, loss: 0.012219632044434547
step: 110, loss: 0.008085445500910282
step: 120, loss: 0.00838881079107523
step: 130, 

step: 860, loss: 0.0006469834479503334
step: 870, loss: 0.03518490493297577
step: 880, loss: 0.0064475093968212605
step: 890, loss: 0.0005716904997825623
step: 900, loss: 0.007693753577768803
step: 910, loss: 0.0013304564636200666
step: 920, loss: 0.012865865603089333
step: 930, loss: 0.0019774052780121565
step: 940, loss: 0.0023338410537689924
step: 950, loss: 0.001336258021183312
step: 960, loss: 0.025589562952518463
step: 970, loss: 0.0009779100073501468
step: 980, loss: 0.007097853813320398
step: 990, loss: 0.0005402073729783297
step: 1000, loss: 0.0010787791106849909
step: 1010, loss: 0.0008144948515109718
step: 1020, loss: 0.001153280376456678
step: 1030, loss: 0.0006649584393016994
step: 1040, loss: 0.004265468567609787
step: 1050, loss: 0.002525780815631151
step: 1060, loss: 0.0034459487069398165
step: 1070, loss: 7.096657645888627e-05
step: 1080, loss: 0.007209317293018103
step: 1090, loss: 0.006534217856824398
step: 1100, loss: 0.007921242155134678
step: 1110, loss: 0.0282324

In [18]:
words[3]

'<cls> DHAKA 1996-08-31 <sep>'

In [19]:
y[0:199]

tensor([[0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1,
         1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 1, 1, 1, 5, 5, 5, 1, 0, 5, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 6, 1, 1, 1, 0, 1,
         1, 1, 0, 0],
        [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [20]:
def eval(model, iterator, gcn, f):
    ensemble_model.eval()

    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            words, x, is_heads, tags, y, seqlens = batch
            max_len = x.shape[1]
            batch_size = x.shape[0]
            idx = i * batch_size

            if idx+batch_size > gcn.shape[0]:
                break
            gcn_tensor = gcn[idx:idx+batch_size]
            padded = np.zeros((batch_size, max_len, 256))

            for ix in range(batch_size):
                is_heads[ix][0] = is_heads[ix][-1] = 0
                num_word = sum(is_heads[ix])
                indexs = [head == 1 for head in is_heads[ix]] + [False] * (max_len - len(is_heads[ix]))
                padded[ix][indexs]=gcn_tensor[ix][:num_word]

#             optimizer.zero_grad()

            logits = ensemble_model(x, padded)
            y_hat = logits.argmax(-1)
#             logits = logits.view(-1, logits.shape[-1])
#             y = y.to(device).view(-1)

#             _, _,y_hat = ensemble_model(x, padded)  # y_hat: (N, T)

            Words.extend(words)
            Is_heads.extend(is_heads)
            Tags.extend(tags)
            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist()) 

                 
    # gets results and save
    with open("temp", 'w') as fout:
        for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):
            is_heads[0] = is_heads[-1] = 1
            y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
            preds = [idx2tag[hat] for hat in y_hat]
#             print(is_heads)
#             print(preds)
#             print(words.split())
#             print(tags.split())
            assert len(preds) == len(words.split()) == len(tags.split())
            for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                fout.write(f"{w} {t} {p}\n")
            fout.write("\n")

    with open("temp") as fout:
        evaluate_conll_file(fout)

    # calc metric
    y_true = np.array([tag2idx[line.split()[1]] for line in open(
        "temp", 'r').read().splitlines() if len(line) > 0])
    y_pred = np.array([tag2idx[line.split()[2]] for line in open(
        "temp", 'r').read().splitlines() if len(line) > 0])

    num_proposed = len(y_pred[y_pred > 1])
    num_correct = (np.logical_and(y_true == y_pred,
                                  y_true > 1)).astype(np.int).sum()
    num_gold = len(y_true[y_true > 1])

    print(f"num_proposed:{num_proposed}")
    print(f"num_correct:{num_correct}")
    print(f"num_gold:{num_gold}")
    try:
        precision = num_correct / num_proposed
    except ZeroDivisionError:
        precision = 1.0

    try:
        recall = num_correct / num_gold
    except ZeroDivisionError:
        recall = 1.0

    try:
        f1 = 2*precision*recall / (precision + recall)
    except ZeroDivisionError:
        if precision*recall == 0:
            f1 = 1.0
        else:
            f1 = 0

    final = f + ".P%.2f_R%.2f_F%.2f" % (precision, recall, f1)
    with open(final, 'w') as fout:
        result = open("temp", "r").read()
        fout.write(f"{result}\n")

        fout.write(f"precision={precision}\n")
        fout.write(f"recall={recall}\n")
        fout.write(f"f1={f1}\n")

    os.remove("temp")

    print("precision=%.4f" % precision)
    print("recall=%.4f" % recall)
    print("f1=%.4f" % f1)
    return precision, recall, f1

In [21]:
logdir = 'checkpoints/02'

In [30]:
print(f"=========eval=========")
if not os.path.exists(logdir):
    os.makedirs(logdir)
fname = os.path.join(logdir, "10")
precision, recall, f1 = eval(ensemble_model, eval_iter, gcn_val, fname)

torch.save(ensemble_model, f"{fname}.pt")
print(f"weights were saved to {fname}.pt")
print(f"=========test0=========")
if not os.path.exists(logdir):
    os.makedirs(logdir)
fname = os.path.join(logdir, "10" +'_test_')
precision, recall, f1 = eval(ensemble_model, test_iter, gcn_test, fname)

processed 55044 tokens with 5942 phrases; found: 5975 phrases; correct: 5859.
accuracy:  99.09%; (non-O)
accuracy:  99.78%; precision:  98.06%; recall:  98.60%; FB1:  98.33
              LOC: precision:  99.08%; recall:  99.62%; FB1:  99.35  1847
             MISC: precision:  95.35%; recall:  95.55%; FB1:  95.45  924
              ORG: precision:  97.48%; recall:  98.14%; FB1:  97.81  1350
              PER: precision:  98.81%; recall:  99.46%; FB1:  99.13  1854
num_proposed:8618
num_correct:8525
num_gold:8603
precision=0.9892
recall=0.9909
f1=0.9901
weights were saved to checkpoints/02/10.pt
processed 50350 tokens with 5648 phrases; found: 5787 phrases; correct: 5131.
accuracy:  92.67%; (non-O)
accuracy:  98.15%; precision:  88.66%; recall:  90.85%; FB1:  89.74
              LOC: precision:  91.61%; recall:  92.33%; FB1:  91.97  1681
             MISC: precision:  77.73%; recall:  81.05%; FB1:  79.36  732
              ORG: precision:  85.89%; recall:  88.32%; FB1:  87.09  1708
     