# A Novel Cascade Binary Tagging Framework for Relational Triple Extraction ACL 2020

# 1前言

### 1,1课程回顾

<img src='imgs/overall_for_code.png' width="800" height="800" align="bottom">

### 1.2 模型结构

<img src="./imgs/casrel.png"  width="600" height="600" align="bottom" />

### 1.3 代码结构展示
<img src="./imgs/directory.jpg"  width="300" height="300" align="bottom" />

# 2 准备工作
### 2.1项目环境配置

* Python3.8
* jupyter notebook
* torch            1.6.0+cu10.2
* numpy            1.18.5
* transformers       3.4.0

代码运行环境建议使用Visual Studio Code(VScode)

### 2.2 数据集下载
NYT数据集下载地址：https://drive.google.com/file/d/10f24s9gM7NdyO3z5OqQxJgYud4NnCJg3/view <br> 
WEBNLG数据集下载地址：https://drive.google.com/file/d/1zISxYa-8ROe2Zv8iRc82jY9QsQrfY1Vj/view <br>

# 3 项目代码结构（VScode中演示）

>1）是什么？

　　我们首先会在VScode环境中让代码跑一下，直观感受到项目的训练，并展示前向推断的输出，让大家看到模型的效果。
>2）怎么构成的？

　　然后介绍项目代码的构成，介绍项目有哪些文件夹，包含哪些文件，这些文件构成了什么功能模块如：数据预处理模块，模型设计模块，损失函数模块，推断与评估模块。
>3）小结

　　在主文件中在过一下启动训练的流程。

# 4 算法模块及细节（jupyter和VScode中演示）

　　在jupyter notebook中细致地讲解每一个模块。
  
　　以实现模块功能为目的，来讲解每个函数的执行流程，呈现中间数据，方便同学们理解学习。
  
　　内容分为以下几个模块：**超参数设置，数据读取，数据预处理，模型训练，模型评价**。

### 4.1 超参数设置

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import config
import argparse
import torch
import numpy as np
import random


seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='Casrel', help='name of the model')
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--multi_gpu', type=bool, default=False)
parser.add_argument('--dataset', type=str, default='NYT')
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--max_epoch', type=int, default=15) #300
parser.add_argument('--test_epoch', type=int, default=1)
parser.add_argument('--train_prefix', type=str, default='train_triples')
parser.add_argument('--dev_prefix', type=str, default='dev_triples')
parser.add_argument('--test_prefix', type=str, default='test_triples')
parser.add_argument('--max_len', type=int, default=150)
parser.add_argument('--rel_num', type=int, default=44)
parser.add_argument('--period', type=int, default=50)
parser.add_argument('--debug', type=bool, default=False)
args = parser.parse_args(args=[])

con = config.Config(args)

In [2]:
def print_config(con):
    for key in con.__dict__:
        print(key, end=' = ')
        print(con.__dict__[key])

In [3]:
print_config(con)

args = Namespace(batch_size=6, dataset='NYT', debug=False, dev_prefix='dev_triples', lr=1e-05, max_epoch=15, max_len=150, model_name='Casrel', multi_gpu=False, period=50, rel_num=44, test_epoch=1, test_prefix='test_triples', train_prefix='train_triples')
multi_gpu = False
learning_rate = 1e-05
batch_size = 6
max_epoch = 15
max_len = 150
rel_num = 44
dataset = NYT
root = /home/niuhao/project/RE/CasRel_2020
data_path = /home/niuhao/project/RE/CasRel_2020/data/NYT
checkpoint_dir = /home/niuhao/project/RE/CasRel_2020/checkpoint/NYT
log_dir = /home/niuhao/project/RE/CasRel_2020/log/NYT
result_dir = /home/niuhao/project/RE/CasRel_2020/result/NYT
train_prefix = train_triples
dev_prefix = dev_triples
test_prefix = test_triples
model_save_name = Casrel_DATASET_NYT_LR_1e-05_BS_6
log_save_name = LOG_Casrel_DATASET_NYT_LR_1e-05_BS_6
result_save_name = RESULT_Casrel_DATASET_NYT_LR_1e-05_BS_6.json
period = 50
test_epoch = 1
debug = False


