Skip to content

Commit

Permalink
#502 consdider skip_extra_existed option
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Sep 27, 2023
1 parent b84d545 commit 273f409
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions arekit/contrib/utils/data/writers/json_opennre.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class OpenNREJsonWriter(BaseWriter):
During the dataset reading stage via OpenNRE, these linkages automaticaly groups into bags.
"""

def __init__(self, text_columns, encoding="utf-8", na_value="NA", keep_extra_columns=True):
def __init__(self, text_columns, encoding="utf-8", na_value="NA", keep_extra_columns=True,
skip_extra_existed=True):
""" text_columns: list
column names that expected to be joined into a single (token) column.
"""
Expand All @@ -38,12 +39,13 @@ def __init__(self, text_columns, encoding="utf-8", na_value="NA", keep_extra_col
self.__target_f = None
self.__keep_extra_columns = keep_extra_columns
self.__na_value = na_value
self.__skip_extra_existed = skip_extra_existed

def extension(self):
return ".jsonl"

@staticmethod
def __format_row(row, na_value, text_columns, keep_extra_columns):
def __format_row(row, na_value, text_columns, keep_extra_columns, skip_extra_existed):
""" Formatting that is compatible with the OpenNRE.
"""
assert(isinstance(na_value, str))
Expand Down Expand Up @@ -75,8 +77,11 @@ def __format_row(row, na_value, text_columns, keep_extra_columns):
if key not in formatted_data and key not in text_columns:
formatted_data[key] = value
else:
raise Exception(f"key `{key}` is already exist in formatted data "
f"or a part of the text columns list: f{text_columns}")
info = f"key `{key}` is already exist in formatted data "\
f"or a part of the text columns list: {text_columns}"
logger.info(info)
if not skip_extra_existed:
raise Exception(info)

return formatted_data

Expand All @@ -100,7 +105,8 @@ def commit_line(self, storage):

bag = self.__format_row(row_data, text_columns=self.__text_columns,
keep_extra_columns=self.__keep_extra_columns,
na_value=self.__na_value)
na_value=self.__na_value,
skip_extra_existed=self.__skip_extra_existed)

self.__write_bag(bag=bag, json_file=self.__target_f)

Expand All @@ -121,7 +127,8 @@ def write_all(self, storage, target):
for row_index, row in storage:
self.__write_bag(bag=self.__format_row(row, text_columns=self.__text_columns,
keep_extra_columns=self.__keep_extra_columns,
na_value=self.__na_value),
na_value=self.__na_value,
skip_extra_existed=self.__skip_extra_existed),
json_file=json_file)

logger.info("Saving completed!")

0 comments on commit 273f409

Please sign in to comment.