In [1]:
import sys
import torch
import argparse
import os
from os.path import dirname, join
from data_load import NerDataset, pad, VOCAB, tag2idx, idx2tag
from model import Net
from conlleval import evaluate_conll_file
import torch.nn as nn
import numpy as np

In [2]:
import random
manualSeed = 1

np.random.seed(manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
# if you are suing GPU
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)


torch.backends.cudnn.enabled = False 
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
checkpoint_path='finetuning/4.pt'
logdir = 'checkpoints/02'

In [4]:
model = torch.load(checkpoint_path)
model.eval()

DataParallel(
  (module): Net(
    (xlnet): XLNetModel(
      (word_embedding): Embedding(32000, 768)
      (layer): ModuleList(
        (0): XLNetLayer(
          (rel_attn): XLNetRelativeAttention(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ff): XLNetFeedForward(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (layer_1): Linear(in_features=768, out_features=3072, bias=True)
            (layer_2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): XLNetLayer(
          (rel_attn): XLNetRelativeAttention(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ff): XLNetFeedForward(
           

In [5]:
newmodel = torch.nn.DataParallel(*(list(model.module.children())[:-1]))

In [6]:
newmodel

DataParallel(
  (module): XLNetModel(
    (word_embedding): Embedding(32000, 768)
    (layer): ModuleList(
      (0): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (1): XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine

In [7]:
from torch.utils import data
train_dataset = NerDataset('conll2003/train.txt')
eval_dataset = NerDataset('conll2003/valid.txt')
test_dataset = NerDataset('conll2003/test.txt')

train_iter = data.DataLoader(dataset=train_dataset,
                             batch_size=16,
                             num_workers=4,
                             collate_fn=pad)
eval_iter = data.DataLoader(dataset=eval_dataset,
                            batch_size=16,
                            num_workers=4,
                            collate_fn=pad)

test_iter = data.DataLoader(dataset=test_dataset,
                            batch_size=16,
                            num_workers=4,
                            collate_fn=pad)

In [8]:
# for sent in train_dataset.sents[:8]:
#     print(len(sent))

In [9]:
import pickle as pkl
gcn_train = pkl.load(open('../conll_gcn/pkl/train_predictions.pkl', 'rb'))
gcn_val = pkl.load(open('../conll_gcn/pkl/val_predictions.pkl', 'rb'))
gcn_test = pkl.load(open('../conll_gcn/pkl/test_predictions.pkl', 'rb'))

In [10]:
# gcn_train = gcn_train[1:]

In [11]:
# gcn_train.shape

In [12]:
# gcn_train[:8].shape

In [13]:
class EnsembleModel(nn.Module):
    def __init__(self, xln_model, gcn_pretrained, vocab_size, device = 'cuda'):
        super().__init__()
        self.xln_model = xln_model
        self.gcn_pretrained = gcn_pretrained

        self.fc = nn.Linear(768 + 256, vocab_size)

        self.device = device


    def forward(self, x, gcn_tensor):
        '''
        x: (N, T). int64
        y: (N, T). int64

        Returns
        '''
        x = x.to(self.device)

        self.xln_model.eval()
        with torch.no_grad():
            encoded_layers = self.xln_model(x)
            enc = encoded_layers[0]
        gcn_tensor = torch.from_numpy(gcn_tensor).float()
        gcn_tensor = gcn_tensor.to(self.device)
        ensemble = torch.cat((enc, gcn_tensor), dim=2)
        logits = self.fc(ensemble)
        
        return logits


In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ensemble_model = EnsembleModel(newmodel, gcn_train, vocab_size = len(VOCAB), device=device).to(device)

In [15]:
optimizer = torch.optim.Adam(ensemble_model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [16]:
ensemble_model

EnsembleModel(
  (xln_model): DataParallel(
    (module): XLNetModel(
      (word_embedding): Embedding(32000, 768)
      (layer): ModuleList(
        (0): XLNetLayer(
          (rel_attn): XLNetRelativeAttention(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ff): XLNetFeedForward(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (layer_1): Linear(in_features=768, out_features=3072, bias=True)
            (layer_2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): XLNetLayer(
          (rel_attn): XLNetRelativeAttention(
            (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ff): XLNetFeedForwar

In [17]:
# words[3]

In [18]:
# y[0:199]

In [19]:
def eval(model, iterator, gcn, f):
    ensemble_model.eval()

    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            words, x, is_heads, tags, y, seqlens = batch
            max_len = x.shape[1]
            batch_size = x.shape[0]
            idx = i * batch_size

            if idx+batch_size > gcn.shape[0]:
                break
            gcn_tensor = gcn[idx:idx+batch_size]
            padded = np.zeros((batch_size, max_len, 256))

            for ix in range(batch_size):
                is_heads[ix][0] = is_heads[ix][-1] = 0
                num_word = sum(is_heads[ix])
                indexs = [head == 1 for head in is_heads[ix]] + [False] * (max_len - len(is_heads[ix]))
                padded[ix][indexs]=gcn_tensor[ix][:num_word]

#             optimizer.zero_grad()

            logits = ensemble_model(x, padded)
            y_hat = logits.argmax(-1)
#             logits = logits.view(-1, logits.shape[-1])
#             y = y.to(device).view(-1)

#             _, _,y_hat = ensemble_model(x, padded)  # y_hat: (N, T)

            Words.extend(words)
            Is_heads.extend(is_heads)
            Tags.extend(tags)
            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist()) 

                 
    # gets results and save
    with open("temp", 'w') as fout:
        for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):
            is_heads[0] = is_heads[-1] = 1
            y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
            preds = [idx2tag[hat] for hat in y_hat]
#             print(is_heads)
#             print(preds)
#             print(words.split())
#             print(tags.split())
            assert len(preds) == len(words.split()) == len(tags.split())
            for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                fout.write(f"{w} {t} {p}\n")
            fout.write("\n")

    with open("temp") as fout:
        evaluate_conll_file(fout)

    # calc metric
    y_true = np.array([tag2idx[line.split()[1]] for line in open(
        "temp", 'r').read().splitlines() if len(line) > 0])
    y_pred = np.array([tag2idx[line.split()[2]] for line in open(
        "temp", 'r').read().splitlines() if len(line) > 0])

    num_proposed = len(y_pred[y_pred > 1])
    num_correct = (np.logical_and(y_true == y_pred,
                                  y_true > 1)).astype(np.int).sum()
    num_gold = len(y_true[y_true > 1])

    print(f"num_proposed:{num_proposed}")
    print(f"num_correct:{num_correct}")
    print(f"num_gold:{num_gold}")
    try:
        precision = num_correct / num_proposed
    except ZeroDivisionError:
        precision = 1.0

    try:
        recall = num_correct / num_gold
    except ZeroDivisionError:
        recall = 1.0

    try:
        f1 = 2*precision*recall / (precision + recall)
    except ZeroDivisionError:
        if precision*recall == 0:
            f1 = 1.0
        else:
            f1 = 0

    final = f + ".P%.2f_R%.2f_F%.2f" % (precision, recall, f1)
    with open(final, 'w') as fout:
        result = open("temp", "r").read()
        fout.write(f"{result}\n")

        fout.write(f"precision={precision}\n")
        fout.write(f"recall={recall}\n")
        fout.write(f"f1={f1}\n")

    os.remove("temp")

    print("precision=%.4f" % precision)
    print("recall=%.4f" % recall)
    print("f1=%.4f" % f1)
    return precision, recall, f1

In [20]:
for ep in range(10):
    ensemble_model.train()
    for i, batch in enumerate(train_iter):
        words, x, is_heads, _, y, seqlens = batch
        
        max_len = x.shape[1]
        batch_size = x.shape[0]
        idx = i * batch_size
        
        if idx+batch_size > gcn_train.shape[0]:
            break
        gcn_tensor = gcn_train[idx:idx+batch_size]
        padded = np.zeros((batch_size, max_len, 256))

        for ix in range(batch_size):
            is_heads[ix][0] = is_heads[ix][-1] = 0
            num_word = sum(is_heads[ix])
            indexs = [head == 1 for head in is_heads[ix]] + [False] * (max_len - len(is_heads[ix]))
            padded[ix][indexs]=gcn_tensor[ix][:num_word]
        
        optimizer.zero_grad()

        logits = ensemble_model(x, padded)

        logits = logits.view(-1, logits.shape[-1])
        y = y.to(device).view(-1)
        
        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

#         if i % 10 == 0:  # monitoring
#             print(f"step: {i}, loss: {loss.item()}")
        
    print(f"=========eval at epoch={ep}=========")
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    fname = os.path.join(logdir, str(ep))
    precision, recall, f1 = eval(ensemble_model, eval_iter, gcn_val, fname)

    torch.save(model, f"{fname}.pt")
    print(f"weights were saved to {fname}.pt")
    print(f"=========test at epoch={ep}=========")
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    fname = os.path.join(logdir, str(ep)+'_test_')
    precision, recall, f1 = eval(ensemble_model, test_iter, gcn_test, fname)


processed 55044 tokens with 5942 phrases; found: 5960 phrases; correct: 5860.
accuracy:  99.13%; (non-O)
accuracy:  99.82%; precision:  98.32%; recall:  98.62%; FB1:  98.47
              LOC: precision:  98.81%; recall:  99.62%; FB1:  99.21  1852
             MISC: precision:  96.29%; recall:  95.77%; FB1:  96.03  917
              ORG: precision:  97.99%; recall:  98.28%; FB1:  98.14  1345
              PER: precision:  99.08%; recall:  99.29%; FB1:  99.19  1846
num_proposed:8594
num_correct:8528
num_gold:8603
precision=0.9923
recall=0.9913
f1=0.9918
weights were saved to checkpoints/02/0.pt
processed 50350 tokens with 5648 phrases; found: 5688 phrases; correct: 5107.
accuracy:  92.25%; (non-O)
accuracy:  98.32%; precision:  89.79%; recall:  90.42%; FB1:  90.10
              LOC: precision:  90.69%; recall:  92.81%; FB1:  91.73  1707
             MISC: precision:  79.20%; recall:  78.63%; FB1:  78.91  697
              ORG: precision:  88.30%; recall:  87.24%; FB1:  87.76  1641
      

processed 55044 tokens with 5942 phrases; found: 5954 phrases; correct: 5903.
accuracy:  99.62%; (non-O)
accuracy:  99.89%; precision:  99.14%; recall:  99.34%; FB1:  99.24
              LOC: precision:  99.78%; recall:  99.78%; FB1:  99.78  1837
             MISC: precision:  98.37%; recall:  98.37%; FB1:  98.37  922
              ORG: precision:  98.01%; recall:  98.96%; FB1:  98.48  1354
              PER: precision:  99.73%; recall:  99.67%; FB1:  99.70  1841
num_proposed:8606
num_correct:8570
num_gold:8603
precision=0.9958
recall=0.9962
f1=0.9960
weights were saved to checkpoints/02/7.pt
processed 50350 tokens with 5648 phrases; found: 5711 phrases; correct: 5135.
accuracy:  92.79%; (non-O)
accuracy:  98.36%; precision:  89.91%; recall:  90.92%; FB1:  90.41
              LOC: precision:  92.05%; recall:  92.33%; FB1:  92.19  1673
             MISC: precision:  78.95%; recall:  79.63%; FB1:  79.29  708
              ORG: precision:  87.91%; recall:  88.86%; FB1:  88.38  1679
      

In [22]:
# print(f"=========eval=========")
# if not os.path.exists(logdir):
#     os.makedirs(logdir)
# fname = os.path.join(logdir)
# precision, recall, f1 = eval(ensemble_model, eval_iter, gcn_val, fname)

# torch.save(ensemble_model, f"{fname}.pt")
# print(f"weights were saved to {fname}.pt")
# print(f"=========test=========")
# if not os.path.exists(logdir):
#     os.makedirs(logdir)
# fname = os.path.join(logdir,'_test_')
# precision, recall, f1 = eval(ensemble_model, test_iter, gcn_test, fname)