From b7edea6a6e905c21ff1759193dee6e88c291a385 Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Sat, 16 Jul 2022 11:27:39 +0300 Subject: [PATCH] fixed #131 --- arekit/contrib/source/rusentrel/entities.py | 2 +- arekit/contrib/source/rusentrel/io_utils.py | 44 ++++++++++++------- .../contrib/source/rusentrel/news_reader.py | 2 +- .../source/rusentrel/opinions/collection.py | 11 +++-- tests/contrib/source/test_rusentrel.py | 2 + 5 files changed, 37 insertions(+), 24 deletions(-) diff --git a/arekit/contrib/source/rusentrel/entities.py b/arekit/contrib/source/rusentrel/entities.py index bc766803..f667a15d 100644 --- a/arekit/contrib/source/rusentrel/entities.py +++ b/arekit/contrib/source/rusentrel/entities.py @@ -19,7 +19,7 @@ def read_collection(cls, doc_id, synonyms, version=RuSentRelVersions.V11): assert (isinstance(doc_id, int)) return RuSentRelIOUtils.read_from_zip( - inner_path=RuSentRelIOUtils.get_entity_innerpath(doc_id), + inner_path=RuSentRelIOUtils.get_entity_innerpath(index=doc_id, version=version), process_func=lambda input_file: cls( entities=BratAnnotationParser.parse_annotations(input_file)["entities"], value_to_group_id_func=synonyms.get_synonym_group_index), diff --git a/arekit/contrib/source/rusentrel/io_utils.py b/arekit/contrib/source/rusentrel/io_utils.py index 7e71b5a5..10e92bca 100644 --- a/arekit/contrib/source/rusentrel/io_utils.py +++ b/arekit/contrib/source/rusentrel/io_utils.py @@ -6,12 +6,17 @@ class RuSentRelVersions(Enum): + """ Original collection repository: https://github.com/nicolay-r/RuSentRel + Paper: https://arxiv.org/abs/1808.08932 + """ V11 = "v1_1" class RuSentRelIOUtils(ZipArchiveUtils): - __sep_doc_id = 46 + TEST_FOLDER = "test" + TRAIN_FOLDER = "train" + ETALON_FOLDER = "etalon" @staticmethod def get_archive_filepath(version): @@ -21,20 +26,22 @@ def get_archive_filepath(version): # region internal methods @staticmethod - def get_sentiment_opin_filepath(index, prefix='art'): - root = RuSentRelIOUtils.__get_root_by_index(index, is_opinion=True) - return path.join(root, "{}{}.opin.txt".format(prefix, index)) + def get_sentiment_opin_filepath(index, version, prefix='art'): + root = RuSentRelIOUtils.__get_root_by_index(index, version=version, keep_etalon=True) + return path.join(root, "{prefix}{index}.opin.txt".format(prefix=prefix, index=index)) @staticmethod - def get_entity_innerpath(index): + def get_entity_innerpath(index, version): assert(isinstance(index, int)) - inner_root = RuSentRelIOUtils.__get_root_by_index(index) + assert(isinstance(version, RuSentRelVersions)) + inner_root = RuSentRelIOUtils.__get_root_by_index(doc_id=index, version=version) return path.join(inner_root, "art{}.ann".format(index)) @staticmethod - def get_news_innerpath(index): + def get_news_innerpath(index, version): assert(isinstance(index, int)) - inner_root = RuSentRelIOUtils.__get_root_by_index(index) + assert(isinstance(version, RuSentRelVersions)) + inner_root = RuSentRelIOUtils.__get_root_by_index(doc_id=index, version=version) return path.join(inner_root, "art{}.txt".format(index)) @staticmethod @@ -44,17 +51,18 @@ def get_synonyms_innerpath(): # endregion @staticmethod - def __get_root_by_index(doc_id, is_opinion=False): + def __get_root_by_index(doc_id, version, keep_etalon=False): + assert(RuSentRelIOUtils.__is_supported(version)) + assert(isinstance(version, RuSentRelVersions)) assert(isinstance(doc_id, int)) - other_dir = 'etalon' if is_opinion else 'test' - return other_dir if doc_id >= RuSentRelIOUtils.__sep_doc_id else "train" + other_dir = RuSentRelIOUtils.ETALON_FOLDER if keep_etalon else RuSentRelIOUtils.TRAIN_FOLDER + test_indices = set(RuSentRelIOUtils.__iter_indicies_from_dataset(version, RuSentRelIOUtils.TRAIN_FOLDER)) + return other_dir if doc_id in test_indices else RuSentRelIOUtils.TRAIN_FOLDER @staticmethod def __is_supported(version): assert(isinstance(version, RuSentRelVersions)) - if version != RuSentRelVersions.V11: - raise NotImplementedError("Collection does not supported") - return True + return version == RuSentRelVersions.V11 @staticmethod def __number_from_string(s): @@ -93,13 +101,17 @@ def __iter_indicies_from_dataset(version, folder_name): @staticmethod def iter_test_indices(version): assert(RuSentRelIOUtils.__is_supported(version)) - for index in RuSentRelIOUtils.__iter_indicies_from_dataset(version=version, folder_name="test/"): + indices_iter = RuSentRelIOUtils.__iter_indicies_from_dataset( + version=version, folder_name="{}/".format(RuSentRelIOUtils.TEST_FOLDER)) + for index in indices_iter: yield index @staticmethod def iter_train_indices(version): assert(RuSentRelIOUtils.__is_supported(version)) - for index in RuSentRelIOUtils.__iter_indicies_from_dataset(version=version, folder_name="train/"): + indices_iter = RuSentRelIOUtils.__iter_indicies_from_dataset( + version=version, folder_name="{}/".format(RuSentRelIOUtils.TRAIN_FOLDER)) + for index in indices_iter: yield index @staticmethod diff --git a/arekit/contrib/source/rusentrel/news_reader.py b/arekit/contrib/source/rusentrel/news_reader.py index c9fffdb2..618ebcce 100644 --- a/arekit/contrib/source/rusentrel/news_reader.py +++ b/arekit/contrib/source/rusentrel/news_reader.py @@ -45,6 +45,6 @@ def file_to_doc(input_file): version=version) return RuSentRelIOUtils.read_from_zip( - inner_path=RuSentRelIOUtils.get_news_innerpath(doc_id), + inner_path=RuSentRelIOUtils.get_news_innerpath(index=doc_id, version=version), process_func=file_to_doc, version=version) diff --git a/arekit/contrib/source/rusentrel/opinions/collection.py b/arekit/contrib/source/rusentrel/opinions/collection.py index e1a3ab57..586e0040 100644 --- a/arekit/contrib/source/rusentrel/opinions/collection.py +++ b/arekit/contrib/source/rusentrel/opinions/collection.py @@ -14,11 +14,10 @@ class RuSentRelOpinionCollection: def iter_opinions_from_doc(doc_id, labels_fmt=RuSentRelLabelsFormatter(), version=RuSentRelVersions.V11): - """ - doc_id: - synonyms: None or SynonymsCollection - None corresponds to the related synonym collection from RuSentRel collection. - version: + """ doc_id: + synonyms: None or SynonymsCollection + None corresponds to the related synonym collection from RuSentRel collection. + version: RuSentrelVersions enum """ assert(isinstance(version, RuSentRelVersions)) assert(isinstance(labels_fmt, StringLabelsFormatter)) @@ -26,7 +25,7 @@ def iter_opinions_from_doc(doc_id, assert(labels_fmt.supports_value(NEG_LABEL_STR)) return RuSentRelIOUtils.iter_from_zip( - inner_path=RuSentRelIOUtils.get_sentiment_opin_filepath(doc_id), + inner_path=RuSentRelIOUtils.get_sentiment_opin_filepath(index=doc_id, version=version), process_func=lambda input_file: RuSentRelOpinionCollectionProvider._iter_opinions_from_file( input_file=input_file, labels_formatter=labels_fmt, diff --git a/tests/contrib/source/test_rusentrel.py b/tests/contrib/source/test_rusentrel.py index 5af73c3e..cd08d2de 100644 --- a/tests/contrib/source/test_rusentrel.py +++ b/tests/contrib/source/test_rusentrel.py @@ -21,6 +21,8 @@ def test_iter_train_indices(self): def test_iter_test_indices(self): test_indices = list(RuSentRelIOUtils.iter_test_indices(self.rsr_version)) + for i in test_indices: + print(i, end=' ') for i in range(46, 76): if i in [70]: