# A Novel Cascade Binary Tagging Framework for Relational Triple Extraction

In [2]:
import os
from google.colab import drive
drive.mount('/content/gdrive')

os.chdir('/content/gdrive/MyDrive/CasRel')

Mounted at /content/gdrive


In [3]:
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import config
import time
import argparse
import torch
import numpy as np
import random
from random import choice
import pickle
import json
from torch.utils.data import DataLoader, Dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

! pip install transformers
import transformers
from transformers import BertModel,BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 4.7 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 72.8 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 86.1 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.10.1 tokenizers-0.13.2 transformers-4.24.0


Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

### 4.1 Hyperparameter Setup

In [4]:
seed = 1234
BERT_MAX_LEN = 512
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='Casrel', help='name of the model')
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--multi_gpu', type=bool, default=False)
parser.add_argument('--dataset', type=str, default='NYT')
parser.add_argument('--batch_size', type=int, default=6)  
parser.add_argument('--max_epoch', type=int, default=300)
parser.add_argument('--test_epoch', type=int, default=1)
parser.add_argument('--train_prefix', type=str, default='train_triples')
parser.add_argument('--dev_prefix', type=str, default='dev_triples')
parser.add_argument('--test_prefix', type=str, default='test_triples')
parser.add_argument('--max_len', type=int, default=100)
parser.add_argument('--rel_num', type=int, default=25)
parser.add_argument('--period', type=int, default=1000)
parser.add_argument('--debug', type=bool, default=False)
args = parser.parse_args(args=[])

con = config.Config(args)

In [5]:
con.__dict__

{'args': Namespace(batch_size=6, dataset='NYT', debug=False, dev_prefix='dev_triples', lr=1e-05, max_epoch=300, max_len=100, model_name='Casrel', multi_gpu=False, period=1000, rel_num=25, test_epoch=1, test_prefix='test_triples', train_prefix='train_triples'),
 'multi_gpu': False,
 'learning_rate': 1e-05,
 'batch_size': 6,
 'max_epoch': 300,
 'max_len': 100,
 'rel_num': 25,
 'dataset': 'NYT',
 'root': '/content/gdrive/MyDrive/CasRel',
 'data_path': '/content/gdrive/MyDrive/CasRel/data/NYT',
 'checkpoint_dir': '/content/gdrive/MyDrive/CasRel/checkpoint/NYT',
 'log_dir': '/content/gdrive/MyDrive/CasRel/log/NYT',
 'result_dir': '/content/gdrive/MyDrive/CasRel/result/NYT',
 'train_prefix': 'train_triples',
 'dev_prefix': 'dev_triples',
 'test_prefix': 'test_triples',
 'model_save_name': 'Casrel_DATASET_NYT_LR_1e-05_BS_6',
 'log_save_name': 'LOG_Casrel_DATASET_NYT_LR_1e-05_BS_6',
 'result_save_name': 'RESULT_Casrel_DATASET_NYT_LR_1e-05_BS_6.json',
 'period': 1000,
 'test_epoch': 1,
 'debug'

# Sequence Framework Functions


In [6]:
def find_head_idx(source, target):
    target_len = len(target)
    for i in range(len(source)):
        if source[i: i + target_len] == target:
            return i
    return -1

In [7]:
id2rel = {}
rel2id = json.load(open(os.path.join(con.data_path, 'rel2id.json')))
for t in rel2id.keys():
  id2rel[str(rel2id[t])] = t
id2rel

js_object = json.dumps(id2rel, indent=4)
with open('id2rel.json','w') as f:
  f.write(js_object)

In [8]:
class CMEDDataset(Dataset):
    def __init__(self, config, prefix, is_test, tokenizer):
        self.config = config
        self.prefix = prefix
        self.is_test = is_test
        self.tokenizer = tokenizer
        if self.config.debug:
            self.json_data = pickle.load(open(os.path.join(self.config.data_path, prefix + '.pkl'), 'rb'))[:500]
        else:
            self.json_data = pickle.load(open(os.path.join(self.config.data_path, prefix + '.pkl'), 'rb'))
        self.rel2id = json.load(open(os.path.join(con.data_path, 'rel2id.json')))

    def __len__(self):
        return len(self.json_data)

    def __getitem__(self, idx):
        ins_json_data = self.json_data[idx]
        text = ins_json_data['text']
        text = ' '.join(text.split()[:self.config.max_len])
        tokens = self.tokenizer.tokenize(text)
        if len(tokens) > BERT_MAX_LEN:
            tokens = tokens[: BERT_MAX_LEN]
        text_len = len(tokens)

        if not self.is_test:
            s2ro_map = {}
            for triple in ins_json_data['triple_list']:
                triple = (self.tokenizer.tokenize(triple[0]), triple[1], self.tokenizer.tokenize(triple[2]))
                sub_head_idx = find_head_idx(tokens, triple[0])
                obj_head_idx = find_head_idx(tokens, triple[2])
                if sub_head_idx != -1 and obj_head_idx != -1:
                    sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
                    if sub not in s2ro_map:
                        s2ro_map[sub] = []
                    s2ro_map[sub].append((obj_head_idx, obj_head_idx + len(triple[2]) - 1, self.rel2id[triple[1]]))

            if s2ro_map:
                # token_ids, segment_ids = self.tokenizer.encode(first=text)
                tokenizer_out = self.tokenizer(text)
                token_ids = tokenizer_out['input_ids']
                segment_ids = tokenizer_out['attention_mask']
                masks = segment_ids
                if len(token_ids) > text_len:
                    token_ids = token_ids[:text_len]
                    masks = masks[:text_len]
                token_ids = np.array(token_ids)
                masks = np.array(masks) + 1
                sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)
                for s in s2ro_map:
                    sub_heads[s[0]] = 1
                    sub_tails[s[1]] = 1
                sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys()))
                sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len)
                sub_head[sub_head_idx] = 1
                sub_tail[sub_tail_idx] = 1
                obj_heads, obj_tails = np.zeros((text_len, self.config.rel_num)), np.zeros((text_len, self.config.rel_num))
                for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []):
                    obj_heads[ro[0]][ro[2]] = 1
                    obj_tails[ro[1]][ro[2]] = 1
                return token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, ins_json_data['triple_list'], tokens
            else:
                return None
        else:
            # token_ids, segment_ids = self.tokenizer.encode(first=text)
            tokenizer_out = self.tokenizer(text)
            token_ids = tokenizer_out['input_ids']
            segment_ids = tokenizer_out['attention_mask']
            masks = segment_ids
            if len(token_ids) > text_len:
                token_ids = token_ids[:text_len]
                masks = masks[:text_len]
            token_ids = np.array(token_ids)
            masks = np.array(masks) + 1
            sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)
            sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len)
            obj_heads, obj_tails = np.zeros((text_len, self.config.rel_num)), np.zeros((text_len, self.config.rel_num))
            return token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, ins_json_data['triple_list'], tokens

In [9]:
def cmed_collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    batch.sort(key=lambda x: x[2], reverse=True)
    token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, triples, tokens = zip(*batch)
    cur_batch = len(batch)
    max_text_len = max(text_len)
    batch_token_ids = torch.LongTensor(cur_batch, max_text_len).zero_()
    batch_masks = torch.LongTensor(cur_batch, max_text_len).zero_()
    batch_sub_heads = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_sub_tails = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_sub_head = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_sub_tail = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_obj_heads = torch.Tensor(cur_batch, max_text_len, con.rel_num).zero_()
    batch_obj_tails = torch.Tensor(cur_batch, max_text_len, con.rel_num).zero_()

    for i in range(cur_batch):
        batch_token_ids[i, :text_len[i]].copy_(torch.from_numpy(token_ids[i]))
        batch_masks[i, :text_len[i]].copy_(torch.from_numpy(masks[i]))
        batch_sub_heads[i, :text_len[i]].copy_(torch.from_numpy(sub_heads[i]))
        batch_sub_tails[i, :text_len[i]].copy_(torch.from_numpy(sub_tails[i]))
        batch_sub_head[i, :text_len[i]].copy_(torch.from_numpy(sub_head[i]))
        batch_sub_tail[i, :text_len[i]].copy_(torch.from_numpy(sub_tail[i]))
        batch_obj_heads[i, :text_len[i], :].copy_(torch.from_numpy(obj_heads[i]))
        batch_obj_tails[i, :text_len[i], :].copy_(torch.from_numpy(obj_tails[i]))

    return {'token_ids': batch_token_ids,
            'mask': batch_masks,
            'sub_heads': batch_sub_heads,
            'sub_tails': batch_sub_tails,
            'sub_head': batch_sub_head,
            'sub_tail': batch_sub_tail,
            'obj_heads': batch_obj_heads,
            'obj_tails': batch_obj_tails,
            'triples': triples,
            'tokens': tokens}

In [10]:
def get_loader(config, prefix, is_test=False, num_workers=6, collate_fn=cmed_collate_fn):
    dataset = CMEDDataset(config, prefix, is_test, tokenizer)
    if not is_test:
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 collate_fn=collate_fn)
    else:
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 collate_fn=collate_fn)
    return data_loader

In [11]:
class DataPreFetcher(object):
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_data = next(self.loader)
        except StopIteration:
            self.next_data = None
            return
        with torch.cuda.stream(self.stream):
            for k, v in self.next_data.items():
                if isinstance(v, torch.Tensor):
                    self.next_data[k] = self.next_data[k].cuda(non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        data = self.next_data
        self.preload()
        return data

# Model setup

In [12]:
class Casrel(nn.Module):
    def __init__(self, config):
        super(Casrel, self).__init__()
        self.config = config
        self.bert_dim = 768
        self.bert_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.sub_heads_linear = nn.Linear(self.bert_dim, 1)
        self.sub_tails_linear = nn.Linear(self.bert_dim, 1)
        self.obj_heads_linear = nn.Linear(self.bert_dim, self.config.rel_num)
        self.obj_tails_linear = nn.Linear(self.bert_dim, self.config.rel_num)

    def get_objs_for_specific_sub(self, sub_head_mapping, sub_tail_mapping, encoded_text):
        # [batch_size, 1, bert_dim]
        sub_head = torch.matmul(sub_head_mapping, encoded_text)
        # [batch_size, 1, bert_dim]
        sub_tail = torch.matmul(sub_tail_mapping, encoded_text)
        # [batch_size, 1, bert_dim]
        sub = (sub_head + sub_tail) / 2
        # [batch_size, seq_len, bert_dim]
        encoded_text = encoded_text + sub
        # [batch_size, seq_len, rel_num]
        pred_obj_heads = self.obj_heads_linear(encoded_text)
        pred_obj_heads = torch.sigmoid(pred_obj_heads)
        # [batch_size, seq_len, rel_num]
        pred_obj_tails = self.obj_tails_linear(encoded_text)
        pred_obj_tails = torch.sigmoid(pred_obj_tails)
        return pred_obj_heads, pred_obj_tails

    def get_encoded_text(self, token_ids, mask):
        # [batch_size, seq_len, bert_dim(768)]
        encoded_text = self.bert_encoder(token_ids, attention_mask=mask)[0]
        return encoded_text

    def get_subs(self, encoded_text):
        # [batch_size, seq_len, 1]
        pred_sub_heads = self.sub_heads_linear(encoded_text)
        pred_sub_heads = torch.sigmoid(pred_sub_heads)
        # [batch_size, seq_len, 1]
        pred_sub_tails = self.sub_tails_linear(encoded_text)
        pred_sub_tails = torch.sigmoid(pred_sub_tails)
        return pred_sub_heads, pred_sub_tails

    def forward(self, data):
        # [batch_size, seq_len]
        token_ids = data['token_ids']
        # [batch_size, seq_len]
        mask = data['mask']
        # [batch_size, seq_len, bert_dim(768)]
        encoded_text = self.get_encoded_text(token_ids, mask)
        # [batch_size, seq_len, 1]
        pred_sub_heads, pred_sub_tails = self.get_subs(encoded_text)
        # [batch_size, 1, seq_len]
        sub_head_mapping = data['sub_head'].unsqueeze(1)
        # [batch_size, 1, seq_len]
        sub_tail_mapping = data['sub_tail'].unsqueeze(1)
        # [batch_size, seq_len, rel_num]
        pred_obj_heads, pred_obj_tails = self.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, encoded_text)
        return pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ori_model = Casrel(con)
