# A Novel Cascade Binary Tagging Framework for Relational Triple Extraction ACL 2020

# 1前言

### 1,1课程回顾

<img src='imgs/overall_for_code.png' width="800" height="800" align="bottom">

### 1.2 模型结构

<img src="./imgs/casrel.png"  width="600" height="600" align="bottom" />

### 1.3 代码结构展示
<img src="./imgs/directory.jpg"  width="300" height="300" align="bottom" />

# 2 准备工作
### 2.1项目环境配置

* Python3.8
* jupyter notebook
* torch            1.6.0+cu10.2
* numpy            1.18.5
* transformers       3.4.0

代码运行环境建议使用Visual Studio Code(VScode)

### 2.2 数据集下载
NYT数据集下载地址：https://drive.google.com/file/d/10f24s9gM7NdyO3z5OqQxJgYud4NnCJg3/view <br> 
WEBNLG数据集下载地址：https://drive.google.com/file/d/1zISxYa-8ROe2Zv8iRc82jY9QsQrfY1Vj/view <br>

# 3 项目代码结构（VScode中演示）

>1）是什么？

　　我们首先会在VScode环境中让代码跑一下，直观感受到项目的训练，并展示前向推断的输出，让大家看到模型的效果。
>2）怎么构成的？

　　然后介绍项目代码的构成，介绍项目有哪些文件夹，包含哪些文件，这些文件构成了什么功能模块如：数据预处理模块，模型设计模块，损失函数模块，推断与评估模块。
>3）小结

　　在主文件中在过一下启动训练的流程。

# 4 算法模块及细节（jupyter和VScode中演示）

　　在jupyter notebook中细致地讲解每一个模块。
  
　　以实现模块功能为目的，来讲解每个函数的执行流程，呈现中间数据，方便同学们理解学习。
  
　　内容分为以下几个模块：**超参数设置，数据读取与处理，模型定义，模型训练，模型评价**。

### 4.1 超参数设置

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import config
import argparse
import torch
import numpy as np
import random


seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='Casrel', help='name of the model')
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--multi_gpu', type=bool, default=False)
parser.add_argument('--dataset', type=str, default='NYT')
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--max_epoch', type=int, default=15) #300
parser.add_argument('--test_epoch', type=int, default=1)
parser.add_argument('--train_prefix', type=str, default='train_triples')
parser.add_argument('--dev_prefix', type=str, default='dev_triples')
parser.add_argument('--test_prefix', type=str, default='test_triples')
parser.add_argument('--max_len', type=int, default=150)
parser.add_argument('--rel_num', type=int, default=25)
parser.add_argument('--period', type=int, default=50)
parser.add_argument('--debug', type=bool, default=False)
args = parser.parse_args(args=[])

con = config.Config(args)

In [2]:
def print_config(con):
    for key in con.__dict__:
        print(key, end=' = ')
        print(con.__dict__[key])

In [3]:
print_config(con)

args = Namespace(batch_size=6, dataset='NYT', debug=False, dev_prefix='dev_triples', lr=1e-05, max_epoch=15, max_len=150, model_name='Casrel', multi_gpu=False, period=50, rel_num=25, test_epoch=1, test_prefix='test_triples', train_prefix='train_triples')
multi_gpu = False
learning_rate = 1e-05
batch_size = 6
max_epoch = 15
max_len = 150
rel_num = 25
dataset = NYT
root = /home/niuhao/project/RE/CasRel_2020
data_path = /home/niuhao/project/RE/CasRel_2020/data/NYT
checkpoint_dir = /home/niuhao/project/RE/CasRel_2020/checkpoint/NYT
log_dir = /home/niuhao/project/RE/CasRel_2020/log/NYT
result_dir = /home/niuhao/project/RE/CasRel_2020/result/NYT
train_prefix = train_triples
dev_prefix = dev_triples
test_prefix = test_triples
model_save_name = Casrel_DATASET_NYT_LR_1e-05_BS_6
log_save_name = LOG_Casrel_DATASET_NYT_LR_1e-05_BS_6
result_save_name = RESULT_Casrel_DATASET_NYT_LR_1e-05_BS_6.json
period = 50
test_epoch = 1
debug = False


### 4.2 数据读取与处理
* 数据处理细节
* 构建dataset类

#### 4.2.1 数据处理细节

In [4]:
import pickle
import json

In [5]:
train_data = pickle.load(open(os.path.join(con.data_path, con.train_prefix + '.pkl'), 'rb'))

In [6]:
rel2id = json.load(open(os.path.join(con.data_path, 'rel2id.json')))[1]

In [7]:
len(train_data)

56195

In [8]:
train_data[0]

{'text': 'Massachusetts ASTON MAGNA Great Barrington ; also at Bard College , Annandale-on-Hudson , N.Y. , July 1-Aug .',
 'triple_list': [['Annandale-on-Hudson',
   '/location/location/contains',
   'College']]}

In [9]:
text = train_data[1]['text']
text

'North Carolina EASTERN MUSIC FESTIVAL Greensboro , June 25-July 30 .'

In [10]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')



In [11]:
text = ' '.join(text.split())
text

'North Carolina EASTERN MUSIC FESTIVAL Greensboro , June 25-July 30 .'

