-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_modules.py
114 lines (84 loc) · 4.17 KB
/
data_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Author: lqxu
import os
from typing import *
import torch
from torch import Tensor
from torchtext import transforms
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule
from core.utils import ROOT_DIR
from core.utils import get_default_tokenizer
from scheme import GPLinkerREScheme
# 这里借用 CasRel 中的 data_modules 模块数据处理的结果, 仅仅将必要的变量拷贝过来
hf_data_dir: str = os.path.join(ROOT_DIR, "examples/relation_extraction/CasRel/output", "hf_dataset")
relation_labels = [ # 一共有 48 个关系标签, 考虑到训练难度, 将训练集中数量小于 1000 的标签都删除了, 因此这里只有 34 个标签
'主演', '作者', '歌手', '导演', '父亲', '成立日期', '妻子', '丈夫', '国籍', '母亲', '作词', '作曲', '毕业院校',
'所属专辑', '董事长', '朝代', '嘉宾', '出品公司', '编剧', '上映时间', '饰演', '简称', '主持人', '配音', '获奖',
'主题曲', '校长', '总部地点', '主角', '创始人', '票房', '制片人', '号', '祖籍'
]
max_num_tokens = 192
tokenizer = get_default_tokenizer()
tokenizer_kwargs = {
"max_length": max_num_tokens, "truncation": True,
"return_attention_mask": False, "return_token_type_ids": False,
}
scheme = GPLinkerREScheme(max_num_tokens=192, num_relations=len(relation_labels))
class DataCollate:
def __init__(self, is_train_stage: bool):
self.scheme = scheme
self.max_num_tokens = max_num_tokens
self.transforms = transforms.Sequential(
transforms.ToTensor(padding_value=0),
transforms.PadTransform(max_num_tokens, pad_value=0)
)
self.is_train_stage = is_train_stage
def __call__(self, batch: List[Dict[str, List[Any]]]) -> Dict[str, Tensor]:
input_ids = self.transforms([sample["text"] for sample in batch])
subject_tensor, object_tensor, head_tensor, tail_tensor, sro_sets = [], [], [], [], []
for sample in batch:
sro_set = {tuple(sro) for sro in sample["sro_list"]}
st, ot, ht, tt = self.scheme.encode(sro_set)
subject_tensor.append(st)
object_tensor.append(ot)
head_tensor.append(ht)
tail_tensor.append(tt)
sro_sets.append(sro_set)
subject_tensor = torch.stack(subject_tensor, dim=0) # [batch_size, n_tokens, n_tokens]
object_tensor = torch.stack(object_tensor, dim=0) # [batch_size, n_tokens, n_tokens]
head_tensor = torch.stack(head_tensor, dim=0) # [batch_size, n_relations, n_tokens, n_tokens]
tail_tensor = torch.stack(tail_tensor, dim=0) # [batch_size, n_relations, n_tokens, n_tokens]
# [batch_size, 2, n_tokens, n_tokens]
entity_tensor = torch.cat([subject_tensor.unsqueeze(1), object_tensor.unsqueeze(1)], dim=1)
ret = {
"input_ids": input_ids, "entity_tensor": entity_tensor,
"head_tensor": head_tensor, "tail_tensor": tail_tensor
}
if not self.is_train_stage:
ret["sro_sets"] = sro_sets
return ret
class DuIEDataModule(LightningDataModule):
def __init__(self, batch_size: int):
super(DuIEDataModule, self).__init__()
# 超参设置
self.batch_size = batch_size
# 其它设置
self.hf_dataset = None
def prepare_data(self):
if not os.path.exists(hf_data_dir):
raise ValueError("文件路径不存在, 请先调用 CasRel 方法中的 init_hf_dataset 方法 !!!")
def setup(self, stage: str) -> None:
from datasets import DatasetDict
self.hf_dataset = DatasetDict.load_from_disk(hf_data_dir)
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.hf_dataset["train"],
batch_size=self.batch_size, shuffle=True, num_workers=8,
collate_fn=DataCollate(is_train_stage=True)
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.hf_dataset["dev"],
batch_size=self.batch_size * 4, shuffle=False, num_workers=8,
collate_fn=DataCollate(is_train_stage=False)
)
test_dataloader = val_dataloader