ori_model.to(device)

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Casrel(
  (bert_encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tru

In [14]:
# define the optimizer
optimizer = optim.Adam(filter(lambda p: p.requires_grad, ori_model.parameters()), lr=con.learning_rate)

# whether use multi GPU
if con.multi_gpu:
    model = nn.DataParallel(ori_model)
else:
    model = ori_model

# define the loss function
def loss(gold, pred, mask):
    pred = pred.squeeze(-1)
    los = F.binary_cross_entropy(pred, gold, reduction='none')
    if los.shape != mask.shape:
        mask = mask.unsqueeze(-1)
    los = torch.sum(los * mask) / torch.sum(mask)
    return los


def logging(s, print_=True, log_=True):
    if print_:
        print(s)
    if log_:
        with open(os.path.join(con.log_dir, con.log_save_name), 'a+') as f_log:
            f_log.write(s + '\n')

In [18]:
id2rel

{'1': '/business/company/founders',
 '2': '/people/person/place_of_birth',
 '3': '/people/deceased_person/place_of_death',
 '4': '/business/company_shareholder/major_shareholder_of',
 '5': '/people/ethnicity/people',
 '6': '/location/neighborhood/neighborhood_of',
 '7': '/sports/sports_team/location',
 '9': '/business/company/industry',
 '10': '/business/company/place_founded',
 '11': '/location/administrative_division/country',
 '0': 'None',
 '12': '/sports/sports_team_location/teams',
 '13': '/people/person/nationality',
 '14': '/people/person/religion',
 '15': '/business/company/advisors',
 '16': '/people/person/ethnicity',
 '17': '/people/ethnicity/geographic_distribution',
 '8': '/business/person/company',
 '19': '/business/company/major_shareholders',
 '18': '/people/person/place_lived',
 '20': '/people/person/profession',
 '21': '/location/country/capital',
 '22': '/location/location/contains',
 '23': '/location/country/administrative_divisions',
 '24': '/people/person/children'

In [23]:
def test(test_data_loader, model, output=False, h_bar=0.5, t_bar=0.5):

    if output:
        # check the result dir
        if not os.path.exists(con.result_dir):
            os.mkdir(con.result_dir)

        path = os.path.join(con.result_dir, con.result_save_name)

        fw = open(path, 'w')

    orders = ['subject', 'relation', 'object']

    def to_tup(triple_list):
        ret = []
        for triple in triple_list:
            ret.append(tuple(triple))
        return ret

    test_data_prefetcher = DataPreFetcher(test_data_loader)
    data = test_data_prefetcher.next()
    id2rel = json.load(open(os.path.join(con.data_path, 'rel2id.json')))
    correct_num, predict_num, gold_num = 0, 0, 0

    while data is not None:
        with torch.no_grad():
            token_ids = data['token_ids']
            tokens = data['tokens'][0]
            mask = data['mask']
            encoded_text = model.get_encoded_text(token_ids, mask)
            pred_sub_heads, pred_sub_tails = model.get_subs(encoded_text)
            print('sub_heads:',pred_sub_heads)
            print('sub_tails:',pred_sub_tails)
            sub_heads, sub_tails = np.where(pred_sub_heads.cpu()[0] > h_bar)[0], np.where(pred_sub_tails.cpu()[0] > t_bar)[0]
            subjects = []
            for sub_head in sub_heads:
                print('for sub head in sub heads')
                sub_tail = sub_tails[sub_tails >= sub_head]
                if len(sub_tail) > 0:
                    sub_tail = sub_tail[0]
                    subject = tokens[sub_head: sub_tail]
                    subjects.append((subject, sub_head, sub_tail))
                    print('subject appended.')
            if subjects:
                print('subject if statement')
                triple_list = []
                # [subject_num, seq_len, bert_dim]
                repeated_encoded_text = encoded_text.repeat(len(subjects), 1, 1)
                # [subject_num, 1, seq_len]
                sub_head_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
                sub_tail_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
                for subject_idx, subject in enumerate(subjects):
                    sub_head_mapping[subject_idx][0][subject[1]] = 1
                    sub_tail_mapping[subject_idx][0][subject[2]] = 1
                sub_tail_mapping = sub_tail_mapping.to(repeated_encoded_text)
                sub_head_mapping = sub_head_mapping.to(repeated_encoded_text)
                pred_obj_heads, pred_obj_tails = model.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, repeated_encoded_text)
                for subject_idx, subject in enumerate(subjects):
                    sub = subject[0]
                    sub = ''.join([i.lstrip("##") for i in sub])
                    sub = ' '.join(sub.split('[unused1]'))
                    obj_heads, obj_tails = np.where(pred_obj_heads.cpu()[subject_idx] > h_bar), np.where(pred_obj_tails.cpu()[subject_idx] > t_bar)
                    for obj_head, rel_head in zip(*obj_heads):
                        for obj_tail, rel_tail in zip(*obj_tails):
                            if obj_head <= obj_tail and rel_head == rel_tail:
                                rel = id2rel[str(int(rel_head))]
                                obj = tokens[obj_head: obj_tail]
                                obj = ''.join([i.lstrip("##") for i in obj])
                                obj = ' '.join(obj.split('[unused1]'))
                                print('pred subject:', sub)
                                print('pred relation:',rel)
                                print('pred object:', obj)
                                triple_list.append((sub, rel, obj))
                                break
                triple_set = set()
                for s, r, o in triple_list:
                    triple_set.add((s, r, o))
                pred_list = list(triple_set)
            else:
                print('subject else statement')
                pred_list = []
            pred_triples = set(pred_list)
            gold_triples = set(to_tup(data['triples'][0]))

            correct_num += len(pred_triples & gold_triples)
            predict_num += len(pred_triples)
            gold_num += len(gold_triples)

            if output:
                result = json.dumps({
                    # 'text': ' '.join(tokens),
                    'triple_list_gold': [
                        dict(zip(orders, triple)) for triple in gold_triples
                    ],
                    'triple_list_pred': [
                        dict(zip(orders, triple)) for triple in pred_triples
                    ],
                    'new': [
                        dict(zip(orders, triple)) for triple in pred_triples - gold_triples
                    ],
                    'lack': [
                        dict(zip(orders, triple)) for triple in gold_triples - pred_triples
                    ]
                }, ensure_ascii=False)
                fw.write(result + '\n')

            data = test_data_prefetcher.next()

    print("correct_num: {:3d}, predict_num: {:3d}, gold_num: {:3d}".format(correct_num, predict_num, gold_num))

    precision = correct_num / (predict_num + 1e-10)
    recall = correct_num / (gold_num + 1e-10)
    f1_score = 2 * precision * recall / (precision + recall + 1e-10)
    return precision, recall, f1_score

# Training

In [16]:
# check the checkpoint dir
if not os.path.exists(con.checkpoint_dir):
    os.mkdir(con.checkpoint_dir)

# check the log dir
if not os.path.exists(con.log_dir):
    os.mkdir(con.log_dir)

# get the data loader
#train_data_loader = data_loader.get_loader(con, prefix=con.train_prefix)
train_data_loader = get_loader(con, prefix=con.train_prefix)
dev_data_loader = get_loader(con, prefix=con.dev_prefix, is_test=True)


In [82]:
model.train()
global_step = 0
loss_sum = 0

best_f1_score = 0
best_precision = 0
best_recall = 0

best_epoch = 0
init_time = time.time()
start_time = time.time()

for epoch in range(con.max_epoch):
    train_data_prefetcher = DataPreFetcher(train_data_loader)
    data = train_data_prefetcher.next()

    while data is not None:
        pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails = model(data)

        sub_heads_loss = loss(data['sub_heads'], pred_sub_heads, data['mask'])
        sub_tails_loss = loss(data['sub_tails'], pred_sub_tails, data['mask'])
        obj_heads_loss = loss(data['obj_heads'], pred_obj_heads, data['mask'])
        obj_tails_loss = loss(data['obj_tails'], pred_obj_tails, data['mask'])
        total_loss = (sub_heads_loss + sub_tails_loss) + (obj_heads_loss + obj_tails_loss)

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        global_step += 1
        loss_sum += total_loss.item()

        if global_step % con.period == 0:
            cur_loss = loss_sum / con.period
            elapsed = time.time() - start_time
            #logging("epoch: {:3d}, step: {:4d}, speed: {:5.2f}ms/b, train loss: {:5.3f}".
                          format(epoch, global_step, elapsed * 1000 / con.period, cur_loss))
            loss_sum = 0
            start_time = time.time()

        data = train_data_prefetcher.next()
        


    if (epoch + 1) % con.test_epoch == 0:
        eval_start_time = time.time()
        model.eval()
        # call the test function
        precision, recall, f1_score = test(dev_data_loader, model)
        model.train()
        #logging('epoch {:3d}, eval time: {:5.2f}s, f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}'.
                      format(epoch, time.time() - eval_start_time, f1_score, precision, recall))

        if f1_score > best_f1_score:
            best_f1_score = f1_score
            best_epoch = epoch
            best_precision = precision
            best_recall = recall
            #logging("saving the model, epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}".
                          format(best_epoch, best_f1_score, precision, recall))
            # save the best model
            path = os.path.join(con.checkpoint_dir, con.model_save_name)
            if not con.debug:
                torch.save(ori_model.state_dict(), path)

            # manually release the unused cache
    torch.cuda.empty_cache()

print("finish training")
print("best epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2}, total time: {:5.2f}s".
              format(best_epoch, best_f1_score, best_precision, best_recall, time.time() - init_time))

finish training
best epoch: 217, best f1: 77.73, precision: 90.20, recall: 9.1e+01, total time: 208936.00s


# Testing

In [83]:
path = os.path.join(con.checkpoint_dir, con.model_save_name)
model = model.load_state_dict(torch.load(path))
model.to(device)
model.eval()
test_data_loader = get_loader(con, prefix=con.test_prefix, is_test=True)
precision, recall, f1_score = test(test_data_loader, model, True)
print("f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}".format(f1_score, precision, recall))

f1: 89.62, precision: 89.67, recall: 89.49
