# Joint Extraction of Entities and Relations Based on a Novel Tagging Scheme ACL 2017

# 1前言

### 1,1课程回顾

<img src='imgs/overall.png' width="800" height="800" align="bottom">

### 1.2 模型结构

##### 1.2.1 Tagging Schema

<img src='imgs/over_all_1.png' width="800" height="800" align="bottom">

##### 1.2.2 Model

<img src='imgs/over_all_2.png' width="800" height="800" align="bottom">

### 1.3 代码结构展示
<img src="./imgs/dir.png"  width="200" height="200" align="bottom" />

# 2 准备工作
### 2.1项目环境配置

* Python3.8
* jupyter notebook
* torch            1.6.0+cu10.2
* numpy            1.18.5

代码运行环境建议使用Visual Studio Code(VScode)

### 2.2 数据集下载
* 数据集下载地址：<br>
 https://pan.baidu.com/s/12maQjrRjv52dPcTA4dRtyw <br>
 提取码：bv71

# 3 项目代码结构（VScode中演示）

>1）是什么？

　　我们首先会在VScode环境中让代码跑一下，直观感受到项目的训练，并展示前向推断的输出，让大家看到模型的效果。
>2）怎么构成的？

　　然后介绍项目代码的构成，介绍项目有哪些文件夹，包含哪些文件，这些文件构成了什么功能模块如：数据预处理模块，模型设计模块，损失函数模块，推断与评估模块。
>3）小结

　　在主文件中在过一下启动训练的流程。

# 4 算法模块及细节（jupyter和VScode中演示）

　　在jupyter notebook中细致地讲解每一个模块。
  
　　以实现模块功能为目的，来讲解每个函数的执行流程，呈现中间数据，方便同学们理解学习。
  
　　内容分为以下几个模块：**超参数设置，数据读取与处理，模型定义，模型训练，模型评价**。

### 4.1 超参数设置

In [4]:
import argparse
from utils import *
from torch.utils.data.dataset import *
from torch.utils.data.sampler import *
from torch.nn.utils.rnn import *
import bisect
from model import *
import torch
import os
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time

torch.Size([47463, 300])


In [14]:
parser = argparse.ArgumentParser(description="Joint Extraction of Entities and Relations")
parser.add_argument('--batch_size', type=int, default=32, metavar='N',
                    help='batch size (default: 32)')
parser.add_argument('--cuda', action='store_false',
                    help='use CUDA (default: True)')
parser.add_argument('--dropout', type=float, default=0.5,
                    help='dropout applied to layers (default: 0.5)')
parser.add_argument('--emb_dropout', type=float, default=0.25,
                    help='dropout applied to the embedded layer (default: 0.25)')
parser.add_argument('--clip', type=float, default=0.35,
                    help='gradient clip, -1 means no clip (default: 0.35)')
parser.add_argument('--epochs', type=int, default=30,
                    help='upper epoch limit (default: 30)')
parser.add_argument('--char_kernel_size', type=int, default=3,
                    help='character-level kernel size (default: 3)')
parser.add_argument('--word_kernel_size', type=int, default=3,
                    help='word-level kernel size (default: 3)')
parser.add_argument('--emsize', type=int, default=50,
                    help='size of character embeddings (default: 50)')
parser.add_argument('--char_layers', type=int, default=3,
                    help='# of character-level convolution layers (default: 3)')
parser.add_argument('--word_layers', type=int, default=3,
                    help='# of word-level convolution layers (default: 3)')
parser.add_argument('--char_nhid', type=int, default=50,
                    help='number of hidden units per character-level convolution layer (default: 50)')