### 4.2 数据读取

In [7]:
import pickle
import json

In [5]:
train_data = pickle.load(open(os.path.join(con.data_path, con.train_prefix + '.pkl'), 'rb'))

In [9]:
rel2id = json.load(open(os.path.join(con.data_path, 'rel2id.json')))[1]

In [10]:
len(train_data)

56195

In [11]:
train_data[0]

{'text': 'Massachusetts ASTON MAGNA Great Barrington ; also at Bard College , Annandale-on-Hudson , N.Y. , July 1-Aug .',
 'triple_list': [['Annandale-on-Hudson',
   '/location/location/contains',
   'College']]}

In [40]:
text = train_data[1]['text']
text

'North Carolina EASTERN MUSIC FESTIVAL Greensboro , June 25-July 30 .'

In [41]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [42]:
text = ' '.join(text.split())
text

'North Carolina EASTERN MUSIC FESTIVAL Greensboro , June 25-July 30 .'

In [43]:
tokens = tokenizer.tokenize(text)
tokens

['north',
 'carolina',
 'eastern',
 'music',
 'festival',
 'greensboro',
 ',',
 'june',
 '25',
 '-',
 'july',
 '30',
 '.']

In [20]:
s2ro_map = {}

In [27]:
train_data[6]['triple_list']

[['Weiner', '/people/person/place_lived', 'Queens'],
 ['Weiner', '/people/person/place_lived', 'Brooklyn']]

In [35]:
triple = train_data[1]['triple_list'][0]
triple

['Carolina', '/location/location/contains', 'Greensboro']

In [38]:
tokenizer.tokenize(triple[0])

['carolina']

In [39]:
triple = (tokenizer.tokenize(triple[0]), triple[1], tokenizer.tokenize(triple[2]))
triple

(['carolina'], '/location/location/contains', ['greensboro'])

In [45]:
source = tokens
target = triple[0]
target_len = len(target)
for i in range(len(source)):
    if source[i: i + target_len] == target:
        print(i)

1


In [46]:
def find_head_idx(source, target):
    target_len = len(target)
    for i in range(len(source)):
        if source[i: i + target_len] == target:
            return i
    return -1

In [48]:
sub_head_idx = find_head_idx(tokens, triple[0])
print(sub_head_idx)
obj_head_idx = find_head_idx(tokens, triple[2])
obj_head_idx

1


5

In [49]:
sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
sub

(1, 1)

In [51]:
if sub not in s2ro_map:
    s2ro_map[sub] = []
s2ro_map

{(1, 1): []}

In [52]:
s2ro_map[sub].append((obj_head_idx, obj_head_idx + len(triple[2]) - 1, rel2id[triple[1]]))
s2ro_map

{(1, 1): [(5, 5, 22)]}

In [56]:
import transformers

In [57]:
transformers.__version__

'2.8.0'

In [60]:
token_ids = tokenizer.encode(text)
token_ids

[101,
 2167,
 3792,
 2789,
 2189,
 2782,
 27905,
 1010,
 2238,
 2423,
 1011,
 2251,
 2382,
 1012,
 102]

In [61]:
token_ids = np.array(token_ids)
token_ids

array([  101,  2167,  3792,  2789,  2189,  2782, 27905,  1010,  2238,
        2423,  1011,  2251,  2382,  1012,   102])

In [64]:
text_len = len(tokens)
text_len

13

In [65]:
sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)

In [66]:
s2ro_map

{(1, 1): [(5, 5, 22)]}

In [103]:
s = [s for s in s2ro_map][0]

In [104]:
sub_heads[s[0]] = 1

In [105]:
sub_tails[s[1]] = 1

In [106]:
sub_heads

array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [107]:
sub_tails

array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [109]:
from random import choice

In [111]:
obj_heads, obj_tails = np.zeros((text_len, con.rel_num)), np.zeros((text_len, con.rel_num))

In [113]:
con.rel_num

44

