Skip to content

Commit

Permalink
#502 related.
Browse files Browse the repository at this point in the history
#376 related to bag_id.
  • Loading branch information
nicolay-r committed Jul 22, 2023
1 parent 539c8f5 commit 5444404
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions arekit/contrib/utils/data/writers/json_opennre.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@ class OpenNREJsonWriter(BaseWriter):
During the dataset reading stage via OpenNRE, these linkages automaticaly groups into bags.
"""

EXTRA_KEYS_TEMPLATE = "_{}"

def __init__(self, text_columns, encoding="utf-8"):
""" text_columns: list
column names that expected to be joined into a single (token) column.
"""
assert(isinstance(text_columns, list))
assert(isinstance(encoding, str))
self.__encoding = encoding
self.__text_columns = text_columns
self.__encoding = encoding
self.__target_f = None

@staticmethod
Expand All @@ -41,7 +47,7 @@ def __format_row(row, text_columns):
sample_id = row[const.ID]
s_ind = int(row[const.S_IND])
t_ind = int(row[const.T_IND])
bag_id = sample_id[0:sample_id.find('_i')]
bag_id = str(row[const.OPINION_ID])

# Gather tokens.
tokens = []
Expand All @@ -50,7 +56,7 @@ def __format_row(row, text_columns):
tokens.extend(row[text_col].split())

# Filtering JSON row.
return {
formatted_data = {
"id": bag_id,
"id_orig": sample_id,
"token": tokens,
Expand All @@ -59,6 +65,13 @@ def __format_row(row, text_columns):
"relation": str(int(row[const.LABEL_UINT])) if const.LABEL_UINT in row else "NA"
}

# Register extra fields.
for key, value in row.items():
if key not in formatted_data and key not in text_columns:
formatted_data[OpenNREJsonWriter.EXTRA_KEYS_TEMPLATE.format(key)] = value

return formatted_data

def open_target(self, target):
os.makedirs(dirname(target), exist_ok=True)
self.__target_f = open(target, "w")
Expand Down

0 comments on commit 5444404

Please sign in to comment.