parser.add_argument('--word_nhid', type=int, default=300,
                    help='number of hidden units per word-level convolution layer (default: 300)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                    help='report interval (default: 100)')
parser.add_argument('--lr', type=float, default=4,
                    help='initial learning rate (default: 4)')
parser.add_argument('--optim', type=str, default='SGD',
                    help='optimizer type (default: SGD)')
parser.add_argument('--seed', type=int, default=1111,
                    help='random seed (default: 1111)')
parser.add_argument('--save', type=str, default='model.pt',
                    help='path to save the final model')
parser.add_argument('--weight', type=float, default=10.0,
                    help='manual rescaling weight given to each tag except "O"')


_StoreAction(option_strings=['--weight'], dest='weight', nargs=None, const=None, default=10.0, type=<class 'float'>, choices=None, help='manual rescaling weight given to each tag except "O"', metavar=None)

In [19]:
args = parser.parse_args(args=[])

In [20]:
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

print(args)
device = torch.device("cuda" if args.cuda else "cpu")


Namespace(batch_size=32, char_kernel_size=3, char_layers=3, char_nhid=50, clip=0.35, cuda=True, dropout=0.5, emb_dropout=0.25, emsize=50, epochs=30, log_interval=100, lr=4, optim='SGD', save='model.pt', seed=1111, weight=10.0, word_kernel_size=3, word_layers=3, word_nhid=300)


### 4.2 数据预处理

In [198]:
import json

In [205]:
class Charset(Index):
    def __init__(self):
        super().__init__()
        for char in string.printable[0:-6]:#所有的字母加符号
            self.add(char)
        self.add("<pad>")
        self.add("<unk>")

    @staticmethod
    def type(char):
        if char in string.digits:
            return "Digits"
        if char in string.ascii_lowercase:
            return "Lower Case"
        if char in string.ascii_uppercase:
            return "Upper Case"
        if char in string.punctuation:
            return "Punctuation"
        return "Other"

    def __getitem__(self, key):
        if isinstance(key, str) and key not in self.key2idx:
            return self.key2idx["<unk>"]
        return super().__getitem__(key)

In [206]:
class Index(object):
    def __init__(self):
        self.key2idx = {}
        self.idx2key = []

    def add(self, key):
        if key not in self.key2idx:
            self.key2idx[key] = len(self.idx2key)
            self.idx2key.append(key)
        return self.key2idx[key]

    def __getitem__(self, key):
        if isinstance(key, str):
            return self.key2idx[key]
        if isinstance(key, int):
            return self.idx2key[key]

    def __len__(self):
        return len(self.idx2key)

    def save(self, f):
        with open(f, 'wt', encoding='utf-8') as fout:
            for index, key in enumerate(self.idx2key):
                fout.write(key + '\t' + str(index) + '\n')

    def load(self, f):
        with open(f, 'rt', encoding='utf-8') as fin:
            for line in fin:
                line = line.strip()
                if not line:
                    continue
                key = line.split()[0]
                self.add(key)

In [207]:
class Vocabulary(Index):
    def __init__(self):
        super().__init__()
        self.add("<pad>")
        self.add("<unk>")

    def __getitem__(self, key):
        if isinstance(key, str) and key not in self.key2idx:
            return self.key2idx["<unk>"]
        return super().__getitem__(key)

In [240]:
test = []

In [203]:
charset = Charset()
vocab = Vocabulary()
vocab.load("data/NYT_CoType/vocab.txt")
relation_labels = Index()
entity_labels = Index()
tag_set = Index()
tag_set.add("O")

0

In [231]:
MAX_SENT_LENGTH = 70
MAX_TOKEN_LENGTH = 20

In [193]:
fin = open('data/NYT_CoType/test.json', 'rt', encoding='utf-8')

In [194]:
line = fin.readline()
line

'{"sentId": 33, "articleId": "2", "relationMentions": [{"em1Text": "Tim Pawlenty", "em2Text": "Minnesota", "label": "/people/person/place_lived"}, {"em1Text": "Minnesota", "em2Text": "Tim Pawlenty", "label": "None"}], "entityMentions": [{"start": 0, "text": "Tim Pawlenty", "label": "PERSON"}, {"start": 1, "text": "Minnesota", "label": "LOCATION"}], "sentText": "Gov. Tim Pawlenty of Minnesota ordered the state health department this month to monitor day-to-day operations at the Minneapolis Veterans Home after state inspectors found that three men had died there in the previous month because of neglect or medical errors .\\r\\n"}\n'

In [195]:
line = line.strip()

In [196]:
# prepare_data_set
num_overlap = 0
overlap = False

In [199]:
sentence = json.loads(line)
sentence

{'sentId': 33,
 'articleId': '2',
 'relationMentions': [{'em1Text': 'Tim Pawlenty',
   'em2Text': 'Minnesota',
   'label': '/people/person/place_lived'},
  {'em1Text': 'Minnesota', 'em2Text': 'Tim Pawlenty', 'label': 'None'}],
 'entityMentions': [{'start': 0, 'text': 'Tim Pawlenty', 'label': 'PERSON'},
  {'start': 1, 'text': 'Minnesota', 'label': 'LOCATION'}],
 'sentText': 'Gov. Tim Pawlenty of Minnesota ordered the state health department this month to monitor day-to-day operations at the Minneapolis Veterans Home after state inspectors found that three men had died there in the previous month because of neglect or medical errors .\r\n'}

In [200]:
def make_tag_set(tag_set, relation_label):
    if relation_label == "None":
        return
    for pos in "BIES":
        for role in "12":
            tag_set.add("-".join([pos, relation_label, role]))#pos-relation_label-role

In [209]:
for relation_mention in sentence["relationMentions"]:
    relation_labels.add(relation_mention["label"])
    make_tag_set(tag_set, relation_mention["label"])

In [210]:
relation_labels.idx2key

['/people/person/place_lived', 'None']

In [211]:
for entity_mention in sentence["entityMentions"]:
    entity_labels.add(entity_mention["label"])

In [212]:
entity_labels.idx2key

['PERSON', 'LOCATION']

In [213]:
sentence_text = sentence["sentText"].strip().strip('"').split()
sentence_text

['Gov.',
 'Tim',
 'Pawlenty',
 'of',
 'Minnesota',
 'ordered',
 'the',
 'state',
 'health',
 'department',
 'this',
 'month',
 'to',
 'monitor',
 'day-to-day',
 'operations',
 'at',
 'the',
 'Minneapolis',
 'Veterans',
 'Home',
 'after',
 'state',
 'inspectors',
 'found',
 'that',
 'three',
 'men',
 'had',
 'died',
 'there',
 'in',
 'the',
 'previous',
 'month',
 'because',
 'of',
 'neglect',
 'or',
 'medical',
 'errors',
 '.']

In [215]:
length_sent = len(sentence_text)
length_sent

42

In [216]:
lower_sentence_text = [token.lower() for token in sentence_text]
lower_sentence_text

['gov.',
 'tim',
 'pawlenty',
 'of',
 'minnesota',
 'ordered',
 'the',
 'state',
 'health',
 'department',
 'this',
 'month',
 'to',
 'monitor',
 'day-to-day',
 'operations',
 'at',
 'the',
 'minneapolis',
 'veterans',
 'home',
 'after',
 'state',
 'inspectors',
 'found',
 'that',
 'three',
 'men',
 'had',
 'died',
 'there',
 'in',
 'the',
 'previous',
 'month',
 'because',
 'of',
 'neglect',
 'or',
 'medical',
 'errors',
 '.']

In [228]:
vocab.key2idx

{'<pad>': 0,
 '<unk>': 1,
 ',': 2,
 'the': 3,
 'and': 4,
 'of': 5,
 '.': 6,
 'in': 7,
 'a': 8,
 'to': 9,
 "''": 10,
 "'s": 11,
 'for': 12,
 'that': 13,
 'on': 14,
 'at': 15,
 'with': 16,
 'is': 17,
 'by': 18,
 'from': 19,
 'new': 20,
 'was': 21,
 'as': 22,
 'his': 23,
 'he': 24,
 'who': 25,
 'said': 26,
 ';': 27,
 '-rrb-': 28,
 '-lrb-': 29,
 'an': 30,
 '--': 31,
 'has': 32,
 'it': 33,
 'mr.': 34,
 'york': 35,
 'have': 36,
 'had': 37,
 'be': 38,
 'united': 39,
 'but': 40,
 ':': 41,
 'are': 42,
 'its': 43,
 'not': 44,
 'which': 45,
 'states': 46,
 'about': 47,
 'this': 48,
 'her': 49,
 'one': 50,
 'will': 51,
 'after': 52,
 'their': 53,
 'president': 54,
 'when': 55,
 'like': 56,
 'last': 57,
 'two': 58,
 'were': 59,
 'would': 60,
 'they': 61,
 'city': 62,
 'or': 63,
 'been': 64,
 'more': 65,
 '$': 66,
 'years': 67,
 'first': 68,
 'other': 69,
 'she': 70,
 'up': 71,
 'also': 72,
 'where': 73,
 'iraq': 74,
 'university': 75,
 'than': 76,
 'john': 77,
 'i': 78,
 'former': 79,
 'year': 80,


In [229]:
def prepare_sequence(seq, to_idx):
    return [to_idx[key] for key in seq]

In [217]:
sentence_idx = prepare_sequence(lower_sentence_text, vocab)
sentence_idx

[670,
 1158,
 40526,
 5,
 1018,
 2037,
 3,
 81,
 626,
 395,
 48,
 228,
 9,
 6048,
 10868,
 801,
 15,
 3,
 2131,
 2866,
 113,
 52,
 81,
 5755,
 346,
 13,
 97,
 427,
 37,
 300,
 88,
 7,
 3,
 1713,
 228,
 126,
 5,
 15086,
 63,
 680,
 7733,
 6]

In [232]:
tokens_idx = []
for token in sentence_text:
    if len(token) <= MAX_TOKEN_LENGTH:
        tokens_idx.append(prepare_sequence(token, charset) + [charset["<pad>"]]*(MAX_TOKEN_LENGTH-len(token)))
    else:
        tokens_idx.append(prepare_sequence(token[0:13] + token[-7:], charset))

In [234]:
tags_idx = [tag_set["O"]] * length_sent

In [241]:
test.append((sentence_idx, tokens_idx, tags_idx))

In [239]:
def save(obj, path):
    with open(path, 'wb') as f:
        pickle.dump(obj, f)

In [242]:
# 循环结束，保存数据
save(test, 'data/NYT_CoType/test.pk')
relation_labels.save('data/NYT_CoType/relation_labels.txt')
entity_labels.save('data/NYT_CoType/entity_labels.txt')
tag_set.save("data/NYT_CoType/tag2id.txt")

### 4.3 数据读取

In [21]:
charset = Charset()

In [22]:
vocab = Vocabulary()

In [23]:
vocab.load("data/NYT_CoType/vocab.txt")

In [33]:
vocab.key2idx

{'<pad>': 0,
 '<unk>': 1,
 ',': 2,
 'the': 3,
 'and': 4,
 'of': 5,
 '.': 6,
 'in': 7,
 'a': 8,
 'to': 9,
 "''": 10,
 "'s": 11,
 'for': 12,
 'that': 13,
 'on': 14,
 'at': 15,
 'with': 16,
 'is': 17,
 'by': 18,
 'from': 19,
 'new': 20,
 'was': 21,
 'as': 22,
 'his': 23,
 'he': 24,
 'who': 25,
 'said': 26,
 ';': 27,
 '-rrb-': 28,
 '-lrb-': 29,
 'an': 30,
 '--': 31,
 'has': 32,
 'it': 33,
 'mr.': 34,
 'york': 35,
 'have': 36,
 'had': 37,
 'be': 38,
 'united': 39,
 'but': 40,
 ':': 41,
 'are': 42,
 'its': 43,
 'not': 44,
 'which': 45,
 'states': 46,
 'about': 47,
 'this': 48,
 'her': 49,
 'one': 50,
 'will': 51,
 'after': 52,
 'their': 53,
 'president': 54,
 'when': 55,
 'like': 56,
 'last': 57,
 'two': 58,
 'were': 59,
 'would': 60,
 'they': 61,
 'city': 62,
 'or': 63,
 'been': 64,
 'more': 65,
 '$': 66,
 'years': 67,
 'first': 68,
 'other': 69,
 'she': 70,
 'up': 71,
 'also': 72,
 'where': 73,
 'iraq': 74,
 'university': 75,
 'than': 76,
 'john': 77,
 'i': 78,
 'former': 79,
 'year': 80,


In [24]:
tag_set = Index()   # 实体类别

In [25]:
tag_set.load("data/NYT_CoType/tag2id.txt")

In [26]:
relation_labels = Index()      # 关系类别

In [27]:
relation_labels.load('data/NYT_CoType/relation_labels.txt')

In [35]:
relation_labels.key2idx

{'/people/person/nationality': 0,
 '/location/country/capital': 1,
 '/location/location/contains': 2,
 '/people/deceased_person/place_of_death': 3,
 '/people/person/children': 4,
 '/people/person/place_of_birth': 5,
 '/people/person/place_lived': 6,
 '/location/country/administrative_divisions': 7,
 '/location/administrative_division/country': 8,
 '/business/person/company': 9,
 '/location/neighborhood/neighborhood_of': 10,
 '/business/company/place_founded': 11,
 '/business/company/founders': 12,
 '/sports/sports_team_location/teams': 13,
 '/sports/sports_team/location': 14,
 '/business/company_shareholder/major_shareholder_of': 15,
 '/business/company/major_shareholders': 16,
 '/people/ethnicity/people': 17,
 '/people/person/ethnicity': 18,
 '/business/company/advisors': 19,
 '/people/person/religion': 20,
 '/people/ethnicity/geographic_distribution': 21,
 '/people/person/profession': 22,
 '/business/company/industry': 23,
 'None': 24}

In [34]:
tag_set.key2idx

{'O': 0,
 'B-/people/person/nationality-1': 1,
 'B-/people/person/nationality-2': 2,
 'I-/people/person/nationality-1': 3,
 'I-/people/person/nationality-2': 4,
 'E-/people/person/nationality-1': 5,
 'E-/people/person/nationality-2': 6,
 'S-/people/person/nationality-1': 7,
 'S-/people/person/nationality-2': 8,
 'B-/location/country/capital-1': 9,
 'B-/location/country/capital-2': 10,
 'I-/location/country/capital-1': 11,
 'I-/location/country/capital-2': 12,
 'E-/location/country/capital-1': 13,
 'E-/location/country/capital-2': 14,
 'S-/location/country/capital-1': 15,
 'S-/location/country/capital-2': 16,
 'B-/location/location/contains-1': 17,
 'B-/location/location/contains-2': 18,
 'I-/location/location/contains-1': 19,
 'I-/location/location/contains-2': 20,
 'E-/location/location/contains-1': 21,
 'E-/location/location/contains-2': 22,
 'S-/location/location/contains-1': 23,
 'S-/location/location/contains-2': 24,
 'B-/people/deceased_person/place_of_death-1': 25,
 'B-/people/d

In [36]:
train_data = load('data/NYT_CoType/train.pk')

In [37]:
test_data = load('data/NYT_CoType/test.pk')

In [40]:
test_data[0]

([670,
  1158,
  40526,
  5,
  1018,
  2037,
  3,
  81,
  626,
  395,
  48,
  228,
  9,
  6048,
  10868,
  801,
  15,
  3,
  2131,
  2866,
  113,
  52,
  81,
  5755,
  346,
  13,
  97,
  427,
  37,
  300,
  88,
  7,
  3,
  1713,
  228,
  126,
  5,
  15086,
  63,
  680,
  7733,
  6],
 [[42,
   24,
   31,
   75,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [55,
   18,
   22,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [51,
   10,
   32,
   21,
   14,
   23,
   29,
   34,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [24,
   15,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [48,
   18,
   23,
   23,
   14,
   28,
   24,
   29,
   10,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [24,
   2

In [41]:
val_size = int(0.01 * len(train_data))
train_data, val_data = random_split(train_data, [len(train_data)-val_size, val_size])

In [46]:
val_data[0][0]

[34,
 44107,
 21,
 266,
 84,
 30,
 1513,
 187,
 7,
 3,
 538,
 4,
 21,
 868,
 7,
 19121,
 2,
 14,
 115,
 123,
 6]

In [47]:
def group(data, breakpoints):
    groups = [[] for _ in range(len(breakpoints)+1)]
    for idx, item in enumerate(data):
        i = bisect.bisect_left(breakpoints, len(item[0]))
        groups[i].append(idx)
    data_groups = [Subset(data, g) for g in groups]
    return data_groups

In [48]:
train_data_groups = group(train_data, [10, 20, 30, 40, 50, 60])
val_data_groups = group(val_data, [10, 20, 30, 40, 50, 60])
test_data_groups = group(test_data, [10, 20, 30, 40, 50, 60])

In [49]:
word_embeddings = torch.tensor(np.load("data/NYT_CoType/word2vec.vectors.npy"))
word_embedding_size = word_embeddings.size(1)
pad_embedding = torch.empty(1, word_embedding_size).uniform_(-0.5, 0.5)
unk_embedding = torch.empty(1, word_embedding_size).uniform_(-0.5, 0.5)

In [50]:
word_embeddings = torch.cat([pad_embedding, unk_embedding, word_embeddings])

In [51]:
word_embeddings.shape

torch.Size([47465, 300])

### 4.4 模型定义

In [52]:
char_channels = [args.emsize] + [args.char_nhid] * args.char_layers
word_channels = [word_embedding_size + args.char_nhid] + [args.word_nhid] * args.word_layers

In [53]:
char_channels

[50, 50, 50, 50]

In [54]:
word_channels

[350, 300, 300, 300]

In [55]:
if os.path.exists("model.pt"):
    model=torch.load('model.pt')
else:
    model = Model(charset_size=len(charset), char_embedding_size=args.emsize, char_channels=char_channels,
                  char_padding_idx=charset["<pad>"], char_kernel_size=args.char_kernel_size, weight=word_embeddings,
                  word_embedding_size=word_embedding_size, word_channels=word_channels,
                  word_kernel_size=args.word_kernel_size, num_tag=len(tag_set), dropout=args.dropout,
                  emb_dropout=args.emb_dropout).to(device)    # vscode

In [57]:
weight = [args.weight] * len(tag_set)
weight

[10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,
 10.0,

In [58]:
weight = torch.tensor(weight).to(device)

In [59]:
criterion = nn.NLLLoss(weight, size_average=False)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)



### 4.5 模型训练

In [60]:
best_val_loss = None
lr = args.lr
all_val_loss = []
all_precision = []
all_recall = []
all_f1 = []

In [62]:
class GroupBatchRandomSampler(object):
    def __init__(self, data_groups, batch_size, drop_last):
        self.batch_indices = []
        for data_group in data_groups:
            self.batch_indices.extend(list(BatchSampler(SubsetRandomSampler(data_group.indices),
                                                        batch_size, drop_last=drop_last)))

    def __iter__(self):
        return (self.batch_indices[i] for i in torch.randperm(len(self.batch_indices)))

    def __len__(self):
        return len(self.batch_indices)

In [63]:
#train()
model.train()
total_loss = 0
count = 0
sampler = GroupBatchRandomSampler(train_data_groups, args.batch_size, drop_last=False)

In [66]:
len(sampler.batch_indices)

7093

In [69]:
batch_indices = sampler.batch_indices[0]
batch_indices

[202991,
 52820,
 98562,
 225404,
 211013,
 130004,
 110339,
 128196,
 78333,
 186118,
 60559,
 74588,
 42842,
 163144,
 177977,
 67215,
 23235,
 52733,
 17306,
 29014,
 155352,
 9532,
 41117,
 10049,
 103004,
 143077,
 74849,
 135856,
 215603,
 132474,
 9404,
 132062]

##### 函数 get_batch

In [71]:
data = train_data
batch = [data[idx] for idx in batch_indices]

In [72]:
len(batch)

32

In [73]:
batch[0]

([1, 109, 3092, 2, 1006, 2, 283, 1, 6],
 [[1, 0, 8, 2, 5, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94],
  [40,
   10,
   28,
   29,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [37,
   24,
   30,
   21,
   14,
   31,
   10,
   27,
   13,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [73,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [38,
   21,
   14,
   31,
   14,
   21,
   10,
   23,
   13,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [73,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [50,
   17,
   18,
   24,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94,
   94],
  [4, 4, 1, 0, 6, 94, 94, 94, 94, 9

In [74]:
sorted_batch = sorted(batch, key=lambda x: len(x[0]), reverse=True)

In [75]:
sentences, tokens, tags = zip(*sorted_batch)

In [76]:
sentences

([33, 21, 3, 27051, 805, 2, 44, 27051, 4532, 6],
 [6759, 11, 68, 2342, 9, 780, 21, 8, 1413, 6],
 [57, 5, 3, 12754, 221, 5, 3625, 4, 946, 6],
 [33, 2446, 11057, 119, 291, 2, 207, 4, 91, 6],
 [2445, 4744, 2, 9799, 27413, 17794, 2, 27489, 2, 2154],
 [24, 21, 7684, 4, 672, 7, 1907, 2, 202, 6],
 [3035, 32, 3625, 20066, 4, 16193, 37, 2373, 7742, 6],
 [1, 109, 3092, 2, 1006, 2, 283, 1, 6],
 [2028, 3323, 47010, 2, 459, 19, 1, 211, 6],
 [4389, 41, 77, 6563, 2, 1609, 556, 7057, 6],
 [963, 795, 170, 2, 10371, 2, 2152, 4, 6985],
 [2146, 4, 1992, 26, 61, 1, 460, 69, 6],
 [3400, 216, 516, 126, 5, 3, 634, 6433, 6],
 [1543, 11105, 1999, 7808, 15, 177, 5884, 6, 10],
 [659, 2, 3025, 2, 120, 2, 494, 4, 2311],
 [764, 2, 629, 4, 1990, 1, 31, 1524, 6],
 [7, 3, 966, 2241, 20516, 757, 19, 516, 6],
 [24, 17, 1, 1, 2, 44, 1, 6],
 [1539, 1321, 83, 11, 11954, 4136, 103, 6],
 [32011, 217, 252, 2, 4849, 2, 1052, 6],
 [323, 5939, 2, 7567, 268, 25690, 2, 304],
 [371, 5, 3526, 2, 944, 4, 792, 6],
 [6632, 1537, 5, 523,

In [78]:
tags

([0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 24, 0, 23, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 24, 0, 23, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 24, 0, 23, 0],
 [0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 81, 85, 0, 88, 0],
 [0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0],
 [24, 0, 23, 0])

In [79]:
padded_sentences, lengths = pad_packed_sequence(pack_sequence([torch.LongTensor(_) for _ in sentences]),
                                                    batch_first=True, padding_value=vocab["<pad>"])
padded_tokens, _ = pad_packed_sequence(pack_sequence([torch.LongTensor(_) for _ in tokens]),
                                       batch_first=True, padding_value=charset["<pad>"])
padded_tags, _ = pad_packed_sequence(pack_sequence([torch.LongTensor(_) for _ in tags]),
                                     batch_first=True, padding_value=tag_set["O"])

In [80]:
padded_sentences.to(device)
padded_tokens.to(device)
padded_tags.to(device)
lengths.to(device)

tensor([10, 10, 10, 10, 10, 10, 10,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  8,
         8,  8,  8,  8,  8,  8,  7,  7,  7,  7,  6,  6,  5,  4],
       device='cuda:0')

In [88]:
def get_batch(batch_indices, data):
    batch = [data[idx] for idx in batch_indices]
    sorted_batch = sorted(batch, key=lambda x: len(x[0]), reverse=True)
    sentences, tokens, tags = zip(*sorted_batch)

    padded_sentences, lengths = pad_packed_sequence(pack_sequence([torch.LongTensor(_) for _ in sentences]),
                                                    batch_first=True, padding_value=vocab["<pad>"])
    padded_tokens, _ = pad_packed_sequence(pack_sequence([torch.LongTensor(_) for _ in tokens]),
                                           batch_first=True, padding_value=charset["<pad>"])
    padded_tags, _ = pad_packed_sequence(pack_sequence([torch.LongTensor(_) for _ in tags]),
                                         batch_first=True, padding_value=tag_set["O"])

    return padded_sentences.to(device), padded_tokens.to(device), padded_tags.to(device), lengths.to(device)


In [83]:
sentences, tokens, targets, lengths = padded_sentences.to(device), padded_tokens.to(device), padded_tags.to(device), lengths.to(device)

In [84]:
optimizer.zero_grad()

In [85]:
output = model(sentences, tokens)  # vscode

In [87]:
output.shape

torch.Size([32, 10, 193])

In [89]:
targets.shape

torch.Size([32, 10])

In [90]:
lengths

tensor([10, 10, 10, 10, 10, 10, 10,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  8,
         8,  8,  8,  8,  8,  8,  7,  7,  7,  7,  6,  6,  5,  4],
       device='cuda:0')

In [91]:
output = pack_padded_sequence(output, lengths, batch_first=True).data
targets = pack_padded_sequence(targets, lengths, batch_first=True).data

In [92]:
targets.shape

torch.Size([265])

In [93]:
output.shape

torch.Size([265, 193])

In [94]:
loss = criterion(output, targets)
loss

tensor(14105.5801, device='cuda:0', grad_fn=<NllLossBackward>)

In [95]:
loss.backward()
if args.clip > 0:
    nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()

### 4.6 模型评价

In [96]:
# val_loss, precision, recall, f1 = evaluate(val_data_groups)
# evaluate
model.eval()
total_loss = 0
count = 0
TP = 0
TP_FP = 0
TP_FN = 0

In [97]:
val_data_sampler = GroupBatchRandomSampler(val_data_groups, args.batch_size, drop_last=False)

In [99]:
batch_indices = val_data_sampler.batch_indices[0]
batch_indices

[1375,
 1353,
 42,
 173,
 389,
 885,
 96,
 2068,
 676,
 1429,
 620,
 652,
 58,
 1802,
 1403,
 1952,
 1201,
 1194]

In [100]:
len(batch_indices)

18

In [141]:
sentences, tokens, targets, lengths = get_batch(batch_indices, val_data)

In [142]:
output = model(sentences, tokens)

In [143]:
output.shape

torch.Size([18, 10, 193])

##### 函数 measure

In [144]:
# tp, tp_fp, tp_fn = measure(output, targets, lengths)
assert output.size(0) == targets.size(0) and targets.size(0) == lengths.size(0)
tp = 0
tp_fp = 0
tp_fn = 0
batch_size = output.size(0)
output_ = torch.argmax(output, dim=-1)

In [145]:
# for i in range(batch_size):
i = 0

In [146]:
length = lengths[i]
length

tensor(10, device='cuda:0')

In [147]:
output_.shape

torch.Size([18, 10])

In [148]:
out = output_[i][:length].tolist()
out

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [149]:
targets.shape

torch.Size([18, 10])

In [150]:
targets

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0, 39,  0, 40,  0],
        [ 0,  0,  0,  0, 23,  0,  0, 24,  0,  0],
        [ 0,  0,  0,  0,  0, 18, 22,  0, 23,  0],
        [ 0,  0,  0, 24,  0,  0,  0, 23,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0, 24,  0, 23,  0,  0,  0,  0,  0],
        [ 0,  0, 87,  0, 88,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0, 24,  0, 23,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [18, 20, 22,  0,  0, 23,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')

In [151]:
target = targets[i][:length].tolist()
target

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

####### 函数 get_triplets #######

In [158]:
# get_triplets
# out_triplets = get_triplets(out)
tags = out

In [159]:
temp = {}
triplets = []

In [160]:
for idx, tag in enumerate(tags):
    if tag == tag_set["O"]:
        continue
    pos, relation_label, role = tag_set[tag].split("-")
    if pos == "B" or pos == "S":
        if relation_label not in temp:
            temp[relation_label] = [[], []]
        temp[relation_label][int(role) - 1].append(idx)

In [161]:
for relation_label in temp:
    role1, role2 = temp[relation_label]
    if role1 and role2:
        len1, len2 = len(role1), len(role2)
        if len1 > len2:
            for e2 in role2:
                idx = np.argmin([abs(e2 - e1) for e1 in role1])
                e1 = role1[idx]
                triplets.append((e1, relation_label, e2))
                del role1[idx]
        else:
            for e1 in role1:
                idx = np.argmin([abs(e2 - e1) for e2 in role2])
                e2 = role2[idx]
                triplets.append((e1, relation_label, e2))
                del role2[idx]
# return triplets

In [162]:
def get_triplets(tags):
    temp = {}
    triplets = []
    for idx, tag in enumerate(tags):
        if tag == tag_set["O"]:
            continue
        pos, relation_label, role = tag_set[tag].split("-")
        if pos == "B" or pos == "S":
            if relation_label not in temp:
                temp[relation_label] = [[], []]
            temp[relation_label][int(role) - 1].append(idx)
    for relation_label in temp:
        role1, role2 = temp[relation_label]
        if role1 and role2:
            len1, len2 = len(role1), len(role2)
            if len1 > len2:
                for e2 in role2:
                    idx = np.argmin([abs(e2 - e1) for e1 in role1])
                    e1 = role1[idx]
                    triplets.append((e1, relation_label, e2))
                    del role1[idx]
            else:
                for e1 in role1:
                    idx = np.argmin([abs(e2 - e1) for e2 in role2])
                    e2 = role2[idx]
                    triplets.append((e1, relation_label, e2))
                    del role2[idx]
    return triplets

In [163]:
out_triplets = get_triplets(out)
tp_fp += len(out_triplets)

In [164]:
target_triplets = get_triplets(target)
tp_fn += len(target_triplets)

In [165]:
for target_triplet in target_triplets:
    for out_triplet in out_triplets:
        if out_triplet == target_triplet:
            tp += 1

In [166]:
def measure(output, targets, lengths):
    assert output.size(0) == targets.size(0) and targets.size(0) == lengths.size(0)
    tp = 0
    tp_fp = 0
    tp_fn = 0
    batch_size = output.size(0)
    output = torch.argmax(output, dim=-1)
    for i in range(batch_size):
        length = lengths[i]
        out = output[i][:length].tolist()
        target = targets[i][:length].tolist()
        out_triplets = get_triplets(out)
        tp_fp += len(out_triplets)
        target_triplets = get_triplets(target)
        tp_fn += len(target_triplets)
        for target_triplet in target_triplets:
            for out_triplet in out_triplets:
                if out_triplet == target_triplet:
                    tp += 1
    return tp, tp_fp, tp_fn

In [167]:
TP += tp
TP_FP += tp_fp
TP_FN += tp_fn

In [168]:
output = pack_padded_sequence(output, lengths, batch_first=True).data
targets = pack_padded_sequence(targets, lengths, batch_first=True).data

In [169]:
targets.shape

torch.Size([154])

In [170]:
output.shape

torch.Size([154, 193])

In [172]:
loss = criterion(output, targets)
loss

tensor(3865.6277, device='cuda:0', grad_fn=<NllLossBackward>)

In [173]:
total_loss += loss.item()
count += len(targets)

# 5 代码梳理及细节回顾(在VScode中演示)

　　在VScode环境中的训练文件里再回顾训练流程。

# 6 作业
  
`【思考题】`思考这篇文章的模型的不足，有什么可以改进的地方，是否还可以想到其他联合处理实体和关系抽取的新的框架。

`【代码实践】`复现该文章代码的模型部分。

`【画图】`不看文章原图，按照自己的理解画出模型的结构图。

`【总结】`对这篇文章进行回顾，从描述背景，到提出模型，再到实验证明，思考并学习文章是如何将他们组合在一起的，学习文章的写作手法和思路。

---