In [12]:
tokens = tokenizer.tokenize(text)
tokens

['north',
 'carolina',
 'eastern',
 'music',
 'festival',
 'greensboro',
 ',',
 'june',
 '25',
 '-',
 'july',
 '30',
 '.']

In [13]:
s2ro_map = {}

In [14]:
train_data[6]['triple_list']

[['Weiner', '/people/person/place_lived', 'Queens'],
 ['Weiner', '/people/person/place_lived', 'Brooklyn']]

In [15]:
triple = train_data[1]['triple_list'][0]
triple

['Carolina', '/location/location/contains', 'Greensboro']

In [16]:
tokenizer.tokenize(triple[0])

['carolina']

In [17]:
triple = (tokenizer.tokenize(triple[0]), triple[1], tokenizer.tokenize(triple[2]))
triple

(['carolina'], '/location/location/contains', ['greensboro'])

In [18]:
source = tokens
target = triple[0]
target_len = len(target)
for i in range(len(source)):
    if source[i: i + target_len] == target:
        print(i)

1


In [19]:
def find_head_idx(source, target):
    target_len = len(target)
    for i in range(len(source)):
        if source[i: i + target_len] == target:
            return i
    return -1

In [20]:
sub_head_idx = find_head_idx(tokens, triple[0])
print(sub_head_idx)
obj_head_idx = find_head_idx(tokens, triple[2])
obj_head_idx

1


5

In [21]:
sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
sub

(1, 1)

In [22]:
if sub not in s2ro_map:
    s2ro_map[sub] = []
s2ro_map

{(1, 1): []}

In [23]:
s2ro_map[sub].append((obj_head_idx, obj_head_idx + len(triple[2]) - 1, rel2id[triple[1]]))
s2ro_map

{(1, 1): [(5, 5, 22)]}

In [24]:
out = tokenizer(text)
out

{'input_ids': [101, 2167, 3792, 2789, 2189, 2782, 27905, 1010, 2238, 2423, 1011, 2251, 2382, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [25]:
token_ids = out['input_ids']

In [26]:
token_ids = np.array(token_ids)
token_ids

array([  101,  2167,  3792,  2789,  2189,  2782, 27905,  1010,  2238,
        2423,  1011,  2251,  2382,  1012,   102])

In [27]:
text_len = len(tokens)
text_len

13

In [28]:
sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)

In [29]:
s2ro_map

{(1, 1): [(5, 5, 22)]}

In [30]:
s = [s for s in s2ro_map][0]

In [31]:
sub_heads[s[0]] = 1

In [32]:
sub_tails[s[1]] = 1

In [33]:
sub_heads

array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [34]:
sub_tails

array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [35]:
from random import choice

In [36]:
obj_heads, obj_tails = np.zeros((text_len, con.rel_num)), np.zeros((text_len, con.rel_num))

In [37]:
con.rel_num

25

In [38]:
obj_heads.shape

(13, 25)

In [39]:
sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys()))

In [40]:
sub_head_idx

1

In [41]:
for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []):
    obj_heads[ro[0]][ro[2]] = 1
    obj_tails[ro[1]][ro[2]] = 1

In [42]:
obj_tails[5]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 1., 0., 0.])

#### 4.2.2 将上述功能构建为类
* 将相似类型的功能整合到一个类里，使代码结构更清晰，并且方面以后调佣，且有利于debug

In [43]:
from torch.utils.data import DataLoader, Dataset

In [44]:
class CMEDDataset(Dataset):
    def __init__(self, config, prefix, is_test, tokenizer):
        self.config = config
        self.prefix = prefix
        self.is_test = is_test
        self.tokenizer = tokenizer
        if self.config.debug:
            self.json_data = pickle.load(open(os.path.join(self.config.data_path, prefix + '.pkl'), 'rb'))[:500]
        else:
            self.json_data = pickle.load(open(os.path.join(self.config.data_path, prefix + '.pkl'), 'rb'))
        self.rel2id = json.load(open(os.path.join(self.config.data_path, 'rel2id.json')))[1]

    def __len__(self):
        return len(self.json_data)

    def __getitem__(self, idx):
        ins_json_data = self.json_data[idx]
        text = ins_json_data['text']
        text = ' '.join(text.split()[:self.config.max_len])
        tokens = self.tokenizer.tokenize(text)
        if len(tokens) > BERT_MAX_LEN:
            tokens = tokens[: BERT_MAX_LEN]
        text_len = len(tokens)

        if not self.is_test:
            s2ro_map = {}
            for triple in ins_json_data['triple_list']:
                triple = (self.tokenizer.tokenize(triple[0]), triple[1], self.tokenizer.tokenize(triple[2]))
                sub_head_idx = find_head_idx(tokens, triple[0])
                obj_head_idx = find_head_idx(tokens, triple[2])
                if sub_head_idx != -1 and obj_head_idx != -1:
                    sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
                    if sub not in s2ro_map:
                        s2ro_map[sub] = []
                    s2ro_map[sub].append((obj_head_idx, obj_head_idx + len(triple[2]) - 1, self.rel2id[triple[1]]))

            if s2ro_map:
                # token_ids, segment_ids = self.tokenizer.encode(first=text)
                tokenizer_out = self.tokenizer(text)
                token_ids = tokenizer_out['input_ids']
                segment_ids = tokenizer_out['attention_mask']
                masks = segment_ids
                if len(token_ids) > text_len:
                    token_ids = token_ids[:text_len]
                    masks = masks[:text_len]
                token_ids = np.array(token_ids)
                masks = np.array(masks) + 1
                sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)
                for s in s2ro_map:
                    sub_heads[s[0]] = 1
                    sub_tails[s[1]] = 1
                sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys()))
                sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len)
                sub_head[sub_head_idx] = 1
                sub_tail[sub_tail_idx] = 1
                obj_heads, obj_tails = np.zeros((text_len, self.config.rel_num)), np.zeros((text_len, self.config.rel_num))
                for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []):
                    obj_heads[ro[0]][ro[2]] = 1
                    obj_tails[ro[1]][ro[2]] = 1
                return token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, ins_json_data['triple_list'], tokens
            else:
                return None
        else:
            # token_ids, segment_ids = self.tokenizer.encode(first=text)
            tokenizer_out = self.tokenizer(text)
            token_ids = tokenizer_out['input_ids']
            segment_ids = tokenizer_out['attention_mask']
            masks = segment_ids
            if len(token_ids) > text_len:
                token_ids = token_ids[:text_len]
                masks = masks[:text_len]
            token_ids = np.array(token_ids)
            masks = np.array(masks) + 1
            sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)
            sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len)
            obj_heads, obj_tails = np.zeros((text_len, self.config.rel_num)), np.zeros((text_len, self.config.rel_num))
            return token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, ins_json_data['triple_list'], tokens