In [114]:
obj_heads.shape

(13, 44)

In [115]:
sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys()))

In [118]:
sub_head_idx

1

In [116]:
for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []):
    obj_heads[ro[0]][ro[2]] = 1
    obj_tails[ro[1]][ro[2]] = 1

In [121]:
obj_tails[5]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

### 4.3.3 将上述功能构建为类
* 将相似类型的功能整合到一个类里，使代码结构更清晰，并且方面以后调佣，且有利于debug

### 4.3 定义模型

In [4]:
import transformers
transformers.__version__



'3.4.0'

In [13]:
from transformers import BertModel
import torch
import torch.nn as nn
import torch.optim as optim
import os
import data_loader
import torch.nn.functional as F
import numpy as np
import json
import time

In [7]:
class Casrel(nn.Module):
    def __init__(self, config):
        super(Casrel, self).__init__()
        self.config = config
        self.bert_dim = 768
        self.bert_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.sub_heads_linear = nn.Linear(self.bert_dim, 1)
        self.sub_tails_linear = nn.Linear(self.bert_dim, 1)
        self.obj_heads_linear = nn.Linear(self.bert_dim, self.config.rel_num)
        self.obj_tails_linear = nn.Linear(self.bert_dim, self.config.rel_num)

    def get_objs_for_specific_sub(self, sub_head_mapping, sub_tail_mapping, encoded_text):
        # [batch_size, 1, bert_dim]
        sub_head = torch.matmul(sub_head_mapping, encoded_text)
        # [batch_size, 1, bert_dim]
        sub_tail = torch.matmul(sub_tail_mapping, encoded_text)
        # [batch_size, 1, bert_dim]
        sub = (sub_head + sub_tail) / 2
        # [batch_size, seq_len, bert_dim]
        encoded_text = encoded_text + sub
        # [batch_size, seq_len, rel_num]
        pred_obj_heads = self.obj_heads_linear(encoded_text)
        pred_obj_heads = torch.sigmoid(pred_obj_heads)
        # [batch_size, seq_len, rel_num]
        pred_obj_tails = self.obj_tails_linear(encoded_text)
        pred_obj_tails = torch.sigmoid(pred_obj_tails)
        return pred_obj_heads, pred_obj_tails

    def get_encoded_text(self, token_ids, mask):
        # [batch_size, seq_len, bert_dim(768)]
        encoded_text = self.bert_encoder(token_ids, attention_mask=mask)[0]
        return encoded_text

    def get_subs(self, encoded_text):
        # [batch_size, seq_len, 1]
        pred_sub_heads = self.sub_heads_linear(encoded_text)
        pred_sub_heads = torch.sigmoid(pred_sub_heads)
        # [batch_size, seq_len, 1]
        pred_sub_tails = self.sub_tails_linear(encoded_text)
        pred_sub_tails = torch.sigmoid(pred_sub_tails)
        return pred_sub_heads, pred_sub_tails

    def forward(self, data):
        # [batch_size, seq_len]
        token_ids = data['token_ids']
        # [batch_size, seq_len]
        mask = data['mask']
        # [batch_size, seq_len, bert_dim(768)]
        encoded_text = self.get_encoded_text(token_ids, mask)
        # [batch_size, seq_len, 1]
        pred_sub_heads, pred_sub_tails = self.get_subs(encoded_text)
        # [batch_size, 1, seq_len]
        sub_head_mapping = data['sub_head'].unsqueeze(1)
        # [batch_size, 1, seq_len]
        sub_tail_mapping = data['sub_tail'].unsqueeze(1)
        # [batch_size, seq_len, rel_num]
        pred_obj_heads, pred_obj_tails = self.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, encoded_text)
        return pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails

In [8]:
ori_model = Casrel(con)
ori_model.cuda()

