In [2]:
# !pip install dgl-cu102==0.4.3
# !pip install transformers==3.1.0


In [1]:
import time
import matplotlib
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from config import *
from data import DGLREDataset, DGLREDataloader, BERTDGLREDataset
from models.GAIN import GAIN_GloVe, GAIN_BERT
import test
from utils import Accuracy, get_cuda, logging, print_params


In [2]:
class GAIN_Original:
    def __init__(self):
        self.activation='relu'
        self.batch_size=5
        self.bert_fix=False
        self.bert_hid_size=768
        self.bert_path='../PLM/bert-base-uncased'
        self.checkpoint_dir='checkpoint'
        self.clip=-1
        self.coslr=True
        self.data_word_vec=np.array([[ 0.      ,  0.      ,  0.      ,  0.      ,  0.,             0.      ]])
        self.dev_set='../data/dev.json'
        self.dev_set_save='../data/prepro_data/dev_BERT.pkl'
        self.dropout=0.6
        self.entity_id_pad=0
        self.entity_id_size=20
        self.entity_type_num=7
        self.entity_type_pad=0
        self.entity_type_size=20
        self.epoch=300
        self.fig_result_dir='fig_result'
        self.finetune_word=False
        self.gcn_dim=808
        self.gcn_layers=2
        self.input_theta=-1
        self.k_fold='none'
        self.log_step=20
        self.lr=0.001
        self.lstm_dropout=0.1
        self.lstm_hidden_size=32
        self.max_entity_num=80
        self.mention_drop=False
        self.model_name='GAIN_BERT_base'
        self.negativa_alpha=4.0
        self.nlayers=1
        self.pre_train_word=False
        self.pretrain_model=''
        self.relation_nums=97
        self.save_model_freq=3
        self.test_batch_size=16
        self.test_epoch=5
        self.test_set='../data/test.json'
        self.test_set_save='../data/prepro_data/test_BERT.pkl'
        self.train_set='../data/train_annotated.json'
        self.train_set_save='../data/prepro_data/train_BERT.pkl'
        self.use_entity_id=True
        self.use_entity_type=True
        self.use_model='bert'
        self.vocabulary_size=200000
        self.weight_decay=0.0001
        self.word_emb_size=10
        self.word_pad=0

                                

In [3]:
import numpy as np
from models.GAIN import GAIN_GloVe, GAIN_BERT
opt = GAIN_Original()
model_original = GAIN_BERT(opt)

pretrain_model = "GAIN_BERT_base_best.pt"
chkpt = torch.load(pretrain_model, map_location=torch.device('cpu'))

model_original.load_state_dict(chkpt['checkpoint'])


<All keys matched successfully>

## DocRED dataset load

In [13]:
train_set = BERTDGLREDataset(opt.train_set, opt.train_set_save, word2id, ner2id, rel2id, dataset_type='train',
                             opt=opt)
dev_set = BERTDGLREDataset(opt.dev_set, opt.dev_set_save, word2id, ner2id, rel2id, dataset_type='dev',
                           instance_in_train=train_set.instance_in_train, opt=opt)
# dataloaders
train_loader = DGLREDataloader(train_set, batch_size=opt.batch_size, shuffle=True,
                               negativa_alpha=opt.negativa_alpha)
dev_loader = DGLREDataloader(dev_set, batch_size=opt.test_batch_size, dataset_type='dev')

Reading data from ../data/train_annotated.json.
load preprocessed data from ../data/prepro_data/train_BERT.pkl.
Reading data from ../data/dev.json.
load preprocessed data from ../data/prepro_data/dev_BERT.pkl.


In [14]:
def train(model, train_loader):

    model = get_cuda(model)

    model.train()
    acc_NA, acc_not_NA, acc_total = Accuracy(), Accuracy(), Accuracy()

    start_epoch = 0
    for epoch in range(start_epoch, opt.epoch + 1):
        start_time = time.time()
        for acc in [acc_NA, acc_not_NA, acc_total]:
            acc.clear()

        for ii, d in enumerate(train_loader):
            relation_multi_label = d['relation_multi_label']
            relation_mask = d['relation_mask']
            relation_label = d['relation_label']

            predict = model(words=d['context_idxs'],
                src_lengths=d['context_word_length'],
                mask=d['context_word_mask'],
                entity_type=d['context_ner'],
                entity_id=d['context_pos'],
                mention_id=d['context_mention'],
                distance=None,
                entity2mention_table=d['entity2mention_table'],
                graphs=d['graphs'],
                h_t_pairs=d['h_t_pairs'],
                relation_mask=relation_mask,
                path_table=d['path_table'],
                entity_graphs=d['entity_graphs'],
                ht_pair_distance=d['ht_pair_distance']
            )
            print(predict)
            break
        break

In [15]:
train(model_original, train_loader)