In [45]:
def cmed_collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    batch.sort(key=lambda x: x[2], reverse=True)
    token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, triples, tokens = zip(*batch)
    cur_batch = len(batch)
    max_text_len = max(text_len)
    batch_token_ids = torch.LongTensor(cur_batch, max_text_len).zero_()
    batch_masks = torch.LongTensor(cur_batch, max_text_len).zero_()
    batch_sub_heads = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_sub_tails = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_sub_head = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_sub_tail = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_obj_heads = torch.Tensor(cur_batch, max_text_len, 44).zero_()
    batch_obj_tails = torch.Tensor(cur_batch, max_text_len, 44).zero_()

    for i in range(cur_batch):
        batch_token_ids[i, :text_len[i]].copy_(torch.from_numpy(token_ids[i]))
        batch_masks[i, :text_len[i]].copy_(torch.from_numpy(masks[i]))
        batch_sub_heads[i, :text_len[i]].copy_(torch.from_numpy(sub_heads[i]))
        batch_sub_tails[i, :text_len[i]].copy_(torch.from_numpy(sub_tails[i]))
        batch_sub_head[i, :text_len[i]].copy_(torch.from_numpy(sub_head[i]))
        batch_sub_tail[i, :text_len[i]].copy_(torch.from_numpy(sub_tail[i]))
        batch_obj_heads[i, :text_len[i], :].copy_(torch.from_numpy(obj_heads[i]))
        batch_obj_tails[i, :text_len[i], :].copy_(torch.from_numpy(obj_tails[i]))

    return {'token_ids': batch_token_ids,
            'mask': batch_masks,
            'sub_heads': batch_sub_heads,
            'sub_tails': batch_sub_tails,
            'sub_head': batch_sub_head,
            'sub_tail': batch_sub_tail,
            'obj_heads': batch_obj_heads,
            'obj_tails': batch_obj_tails,
            'triples': triples,
            'tokens': tokens}

In [46]:
def get_loader(config, prefix, is_test=False, num_workers=0, collate_fn=cmed_collate_fn):
    dataset = CMEDDataset(config, prefix, is_test, tokenizer)
    if not is_test:
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 collate_fn=collate_fn)
    else:
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 collate_fn=collate_fn)
    return data_loader

