Skip to content

Commit

Permalink
fixed #131
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 16, 2022
1 parent bfc62aa commit b7edea6
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 24 deletions.
2 changes: 1 addition & 1 deletion arekit/contrib/source/rusentrel/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def read_collection(cls, doc_id, synonyms, version=RuSentRelVersions.V11):
assert (isinstance(doc_id, int))

return RuSentRelIOUtils.read_from_zip(
inner_path=RuSentRelIOUtils.get_entity_innerpath(doc_id),
inner_path=RuSentRelIOUtils.get_entity_innerpath(index=doc_id, version=version),
process_func=lambda input_file: cls(
entities=BratAnnotationParser.parse_annotations(input_file)["entities"],
value_to_group_id_func=synonyms.get_synonym_group_index),
Expand Down
44 changes: 28 additions & 16 deletions arekit/contrib/source/rusentrel/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@


class RuSentRelVersions(Enum):
""" Original collection repository: https://github.com/nicolay-r/RuSentRel
Paper: https://arxiv.org/abs/1808.08932
"""
V11 = "v1_1"


class RuSentRelIOUtils(ZipArchiveUtils):

__sep_doc_id = 46
TEST_FOLDER = "test"
TRAIN_FOLDER = "train"
ETALON_FOLDER = "etalon"

@staticmethod
def get_archive_filepath(version):
Expand All @@ -21,20 +26,22 @@ def get_archive_filepath(version):
# region internal methods

@staticmethod
def get_sentiment_opin_filepath(index, prefix='art'):
root = RuSentRelIOUtils.__get_root_by_index(index, is_opinion=True)
return path.join(root, "{}{}.opin.txt".format(prefix, index))
def get_sentiment_opin_filepath(index, version, prefix='art'):
root = RuSentRelIOUtils.__get_root_by_index(index, version=version, keep_etalon=True)
return path.join(root, "{prefix}{index}.opin.txt".format(prefix=prefix, index=index))

@staticmethod
def get_entity_innerpath(index):
def get_entity_innerpath(index, version):
assert(isinstance(index, int))
inner_root = RuSentRelIOUtils.__get_root_by_index(index)
assert(isinstance(version, RuSentRelVersions))
inner_root = RuSentRelIOUtils.__get_root_by_index(doc_id=index, version=version)
return path.join(inner_root, "art{}.ann".format(index))

@staticmethod
def get_news_innerpath(index):
def get_news_innerpath(index, version):
assert(isinstance(index, int))
inner_root = RuSentRelIOUtils.__get_root_by_index(index)
assert(isinstance(version, RuSentRelVersions))
inner_root = RuSentRelIOUtils.__get_root_by_index(doc_id=index, version=version)
return path.join(inner_root, "art{}.txt".format(index))

@staticmethod
Expand All @@ -44,17 +51,18 @@ def get_synonyms_innerpath():
# endregion

@staticmethod
def __get_root_by_index(doc_id, is_opinion=False):
def __get_root_by_index(doc_id, version, keep_etalon=False):
assert(RuSentRelIOUtils.__is_supported(version))
assert(isinstance(version, RuSentRelVersions))
assert(isinstance(doc_id, int))
other_dir = 'etalon' if is_opinion else 'test'
return other_dir if doc_id >= RuSentRelIOUtils.__sep_doc_id else "train"
other_dir = RuSentRelIOUtils.ETALON_FOLDER if keep_etalon else RuSentRelIOUtils.TRAIN_FOLDER
test_indices = set(RuSentRelIOUtils.__iter_indicies_from_dataset(version, RuSentRelIOUtils.TRAIN_FOLDER))
return other_dir if doc_id in test_indices else RuSentRelIOUtils.TRAIN_FOLDER

@staticmethod
def __is_supported(version):
assert(isinstance(version, RuSentRelVersions))
if version != RuSentRelVersions.V11:
raise NotImplementedError("Collection does not supported")
return True
return version == RuSentRelVersions.V11

@staticmethod
def __number_from_string(s):
Expand Down Expand Up @@ -93,13 +101,17 @@ def __iter_indicies_from_dataset(version, folder_name):
@staticmethod
def iter_test_indices(version):
assert(RuSentRelIOUtils.__is_supported(version))
for index in RuSentRelIOUtils.__iter_indicies_from_dataset(version=version, folder_name="test/"):
indices_iter = RuSentRelIOUtils.__iter_indicies_from_dataset(
version=version, folder_name="{}/".format(RuSentRelIOUtils.TEST_FOLDER))
for index in indices_iter:
yield index

@staticmethod
def iter_train_indices(version):
assert(RuSentRelIOUtils.__is_supported(version))
for index in RuSentRelIOUtils.__iter_indicies_from_dataset(version=version, folder_name="train/"):
indices_iter = RuSentRelIOUtils.__iter_indicies_from_dataset(
version=version, folder_name="{}/".format(RuSentRelIOUtils.TRAIN_FOLDER))
for index in indices_iter:
yield index

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/source/rusentrel/news_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ def file_to_doc(input_file):
version=version)

return RuSentRelIOUtils.read_from_zip(
inner_path=RuSentRelIOUtils.get_news_innerpath(doc_id),
inner_path=RuSentRelIOUtils.get_news_innerpath(index=doc_id, version=version),
process_func=file_to_doc,
version=version)
11 changes: 5 additions & 6 deletions arekit/contrib/source/rusentrel/opinions/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@ class RuSentRelOpinionCollection:
def iter_opinions_from_doc(doc_id,
labels_fmt=RuSentRelLabelsFormatter(),
version=RuSentRelVersions.V11):
"""
doc_id:
synonyms: None or SynonymsCollection
None corresponds to the related synonym collection from RuSentRel collection.
version:
""" doc_id:
synonyms: None or SynonymsCollection
None corresponds to the related synonym collection from RuSentRel collection.
version: RuSentrelVersions enum
"""
assert(isinstance(version, RuSentRelVersions))
assert(isinstance(labels_fmt, StringLabelsFormatter))
assert(labels_fmt.supports_value(POS_LABEL_STR))
assert(labels_fmt.supports_value(NEG_LABEL_STR))

return RuSentRelIOUtils.iter_from_zip(
inner_path=RuSentRelIOUtils.get_sentiment_opin_filepath(doc_id),
inner_path=RuSentRelIOUtils.get_sentiment_opin_filepath(index=doc_id, version=version),
process_func=lambda input_file: RuSentRelOpinionCollectionProvider._iter_opinions_from_file(
input_file=input_file,
labels_formatter=labels_fmt,
Expand Down
2 changes: 2 additions & 0 deletions tests/contrib/source/test_rusentrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def test_iter_train_indices(self):

def test_iter_test_indices(self):
test_indices = list(RuSentRelIOUtils.iter_test_indices(self.rsr_version))
for i in test_indices:
print(i, end=' ')

for i in range(46, 76):
if i in [70]:
Expand Down

0 comments on commit b7edea6

Please sign in to comment.