tensor([[[  -4.4753,  -43.2044,  -42.7028,  ...,  -31.9513,  -31.4520,
           -37.2508],
         [ -12.3107,  -54.4923,  -43.5470,  ...,  -41.5890,  -38.9031,
           -36.9920],
         [  -9.3263,  -67.7657,  -43.3337,  ...,  -58.3543,  -62.6477,
           -62.0271],
         ...,
         [ -19.6678, -133.6537, -139.1118,  ..., -141.7457, -149.0825,
          -150.9542],
         [ -11.6704, -151.5786, -153.3741,  ..., -148.8346, -162.9838,
          -169.6308],
         [ -19.8514, -153.4541, -135.7840,  ..., -181.8170, -167.2855,
          -192.5784]],

        [[  -8.0435,  -26.3091,    9.1803,  ...,  -24.8662,  -44.5049,
           -46.8739],
         [ -12.1401,   13.1404,  -34.0412,  ...,  -68.4048,  -68.7761,
           -74.8332],
         [ -19.3975,   21.9022,  -40.7421,  ...,  -60.0891,  -66.3379,
           -70.8801],
         ...,
         [ -29.8872, -261.7500, -197.6829,  ..., -282.1135, -296.5481,
          -286.5006],
         [ -14.4501, -280.8423, -218.345

# Change OPT (DocRED -> SemEval)

In [47]:
optSemEval = opt
optSemEval.train = "../SemEval2DocRED/train_annotated.json"
optSemEval.train_set = "../SemEval2DocRED/train_annotated.json"
optSemEval.train_set_save = "../SemEval2DocRED/train_BERT.pkl"

optSemEval.dev_set = "../SemEval2DocRED/dev.json"
optSemEval.dev_set_save = "../SemEval2DocRED/dev_BERT.pkl"

with open ("../SemEval2DocRED/DocRED_baseline_metadata/rel2id.json") as d:
    optSemEval.rel2id = json.load(d)
with open ("../SemEval2DocRED/DocRED_baseline_metadata/word2id.json") as d:
    optSemEval.word2id = json.load(d)

optSemEval.relation_nums = 10

In [58]:
optSemEval.relation_nums = 10

In [59]:
ner2id

{'BLANK': 0, 'ORG': 1, 'LOC': 2, 'TIME': 3, 'PER': 4, 'MISC': 5, 'NUM': 6}

In [61]:
train_set = BERTDGLREDataset(optSemEval.train_set, optSemEval.train_set_save, optSemEval.word2id, {"None":0}, optSemEval.rel2id, dataset_type='train',
                             opt=optSemEval)
dev_set = BERTDGLREDataset(optSemEval.dev_set, optSemEval.dev_set_save, optSemEval.word2id, {"None":0}, optSemEval.rel2id, dataset_type='dev',
                           instance_in_train=train_set.instance_in_train, opt=optSemEval)
# dataloaders
train_loader = DGLREDataloader(train_set, batch_size=optSemEval.batch_size, shuffle=True,
                               negativa_alpha=optSemEval.negativa_alpha)
dev_loader = DGLREDataloader(dev_set, batch_size=optSemEval.test_batch_size, dataset_type='dev')

Reading data from ../SemEval2DocRED/train_annotated.json.
load preprocessed data from ../SemEval2DocRED/train_BERT.pkl.
Reading data from ../SemEval2DocRED/dev.json.
../PLM/bert-base-uncased
loading..
finish reading ../SemEval2DocRED/dev.json and save preprocessed data to ../SemEval2DocRED/dev_BERT.pkl.


In [62]:
model_semEval = GAIN_BERT(optSemEval)


In [67]:
model_semEval.bert = model_original.bert

for child in model_semEval.bert.children():
    for param in child.parameters():
        param.requires_grad = False

In [68]:
train(model_semEval, train_loader)

tensor([[[ 0.0820,  0.0555,  0.1060,  ..., -0.1069, -0.2659,  0.2644],
         [-0.3910, -0.1363, -0.2390,  ...,  0.1367, -0.2156,  0.2055],
         [-0.8999, -0.0288, -0.0096,  ...,  0.1641, -0.0861,  0.4132],
         ...,
         [-1.1530, -0.0763,  0.2963,  ...,  0.4006, -0.6051,  1.2431],
         [-0.0294,  0.0123, -0.9427,  ..., -0.1362, -0.7850,  0.4324],
         [-0.4760,  0.3644, -0.7466,  ..., -0.0844,  0.3149, -0.0189]],

        [[ 0.1789,  0.4055,  0.0378,  ..., -0.0627, -0.3748,  0.2635],
         [-0.0723, -0.4669, -0.1513,  ...,  0.0471,  0.1920,  0.3797],
         [ 0.2763,  0.4738,  0.2147,  ...,  0.7542, -1.3142,  0.2215],
         ...,
         [-0.3469,  0.5495,  0.1965,  ...,  0.7479,  0.0550, -0.0234],
         [ 0.1589,  0.3491, -0.0970,  ...,  0.0991, -0.3902,  0.2634],
         [ 0.9001, -0.4232, -0.0416,  ...,  1.1210, -0.0956, -0.2757]],

        [[-0.2735,  0.0868, -0.0593,  ...,  0.2635, -0.0318,  0.2504],
         [-0.1859, -0.1550, -0.0871,  ...,  0