In [47]:
class DataPreFetcher(object):
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_data = next(self.loader)
        except StopIteration:
            self.next_data = None
            return
        with torch.cuda.stream(self.stream):
            for k, v in self.next_data.items():
                if isinstance(v, torch.Tensor):
                    self.next_data[k] = self.next_data[k].cuda(non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        data = self.next_data
        self.preload()
        return data

### 4.3 定义模型

In [48]:
import transformers
transformers.__version__

'3.4.0'

In [49]:
from transformers import BertModel
import torch
import torch.nn as nn
import torch.optim as optim
import os
import data_loader
import torch.nn.functional as F
import numpy as np
import json
import time

In [50]:
class Casrel(nn.Module):
    def __init__(self, config):
        super(Casrel, self).__init__()
        self.config = config
        self.bert_dim = 768
        self.bert_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.sub_heads_linear = nn.Linear(self.bert_dim, 1)
        self.sub_tails_linear = nn.Linear(self.bert_dim, 1)
        self.obj_heads_linear = nn.Linear(self.bert_dim, self.config.rel_num)
        self.obj_tails_linear = nn.Linear(self.bert_dim, self.config.rel_num)

    def get_objs_for_specific_sub(self, sub_head_mapping, sub_tail_mapping, encoded_text):
        # [batch_size, 1, bert_dim]
        sub_head = torch.matmul(sub_head_mapping, encoded_text)
        # [batch_size, 1, bert_dim]
        sub_tail = torch.matmul(sub_tail_mapping, encoded_text)
        # [batch_size, 1, bert_dim]
        sub = (sub_head + sub_tail) / 2
        # [batch_size, seq_len, bert_dim]
        encoded_text = encoded_text + sub
        # [batch_size, seq_len, rel_num]
        pred_obj_heads = self.obj_heads_linear(encoded_text)
        pred_obj_heads = torch.sigmoid(pred_obj_heads)
        # [batch_size, seq_len, rel_num]
        pred_obj_tails = self.obj_tails_linear(encoded_text)
        pred_obj_tails = torch.sigmoid(pred_obj_tails)
        return pred_obj_heads, pred_obj_tails

    def get_encoded_text(self, token_ids, mask):
        # [batch_size, seq_len, bert_dim(768)]
        encoded_text = self.bert_encoder(token_ids, attention_mask=mask)[0]
        return encoded_text

    def get_subs(self, encoded_text):
        # [batch_size, seq_len, 1]
        pred_sub_heads = self.sub_heads_linear(encoded_text)
        pred_sub_heads = torch.sigmoid(pred_sub_heads)
        # [batch_size, seq_len, 1]
        pred_sub_tails = self.sub_tails_linear(encoded_text)
        pred_sub_tails = torch.sigmoid(pred_sub_tails)
        return pred_sub_heads, pred_sub_tails

    def forward(self, data):
        # [batch_size, seq_len]
        token_ids = data['token_ids']
        # [batch_size, seq_len]
        mask = data['mask']
        # [batch_size, seq_len, bert_dim(768)]
        encoded_text = self.get_encoded_text(token_ids, mask)
        # [batch_size, seq_len, 1]
        pred_sub_heads, pred_sub_tails = self.get_subs(encoded_text)
        # [batch_size, 1, seq_len]
        sub_head_mapping = data['sub_head'].unsqueeze(1)
        # [batch_size, 1, seq_len]
        sub_tail_mapping = data['sub_tail'].unsqueeze(1)
        # [batch_size, seq_len, rel_num]
        pred_obj_heads, pred_obj_tails = self.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, encoded_text)
        return pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails

In [51]:
ori_model = Casrel(con)
ori_model.cuda()

