-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathjson_opennre.py
100 lines (79 loc) · 3.37 KB
/
json_opennre.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json
import logging
import os
from os.path import dirname
from arekit.common.data import const
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.base import BaseWriter
logger = logging.getLogger(__name__)
class OpenNREJsonWriter(BaseWriter):
""" This is a bag-based writer for the samples.
Project page: https://github.com/thunlp/OpenNRE
Every bag presented as follows:
{
'text' or 'token': ...,
'h': {'pos': [start, end], 'id': ... },
't': {'pos': [start, end], 'id': ... }
'id': "id_of_the_text_opinion"
}
In terms of the linked opinions (i0, i1, etc.) we consider id of the first opinion in linkage.
During the dataset reading stage via OpenNRE, these linkages automaticaly groups into bags.
"""
def __init__(self, text_columns, encoding="utf-8"):
assert(isinstance(encoding, str))
self.__encoding = encoding
self.__text_columns = text_columns
self.__target_f = None
@staticmethod
def __format_row(row, text_columns):
""" Formatting that is compatible with the OpenNRE.
"""
sample_id = row[const.ID]
s_ind = int(row[const.S_IND])
t_ind = int(row[const.T_IND])
bag_id = sample_id[0:sample_id.find('_i')]
# Gather tokens.
tokens = []
for text_col in text_columns:
tokens.extend(row[text_col].split())
# Filtering JSON row.
return {
"id": bag_id,
"id_orig": sample_id,
"token": tokens,
"h": {"pos": [s_ind, s_ind + 1], "id": str(bag_id + "s")},
"t": {"pos": [t_ind, t_ind + 1], "id": str(bag_id + "t")},
"relation": str(int(row[const.LABEL])) if const.LABEL in row else "NA"
}
def open_target(self, target):
os.makedirs(dirname(target), exist_ok=True)
self.__target_f = open(target, "w")
pass
def close_target(self):
self.__target_f.close()
def commit_line(self, storage):
assert(isinstance(storage, RowCacheStorage))
# Collect existed columns.
row_data = {}
for col_name in storage.iter_column_names():
if col_name not in storage.RowCache:
continue
row_data[col_name] = storage.RowCache[col_name]
self.__write_bag(bag=self.__format_row(row_data, text_columns=self.__text_columns),
json_file=self.__target_f)
@staticmethod
def __write_bag(bag, json_file):
assert(isinstance(bag, dict))
json.dump(bag, json_file, separators=(",", ":"), ensure_ascii=False)
json_file.write("\n")
def write_all(self, storage, target):
assert(isinstance(storage, BaseRowsStorage))
assert(isinstance(target, str))
logger.info("Saving... {rows}: {filepath}".format(rows=(len(storage)), filepath=target))
os.makedirs(os.path.dirname(target), exist_ok=True)
with open(target, "w", encoding=self.__encoding) as json_file:
for row_index, row in storage:
self.__write_bag(bag=self.__format_row(row, text_columns=self.__text_columns),
json_file=json_file)
logger.info("Saving completed!")