Casrel(
  (bert_encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tru

In [12]:
# define the optimizer
optimizer = optim.Adam(filter(lambda p: p.requires_grad, ori_model.parameters()), lr=con.learning_rate)

# whether use multi GPU
if con.multi_gpu:
    model = nn.DataParallel(ori_model)
else:
    model = ori_model

### 4.4 模型训练

In [5]:
# define the loss function
def loss(gold, pred, mask):
    pred = pred.squeeze(-1)
    los = F.binary_cross_entropy(pred, gold, reduction='none')
    if los.shape != mask.shape:
        mask = mask.unsqueeze(-1)
    los = torch.sum(los * mask) / torch.sum(mask)
    return loss

In [16]:
# check the checkpoint dir
if not os.path.exists(con.checkpoint_dir):
    os.mkdir(con.checkpoint_dir)

# check the log dir
if not os.path.exists(con.log_dir):
    os.mkdir(con.log_dir)

# get the data loader
train_data_loader = data_loader.get_loader(con, prefix=con.train_prefix)
dev_data_loader = data_loader.get_loader(con, prefix=con.dev_prefix, is_test=True)


In [None]:
model.train()
global_step = 0
loss_sum = 0

best_f1_score = 0
best_precision = 0
best_recall = 0

best_epoch = 0
init_time = time.time()
start_time = time.time()

# the training loop
for epoch in range(self.config.max_epoch):
    train_data_prefetcher = data_loader.DataPreFetcher(train_data_loader)
    data = train_data_prefetcher.next()
    while data is not None:
        pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails = model(data)

        sub_heads_loss = loss(data['sub_heads'], pred_sub_heads, data['mask'])
        sub_tails_loss = loss(data['sub_tails'], pred_sub_tails, data['mask'])
        obj_heads_loss = loss(data['obj_heads'], pred_obj_heads, data['mask'])
        obj_tails_loss = loss(data['obj_tails'], pred_obj_tails, data['mask'])
        total_loss = (sub_heads_loss + sub_tails_loss) + (obj_heads_loss + obj_tails_loss)

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        global_step += 1
        loss_sum += total_loss.item()

        if global_step % self.config.period == 0:
            cur_loss = loss_sum / self.config.period
            elapsed = time.time() - start_time
            self.logging("epoch: {:3d}, step: {:4d}, speed: {:5.2f}ms/b, train loss: {:5.3f}".
                         format(epoch, global_step, elapsed * 1000 / self.config.period, cur_loss))
            loss_sum = 0
            start_time = time.time()

        data = train_data_prefetcher.next()

    if (epoch + 1) % self.config.test_epoch == 0:
        eval_start_time = time.time()
        model.eval()
        # call the test function
        precision, recall, f1_score = self.test(dev_data_loader, model)
        model.train()
        self.logging('epoch {:3d}, eval time: {:5.2f}s, f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}'.
                     format(epoch, time.time() - eval_start_time, f1_score, precision, recall))

        if f1_score > best_f1_score:
            best_f1_score = f1_score
            best_epoch = epoch
            best_precision = precision
            best_recall = recall
            self.logging("saving the model, epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}".
                         format(best_epoch, best_f1_score, precision, recall))
            # save the best model
            path = os.path.join(self.config.checkpoint_dir, self.config.model_save_name)
            if not self.config.debug:
                torch.save(ori_model.state_dict(), path)

    # manually release the unused cache
    torch.cuda.empty_cache()

self.logging("finish training")
self.logging("best epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2}, total time: {:5.2f}s".
             format(best_epoch, best_f1_score, best_precision, best_recall, time.time() - init_time))

### 4.5 模型评价
现在假设训练好了一个模型（或者模型训练了一个epoch），我们想看看模型现在的性能，那么就需要对模型进行评价

**目录**
* 模型预测
* 计算评价指标

In [None]:
def test(self, test_data_loader, model, output=False, h_bar=0.5, t_bar=0.5):

    if output:
        # check the result dir
        if not os.path.exists(self.config.result_dir):
            os.mkdir(self.config.result_dir)

        path = os.path.join(self.config.result_dir, self.config.result_save_name)

        fw = open(path, 'w')

    orders = ['subject', 'relation', 'object']

    def to_tup(triple_list):
        ret = []
        for triple in triple_list:
            ret.append(tuple(triple))
        return ret

    test_data_prefetcher = data_loader.DataPreFetcher(test_data_loader)
    data = test_data_prefetcher.next()
    id2rel = json.load(open(os.path.join(self.config.data_path, 'rel2id.json')))[0]
    correct_num, predict_num, gold_num = 0, 0, 0

    while data is not None:
        with torch.no_grad():
            token_ids = data['token_ids']
            tokens = data['tokens'][0]
            mask = data['mask']
            encoded_text = model.get_encoded_text(token_ids, mask)
            pred_sub_heads, pred_sub_tails = model.get_subs(encoded_text)
            sub_heads, sub_tails = np.where(pred_sub_heads.cpu()[0] > h_bar)[0], np.where(pred_sub_tails.cpu()[0] > t_bar)[0]
            subjects = []
            for sub_head in sub_heads:
                sub_tail = sub_tails[sub_tails >= sub_head]
                if len(sub_tail) > 0:
                    sub_tail = sub_tail[0]
                    subject = tokens[sub_head: sub_tail]
                    subjects.append((subject, sub_head, sub_tail))
            if subjects:
                triple_list = []
                # [subject_num, seq_len, bert_dim]
                repeated_encoded_text = encoded_text.repeat(len(subjects), 1, 1)
                # [subject_num, 1, seq_len]
                sub_head_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
                sub_tail_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
                for subject_idx, subject in enumerate(subjects):
                    sub_head_mapping[subject_idx][0][subject[1]] = 1
                    sub_tail_mapping[subject_idx][0][subject[2]] = 1
                sub_tail_mapping = sub_tail_mapping.to(repeated_encoded_text)
                sub_head_mapping = sub_head_mapping.to(repeated_encoded_text)
                pred_obj_heads, pred_obj_tails = model.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, repeated_encoded_text)
                for subject_idx, subject in enumerate(subjects):
                    sub = subject[0]
                    sub = ''.join([i.lstrip("##") for i in sub])
                    sub = ' '.join(sub.split('[unused1]'))
                    obj_heads, obj_tails = np.where(pred_obj_heads.cpu()[subject_idx] > h_bar), np.where(pred_obj_tails.cpu()[subject_idx] > t_bar)
                    for obj_head, rel_head in zip(*obj_heads):
                        for obj_tail, rel_tail in zip(*obj_tails):
                            if obj_head <= obj_tail and rel_head == rel_tail:
                                rel = id2rel[str(int(rel_head))]
                                obj = tokens[obj_head: obj_tail]
                                obj = ''.join([i.lstrip("##") for i in obj])
                                obj = ' '.join(obj.split('[unused1]'))
                                triple_list.append((sub, rel, obj))
                                break
                triple_set = set()
                for s, r, o in triple_list:
                    triple_set.add((s, r, o))
                pred_list = list(triple_set)
            else:
                pred_list = []
            pred_triples = set(pred_list)
            gold_triples = set(to_tup(data['triples'][0]))

            correct_num += len(pred_triples & gold_triples)
            predict_num += len(pred_triples)
            gold_num += len(gold_triples)

            if output:
                result = json.dumps({
                    # 'text': ' '.join(tokens),
                    'triple_list_gold': [
                        dict(zip(orders, triple)) for triple in gold_triples
                    ],
                    'triple_list_pred': [
                        dict(zip(orders, triple)) for triple in pred_triples
                    ],
                    'new': [
                        dict(zip(orders, triple)) for triple in pred_triples - gold_triples
                    ],
                    'lack': [
                        dict(zip(orders, triple)) for triple in gold_triples - pred_triples
                    ]
                }, ensure_ascii=False)
                fw.write(result + '\n')

            data = test_data_prefetcher.next()

    print("correct_num: {:3d}, predict_num: {:3d}, gold_num: {:3d}".format(correct_num, predict_num, gold_num))

    precision = correct_num / (predict_num + 1e-10)
    recall = correct_num / (gold_num + 1e-10)
    f1_score = 2 * precision * recall / (precision + recall + 1e-10)
    return precision, recall, f1_score