Casrel(
  (bert_encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tru

In [52]:
# define the optimizer
optimizer = optim.Adam(filter(lambda p: p.requires_grad, ori_model.parameters()), lr=con.learning_rate)

# whether use multi GPU
if con.multi_gpu:
    model = nn.DataParallel(ori_model)
else:
    model = ori_model

### 4.4 模型训练

In [53]:
# define the loss function
def loss(gold, pred, mask):
    pred = pred.squeeze(-1)
    los = F.binary_cross_entropy(pred, gold, reduction='none')
    if los.shape != mask.shape:
        mask = mask.unsqueeze(-1)
    los = torch.sum(los * mask) / torch.sum(mask)
    return los

In [54]:
# check the checkpoint dir
if not os.path.exists(con.checkpoint_dir):
    os.mkdir(con.checkpoint_dir)

# check the log dir
if not os.path.exists(con.log_dir):
    os.mkdir(con.log_dir)

# get the data loader
train_data_loader = data_loader.get_loader(con, prefix=con.train_prefix)
dev_data_loader = data_loader.get_loader(con, prefix=con.dev_prefix, is_test=True)


In [55]:
model.train()
global_step = 0
loss_sum = 0

best_f1_score = 0
best_precision = 0
best_recall = 0

best_epoch = 0
init_time = time.time()
start_time = time.time()

# the training loop
# for epoch in range(self.config.max_epoch):

In [56]:
train_data_prefetcher = data_loader.DataPreFetcher(train_data_loader)

In [57]:
data = train_data_prefetcher.next()
data

{'token_ids': tensor([[  101,  2096,  1996, 28709,  6790,  2001, 25369,  2041,  2005,  2049,
           1005,  1005, 15085,  7720,  8158,  1998,  3824,  6904, 18796,  3215,
           1010,  1005,  1005,  2429,  2000,  1996, 25022, 12380,  2015,  4773,
           2609,  1010, 24526,  1005,  1055,  4825,  1010,  1999,  2225,  2121,
           3077,  1010,  4058,  1010,  2363,  1996,  2087,  3784,  4494,  1997,
           2274, 13527],
         [  101, 20351,  2546,  1010,  2321,  1010,  2040,  2973,  2006,  2358,
           1012,  6017,  3927,  1999,  1996,  3134,  3077,  2930,  1997,  6613,
           1010,  2351,  2012, 26756,  2902,  2415,  2044,  2108, 13263,  1999,
           1996, 13878,  2076,  1037,  2954,  1999,  1996,  4257,  9711,  2078,
           2212,  2008,  2920,  2195, 12908,  1010,  1996,  2610,  2056,     0,
              0,     0],
         [  101,  6744,  1005,  1055,  3570,  2003,  2028,  1997,  1996,  2087,
           7591,  3980,  1999,  1996,  5611,  1011,  9302

In [58]:
pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails = model(data)
# print('pred_sub_heads:{}, pred_sub_tails: {},pred_obj_heads: {},pred_obj_tails: {}'.format(pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails))

In [59]:
sub_heads_loss = loss(data['sub_heads'], pred_sub_heads, data['mask'])
sub_heads_loss

tensor(0.6663, device='cuda:0', grad_fn=<DivBackward0>)

In [60]:
sub_heads_loss = loss(data['sub_heads'], pred_sub_heads, data['mask'])
sub_tails_loss = loss(data['sub_tails'], pred_sub_tails, data['mask'])
obj_heads_loss = loss(data['obj_heads'], pred_obj_heads, data['mask'])
obj_tails_loss = loss(data['obj_tails'], pred_obj_tails, data['mask'])
total_loss = (sub_heads_loss + sub_tails_loss) + (obj_heads_loss + obj_tails_loss)
total_loss

tensor(36.7705, device='cuda:0', grad_fn=<AddBackward0>)

In [61]:
optimizer.zero_grad()
total_loss.backward()
optimizer.step()

In [62]:
global_step += 1
loss_sum += total_loss.item()

if global_step % con.period == 0:
    cur_loss = loss_sum / self.config.period
    elapsed = time.time() - start_time
    logging("epoch: {:3d}, step: {:4d}, speed: {:5.2f}ms/b, train loss: {:5.3f}".
                 format(epoch, global_step, elapsed * 1000 / con.period, cur_loss))
    loss_sum = 0
    start_time = time.time()

### 4.5 模型评价
现在假设训练好了一个模型（或者模型训练了一个epoch），我们想看看模型现在的性能，那么就需要对模型进行评价

**目录**
* 预测头实体
* 预测尾实体和关系
* 计算评价指标
* 定义函数

In [63]:
# check the result dir
if not os.path.exists(con.result_dir):
    os.mkdir(con.result_dir)

path = os.path.join(con.result_dir, con.result_save_name)

fw = open(path, 'w')

In [64]:
orders = ['subject', 'relation', 'object']

In [65]:
def to_tup(triple_list):
    ret = []
    for triple in triple_list:
        ret.append(tuple(triple))
    return ret

In [66]:
test_data_prefetcher = data_loader.DataPreFetcher(dev_data_loader)
data = test_data_prefetcher.next()
id2rel = json.load(open(os.path.join(con.data_path, 'rel2id.json')))[0]

In [67]:
id2rel

{'1': '/business/company/founders',
 '2': '/people/person/place_of_birth',
 '3': '/people/deceased_person/place_of_death',
 '4': '/business/company_shareholder/major_shareholder_of',
 '5': '/people/ethnicity/people',
 '6': '/location/neighborhood/neighborhood_of',
 '7': '/sports/sports_team/location',
 '9': '/business/company/industry',
 '10': '/business/company/place_founded',
 '11': '/location/administrative_division/country',
 '0': 'None',
 '12': '/sports/sports_team_location/teams',
 '13': '/people/person/nationality',
 '14': '/people/person/religion',
 '15': '/business/company/advisors',
 '16': '/people/person/ethnicity',
 '17': '/people/ethnicity/geographic_distribution',
 '8': '/business/person/company',
 '19': '/business/company/major_shareholders',
 '18': '/people/person/place_lived',
 '20': '/people/person/profession',
 '21': '/location/country/capital',
 '22': '/location/location/contains',
 '23': '/location/country/administrative_divisions',
 '24': '/people/person/children'

In [68]:
correct_num, predict_num, gold_num = 0, 0, 0

#### 4.5.1 预测头实体

In [69]:
with torch.no_grad():
    token_ids = data['token_ids']
    tokens = data['tokens'][0]
    mask = data['mask']
    encoded_text = model.get_encoded_text(token_ids, mask)
    pred_sub_heads, pred_sub_tails = model.get_subs(encoded_text)

In [70]:
pred_sub_heads.shape

torch.Size([1, 30, 1])

In [71]:
pred_sub_tails

tensor([[[0.5773],
         [0.5536],
         [0.4459],
         [0.4521],
         [0.5084],
         [0.5354],
         [0.4858],
         [0.5011],
         [0.5999],
         [0.5540],
         [0.4826],
         [0.4457],
         [0.4614],
         [0.4807],
         [0.3697],
         [0.4443],
         [0.4665],
         [0.5019],
         [0.4635],
         [0.4928],
         [0.5673],
         [0.5024],
         [0.5179],
         [0.5031],
         [0.4897],
         [0.4830],
         [0.4964],
         [0.5143],
         [0.4795],
         [0.5240]]], device='cuda:0')

In [72]:
h_bar=0.5
t_bar=0.5

In [73]:
sub_heads, sub_tails = np.where(pred_sub_heads.cpu()[0] > h_bar)[0], np.where(pred_sub_tails.cpu()[0] > t_bar)[0]

In [74]:
sub_heads

array([ 1,  2,  3,  7, 11])

In [75]:
sub_tails

array([ 0,  1,  4,  5,  7,  8,  9, 17, 20, 21, 22, 23, 27, 29])

In [76]:
subjects = []
for sub_head in sub_heads:
    sub_tail = sub_tails[sub_tails >= sub_head]
    if len(sub_tail) > 0:
        sub_tail = sub_tail[0]
        subject = tokens[sub_head: sub_tail]
        subjects.append((subject, sub_head, sub_tail))

In [77]:
subjects

[([], 1, 1),
 ([',', 'north'], 2, 4),
 (['north'], 3, 4),
 ([], 7, 7),
 ([',', 'su', '##pp', '##lan', '##ted', 'a'], 11, 17)]

In [78]:
triple_list = []
# [subject_num, seq_len, bert_dim]
repeated_encoded_text = encoded_text.repeat(len(subjects), 1, 1)
repeated_encoded_text.shape

torch.Size([5, 30, 768])

In [79]:
sub_head_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
sub_tail_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()

In [80]:
sub_tail_mapping.shape

torch.Size([5, 1, 30])

In [81]:
subjects

[([], 1, 1),
 ([',', 'north'], 2, 4),
 (['north'], 3, 4),
 ([], 7, 7),
 ([',', 'su', '##pp', '##lan', '##ted', 'a'], 11, 17)]

In [82]:
for subject_idx, subject in enumerate(subjects):
    sub_head_mapping[subject_idx][0][subject[1]] = 1
    sub_tail_mapping[subject_idx][0][subject[2]] = 1

In [83]:
sub_tail_mapping

tensor([[[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [84]:
sub_tail_mapping = sub_tail_mapping.to(repeated_encoded_text)
sub_head_mapping = sub_head_mapping.to(repeated_encoded_text)

In [85]:
sub_tail_mapping

tensor([[[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
       device='cuda:0')

#### 4.5.2 预测尾实体和关系

In [86]:
pred_obj_heads, pred_obj_tails = model.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, repeated_encoded_text)

In [87]:
pred_obj_heads.shape

torch.Size([5, 30, 25])

In [88]:
subs = [[subject_idx, subject] for subject_idx, subject in enumerate(subjects)]
subs

[[0, ([], 1, 1)],
 [1, ([',', 'north'], 2, 4)],
 [2, (['north'], 3, 4)],
 [3, ([], 7, 7)],
 [4, ([',', 'su', '##pp', '##lan', '##ted', 'a'], 11, 17)]]

In [89]:
subject = subs[0][1]
subject

([], 1, 1)

In [90]:
sub = subject[0]
sub = ''.join([i.lstrip("##") for i in sub])
sub

''

In [91]:
sub = ' '.join(sub.split('[unused1]'))
sub

''

In [92]:
pred_obj_heads[0].shape

torch.Size([30, 25])

In [93]:
obj_heads, obj_tails = np.where(pred_obj_heads.cpu()[subject_idx] > h_bar), np.where(pred_obj_tails.cpu()[subject_idx] > t_bar)

In [94]:
obj_heads[0].shape

(309,)

In [95]:
[[obj_head, rel_head] for obj_head, rel_head in zip(*obj_heads)]

[[0, 1],
 [0, 4],
 [0, 5],
 [0, 6],
 [0, 7],
 [0, 9],
 [0, 10],
 [0, 14],
 [0, 17],
 [0, 18],
 [0, 19],
 [1, 4],
 [1, 5],
 [1, 6],
 [1, 9],
 [1, 10],
 [1, 14],
 [1, 17],
 [1, 18],
 [1, 19],
 [1, 22],
 [2, 3],
 [2, 4],
 [2, 5],
 [2, 6],
 [2, 9],
 [2, 10],
 [2, 14],
 [2, 16],
 [2, 18],
 [2, 19],
 [2, 22],
 [3, 1],
 [3, 4],
 [3, 5],
 [3, 6],
 [3, 7],
 [3, 9],
 [3, 10],
 [3, 14],
 [3, 17],
 [3, 22],
 [4, 1],
 [4, 3],
 [4, 4],
 [4, 5],
 [4, 6],
 [4, 7],
 [4, 10],
 [4, 14],
 [4, 17],
 [4, 18],
 [4, 22],
 [5, 4],
 [5, 5],
 [5, 6],
 [5, 9],
 [5, 10],
 [5, 14],
 [5, 17],
 [5, 18],
 [5, 19],
 [5, 22],
 [6, 4],
 [6, 5],
 [6, 6],
 [6, 9],
 [6, 10],
 [6, 14],
 [6, 17],
 [6, 18],
 [6, 22],
 [7, 4],
 [7, 5],
 [7, 6],
 [7, 10],
 [7, 14],
 [7, 15],
 [7, 22],
 [8, 1],
 [8, 3],
 [8, 5],
 [8, 6],
 [8, 9],
 [8, 10],
 [8, 11],
 [8, 14],
 [8, 16],
 [8, 17],
 [8, 18],
 [8, 19],
 [8, 22],
 [9, 1],
 [9, 3],
 [9, 4],
 [9, 5],
 [9, 6],
 [9, 10],
 [9, 14],
 [9, 16],
 [9, 17],
 [9, 18],
 [9, 22],
 [10, 1],
 [10, 3]

In [96]:
[[obj_tail, rel_tail] for obj_tail, rel_tail in zip(*obj_tails)]

[[0, 0],
 [0, 1],
 [0, 2],
 [0, 3],
 [0, 4],
 [0, 6],
 [0, 7],
 [0, 12],
 [0, 13],
 [0, 14],
 [0, 15],
 [0, 16],
 [0, 18],
 [0, 20],
 [0, 21],
 [0, 23],
 [0, 24],
 [1, 0],
 [1, 1],
 [1, 2],
 [1, 3],
 [1, 4],
 [1, 6],
 [1, 7],
 [1, 12],
 [1, 13],
 [1, 14],
 [1, 15],
 [1, 18],
 [1, 20],
 [1, 21],
 [1, 23],
 [1, 24],
 [2, 1],
 [2, 3],
 [2, 4],
 [2, 5],
 [2, 6],
 [2, 7],
 [2, 9],
 [2, 11],
 [2, 12],
 [2, 13],
 [2, 14],
 [2, 15],
 [2, 18],
 [2, 21],
 [2, 23],
 [2, 24],
 [3, 0],
 [3, 1],
 [3, 2],
 [3, 3],
 [3, 4],
 [3, 5],
 [3, 6],
 [3, 7],
 [3, 9],
 [3, 10],
 [3, 12],
 [3, 13],
 [3, 14],
 [3, 15],
 [3, 16],
 [3, 18],
 [3, 20],
 [3, 21],
 [3, 23],
 [3, 24],
 [4, 1],
 [4, 2],
 [4, 3],
 [4, 4],
 [4, 6],
 [4, 7],
 [4, 9],
 [4, 12],
 [4, 13],
 [4, 14],
 [4, 15],
 [4, 16],
 [4, 18],
 [4, 20],
 [4, 21],
 [4, 23],
 [4, 24],
 [5, 1],
 [5, 2],
 [5, 3],
 [5, 4],
 [5, 5],
 [5, 7],
 [5, 9],
 [5, 12],
 [5, 13],
 [5, 14],
 [5, 15],
 [5, 18],
 [5, 20],
 [5, 21],
 [5, 23],
 [5, 24],
 [6, 1],
 [6, 2],
 [6, 3

In [97]:
for obj_head, rel_head in zip(*obj_heads):
    for obj_tail, rel_tail in zip(*obj_tails):
        if obj_head <= obj_tail and rel_head == rel_tail:
            rel = id2rel[str(int(rel_head))]
            obj = tokens[obj_head: obj_tail]
            obj = ''.join([i.lstrip("##") for i in obj])
            obj = ' '.join(obj.split('[unused1]'))
            triple_list.append((sub, rel, obj))
            break

In [98]:
triple_list

[('', '/business/company/founders', ''),
 ('', '/business/company_shareholder/major_shareholder_of', ''),
 ('', '/people/ethnicity/people', 'inqueens'),
 ('', '/location/neighborhood/neighborhood_of', ''),
 ('', '/sports/sports_team/location', ''),
 ('', '/business/company/industry', 'inqueens'),
 ('', '/business/company/place_founded', 'inqueens,'),
 ('', '/people/person/religion', ''),
 ('', '/people/person/place_lived', ''),
 ('',
  '/business/company/major_shareholders',
  'inqueens,northshoretowers,near'),
 ('', '/business/company_shareholder/major_shareholder_of', ''),
 ('', '/people/ethnicity/people', 'queens'),
 ('', '/location/neighborhood/neighborhood_of', ''),
 ('', '/business/company/industry', 'queens'),
 ('', '/business/company/place_founded', 'queens,'),
 ('', '/people/person/religion', ''),
 ('', '/people/person/place_lived', ''),
 ('', '/business/company/major_shareholders', 'queens,northshoretowers,near'),
 ('',
  '/location/location/contains',
  'queens,northshoretow

####  4.5.3 计算评价指标

In [99]:
triple_set = set()
for s, r, o in triple_list:
    triple_set.add((s, r, o))
pred_list = list(triple_set)
pred_list

[('', '/business/company/industry', 'queens'),
 ('', '/people/ethnicity/people', 'queens'),
 ('', '/location/neighborhood/neighborhood_of', 'quarry'),
 ('', '/business/company/place_founded', 'towers,nearthenassauborder,'),
 ('', '/people/person/ethnicity', 'gravel'),
 ('', '/location/neighborhood/neighborhood_of', 'housing'),
 ('', '/business/company/industry', 'the'),
 ('',
  '/location/location/contains',
  'northshoretowers,nearthenassauborder,supp'),
 ('', '/sports/sports_team/location', ''),
 ('', '/people/ethnicity/people', 'housingreplacedagravelquarryindouglas'),
 ('', '/business/company/place_founded', 'nassauborder,'),
 ('', '/sports/sports_team/location', 'ton'),
 ('', '/business/company_shareholder/major_shareholder_of', ''),
 ('', '/people/person/religion', 'course'),
 ('', '/people/ethnicity/people', 'quarryindouglas'),
 ('', '/people/ethnicity/people', ',nearthe'),
 ('', '/location/location/contains', ',nearthenassauborder,supp'),
 ('', '/location/location/contains', 't

In [100]:
data['triples']

[[['Douglaston', '/location/neighborhood/neighborhood_of', 'Queens'],
  ['Queens', '/location/location/contains', 'Douglaston']]]

In [101]:
pred_triples = set(pred_list)
gold_triples = set(to_tup(data['triples'][0]))

correct_num += len(pred_triples & gold_triples)
predict_num += len(pred_triples)
gold_num += len(gold_triples)

In [102]:
correct_num

0

In [103]:
precision = correct_num / (predict_num + 1e-10)
recall = correct_num / (gold_num + 1e-10)
f1_score = 2 * precision * recall / (precision + recall + 1e-10)
f1_score

0.0

#### 4.5.4  定义测试函数，方便调用

In [104]:
def test(test_data_loader, model, output=False, h_bar=0.5, t_bar=0.5):

    if output:
        # check the result dir
        if not os.path.exists(con.result_dir):
            os.mkdir(con.result_dir)

        path = os.path.join(con.result_dir, con.result_save_name)

        fw = open(path, 'w')

    orders = ['subject', 'relation', 'object']

    def to_tup(triple_list):
        ret = []
        for triple in triple_list:
            ret.append(tuple(triple))
        return ret

    test_data_prefetcher = data_loader.DataPreFetcher(test_data_loader)
    data = test_data_prefetcher.next()
    id2rel = json.load(open(os.path.join(con.data_path, 'rel2id.json')))[0]
    correct_num, predict_num, gold_num = 0, 0, 0

    while data is not None:
        with torch.no_grad():
            token_ids = data['token_ids']
            tokens = data['tokens'][0]
            mask = data['mask']
            encoded_text = model.get_encoded_text(token_ids, mask)
            pred_sub_heads, pred_sub_tails = model.get_subs(encoded_text)
            sub_heads, sub_tails = np.where(pred_sub_heads.cpu()[0] > h_bar)[0], np.where(pred_sub_tails.cpu()[0] > t_bar)[0]
            subjects = []
            for sub_head in sub_heads:
                sub_tail = sub_tails[sub_tails >= sub_head]
                if len(sub_tail) > 0:
                    sub_tail = sub_tail[0]
                    subject = tokens[sub_head: sub_tail]
                    subjects.append((subject, sub_head, sub_tail))
            if subjects:
                triple_list = []
                # [subject_num, seq_len, bert_dim]
                repeated_encoded_text = encoded_text.repeat(len(subjects), 1, 1)
                # [subject_num, 1, seq_len]
                sub_head_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
                sub_tail_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
                for subject_idx, subject in enumerate(subjects):
                    sub_head_mapping[subject_idx][0][subject[1]] = 1
                    sub_tail_mapping[subject_idx][0][subject[2]] = 1
                sub_tail_mapping = sub_tail_mapping.to(repeated_encoded_text)
                sub_head_mapping = sub_head_mapping.to(repeated_encoded_text)
                pred_obj_heads, pred_obj_tails = model.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, repeated_encoded_text)
                for subject_idx, subject in enumerate(subjects):
                    sub = subject[0]
                    sub = ''.join([i.lstrip("##") for i in sub])
                    sub = ' '.join(sub.split('[unused1]'))
                    obj_heads, obj_tails = np.where(pred_obj_heads.cpu()[subject_idx] > h_bar), np.where(pred_obj_tails.cpu()[subject_idx] > t_bar)
                    for obj_head, rel_head in zip(*obj_heads):
                        for obj_tail, rel_tail in zip(*obj_tails):
                            if obj_head <= obj_tail and rel_head == rel_tail:
                                rel = id2rel[str(int(rel_head))]
                                obj = tokens[obj_head: obj_tail]
                                obj = ''.join([i.lstrip("##") for i in obj])
                                obj = ' '.join(obj.split('[unused1]'))
                                triple_list.append((sub, rel, obj))
                                break
                triple_set = set()
                for s, r, o in triple_list:
                    triple_set.add((s, r, o))
                pred_list = list(triple_set)
            else:
                pred_list = []
            pred_triples = set(pred_list)
            gold_triples = set(to_tup(data['triples'][0]))

            correct_num += len(pred_triples & gold_triples)
            predict_num += len(pred_triples)
            gold_num += len(gold_triples)

            if output:
                result = json.dumps({
                    # 'text': ' '.join(tokens),
                    'triple_list_gold': [
                        dict(zip(orders, triple)) for triple in gold_triples
                    ],
                    'triple_list_pred': [
                        dict(zip(orders, triple)) for triple in pred_triples
                    ],
                    'new': [
                        dict(zip(orders, triple)) for triple in pred_triples - gold_triples
                    ],
                    'lack': [
                        dict(zip(orders, triple)) for triple in gold_triples - pred_triples
                    ]
                }, ensure_ascii=False)
                fw.write(result + '\n')

            data = test_data_prefetcher.next()

    print("correct_num: {:3d}, predict_num: {:3d}, gold_num: {:3d}".format(correct_num, predict_num, gold_num))

    precision = correct_num / (predict_num + 1e-10)
    recall = correct_num / (gold_num + 1e-10)
    f1_score = 2 * precision * recall / (precision + recall + 1e-10)
    return precision, recall, f1_score

# 5 代码梳理及细节回顾(在VScode中演示)

　　在VScode环境中的训练文件里再回顾训练流程。

# 6 作业
  
`【思考题】`思考这篇文章的模型有什么不足，有什么可以改进的地方。

`【代码实践】`复现该文章代码的模型（CASREL）部分。

`【画图】`不看文章原图，按照自己的理解画出模型的结构图。

`【总结】`对这篇文章进行回顾，思考并学习文章写作总体结构，实验设计等内容。

---