Skip to content

Commit

Permalink
Fixed #124
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed May 27, 2021
1 parent 7a094c3 commit 07ee6aa
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 17 deletions.
18 changes: 16 additions & 2 deletions common/evaluation/evaluators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,17 @@ def _create_eval_result(self):

# region protected methods

def _calc_diff(self, etalon_opins, test_opins):
def _check_is_supported(self, label, is_label_supported):
if label is None:
return True

if not is_label_supported(label):
raise Exception(u"Label \"{label}\" is not supported by {e}".format(
label=label_to_str(label),
e=type(self).__name__))

def _calc_diff(self, etalon_opins, test_opins, is_label_supported):
assert(callable(is_label_supported))

it = self.__iter_diff_core(etalon_opins=etalon_opins,
test_opins=test_opins)
Expand All @@ -91,6 +101,9 @@ def _calc_diff(self, etalon_opins, test_opins):
for args in it:
opin, etalon_label, result_label = args

self._check_is_supported(label=etalon_label, is_label_supported=is_label_supported)
self._check_is_supported(label=result_label, is_label_supported=is_label_supported)

row = [opin.SourceValue.encode('utf-8'),
opin.TargetValue.encode('utf-8'),
None if etalon_label is None else label_to_str(etalon_label),
Expand Down Expand Up @@ -119,7 +132,8 @@ def evaluate(self, cmp_pairs):
for cmp_pair in cmp_pairs:
assert(isinstance(cmp_pair, OpinionCollectionsToCompare))
cmp_table = self._calc_diff(etalon_opins=cmp_pair.EtalonOpinionCollection,
test_opins=cmp_pair.TestOpinionCollection)
test_opins=cmp_pair.TestOpinionCollection,
is_label_supported=result.is_label_supported)

result.reg_doc(cmp_pair=cmp_pair, cmp_table=cmp_table)

Expand Down
7 changes: 6 additions & 1 deletion common/evaluation/results/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

class BaseEvalResult(object):

def __init__(self):
def __init__(self, supported_labels):
assert(isinstance(supported_labels, set))
self._cmp_tables = {}
self._total_result = OrderedDict()
self.__supported_labels = supported_labels

# region properties

Expand All @@ -25,6 +27,9 @@ def calculate(self):

# endregion

def is_label_supported(self, label):
return label in self.__supported_labels

def get_result_by_metric(self, metric_name):
assert(isinstance(metric_name, unicode))
return self._total_result[metric_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, data_type):
super(ThreeClassEvaluator, self).__init__(eval_mode=EvaluationModes.Extraction)
self.__data_type = data_type

def _calc_diff(self, etalon_opins, test_opins):
def _calc_diff(self, etalon_opins, test_opins, is_label_supported):
assert(isinstance(etalon_opins, OpinionCollection))
assert(isinstance(test_opins, OpinionCollection))

Expand All @@ -30,11 +30,15 @@ def _calc_diff(self, etalon_opins, test_opins):
for opinion in etalon_opins:
# We keep only those opinions that were not
# presented in test and has neutral label

self._check_is_supported(label=opinion.Sentiment, is_label_supported=is_label_supported)

if not test_opins_expanded.has_synonymous_opinion(opinion) and opinion.Sentiment == neut_label:
test_opins_expanded.add_opinion(opinion)

return super(ThreeClassEvaluator, self)._calc_diff(etalon_opins=etalon_opins,
test_opins=test_opins_expanded)
test_opins=test_opins_expanded,
is_label_supported=is_label_supported)

def _create_eval_result(self):
return ThreeClassEvalResult()
14 changes: 9 additions & 5 deletions contrib/experiment_rusentrel/evaluation/results/three_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from arekit.common.evaluation.results.base import BaseEvalResult
from arekit.common.evaluation.results.utils import calc_f1_3c_macro, calc_f1_single_class
from arekit.common.labels.base import NeutralLabel, Label
from arekit.common.labels.base import Label
from arekit.common.opinions.collection import OpinionCollection
from arekit.contrib.experiment_rusentrel.evaluation.results import metrics
from arekit.contrib.experiment_rusentrel.evaluation.results.metrics import calc_precision_micro, calc_recall_micro
from arekit.contrib.experiment_rusentrel.labels.types import ExperimentPositiveLabel, ExperimentNegativeLabel
from arekit.contrib.experiment_rusentrel.labels.types import ExperimentPositiveLabel, ExperimentNegativeLabel, \
ExperimentNeutralLabel


class ThreeClassEvalResult(BaseEvalResult):
Expand All @@ -28,15 +29,18 @@ class ThreeClassEvalResult(BaseEvalResult):
C_F1_MICRO = u'f1_micro'

def __init__(self):
super(ThreeClassEvalResult, self).__init__()
self.__doc_results = OrderedDict()
self.__pos_label = ExperimentPositiveLabel()
self.__neg_label = ExperimentNegativeLabel()
self.__neu_label = self.create_neutral_label()

super(ThreeClassEvalResult, self).__init__(
supported_labels={self.__pos_label, self.__neg_label, self.__neu_label})

self.__doc_results = OrderedDict()

@staticmethod
def create_neutral_label():
return NeutralLabel()
return ExperimentNeutralLabel()

@staticmethod
def __has_opinions_with_label(opinions, label):
Expand Down
7 changes: 6 additions & 1 deletion contrib/experiment_rusentrel/evaluation/results/two_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@ class TwoClassEvalResult(BaseEvalResult):
C_F1_NEG = u'f1_neg'

def __init__(self):
super(TwoClassEvalResult, self).__init__()
self.__doc_results = OrderedDict()

self.__pos_label = ExperimentPositiveLabel()
self.__neg_label = ExperimentNegativeLabel()

super(TwoClassEvalResult, self).__init__(
supported_labels={self.__pos_label, self.__neg_label})

self.__using_labels = {self.__pos_label, self.__neg_label}

@staticmethod
def __has_opinions_with_label(opinions, label):
assert(isinstance(label, Label))
Expand Down
22 changes: 16 additions & 6 deletions contrib/experiment_rusentrel/tests/test_rusentrel_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@

from enum import Enum

from arekit.common.evaluation.evaluators.cmp_table import DocumentCompareTable
from arekit.common.evaluation.evaluators.modes import EvaluationModes
from arekit.common.evaluation.utils import OpinionCollectionsToCompareUtils
from arekit.common.opinions.collection import OpinionCollection
from arekit.common.synonyms import SynonymsCollection
from arekit.common.utils import progress_bar_iter
from arekit.contrib.experiment_rusentrel.evaluation.evaluators.two_class import TwoClassEvaluator
from arekit.contrib.experiment_rusentrel.evaluation.results.two_class import TwoClassEvalResult
from arekit.contrib.experiment_rusentrel.labels.formatters.rusentiframes import ExperimentRuSentiFramesLabelsFormatter
from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions
from arekit.contrib.source.rusentrel.labels_fmt import RuSentRelLabelsFormatter
from arekit.contrib.source.rusentrel.opinions.collection import RuSentRelOpinionCollection
from arekit.contrib.source.rusentrel.opinions.formatter import RuSentRelOpinionCollectionFormatter
from arekit.contrib.source.zip_utils import ZipArchiveUtils
Expand Down Expand Up @@ -74,12 +75,12 @@ def iter_doc_ids(result_version):
yield int(doc_id_str)

@staticmethod
def iter_doc_opinions(doc_id, result_version):
def iter_doc_opinions(doc_id, result_version, labels_formatter):
return ZippedResultsIOUtils.iter_from_zip(
inner_path=path.join(u"{}.opin.txt".format(doc_id)),
process_func=lambda input_file: RuSentRelOpinionCollectionFormatter._iter_opinions_from_file(
input_file=input_file,
labels_formatter=RuSentRelLabelsFormatter()),
labels_formatter=labels_formatter),
version=result_version)


Expand Down Expand Up @@ -114,16 +115,24 @@ def __test_core(self, res_version, synonyms=None,
else:
actual_synonyms = synonyms

# Setup an experiment labels formatter.
labels_formatter = ExperimentRuSentiFramesLabelsFormatter()

# Iter cmp opinions.
cmp_pairs_iter = OpinionCollectionsToCompareUtils.iter_comparable_collections(
doc_ids=ZippedResultsIOUtils.iter_doc_ids(res_version),
read_etalon_collection_func=lambda doc_id: OpinionCollection(
opinions=RuSentRelOpinionCollection.iter_opinions_from_doc(doc_id=doc_id),
opinions=RuSentRelOpinionCollection.iter_opinions_from_doc(
doc_id=doc_id,
labels_fmt=labels_formatter),
synonyms=actual_synonyms,
error_on_duplicates=False,
error_on_synonym_end_missed=True),
read_result_collection_func=lambda doc_id: OpinionCollection(
opinions=ZippedResultsIOUtils.iter_doc_opinions(doc_id=doc_id, result_version=res_version),
opinions=ZippedResultsIOUtils.iter_doc_opinions(
doc_id=doc_id,
result_version=res_version,
labels_formatter=labels_formatter),
synonyms=actual_synonyms,
error_on_duplicates=False,
error_on_synonym_end_missed=False))
Expand All @@ -150,7 +159,8 @@ def __test_core(self, res_version, synonyms=None,
if self.__display_cmp_table:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
for doc_id, df_cmp_table in result.iter_dataframe_cmp_tables():
print u"{}:\t{}\n".format(doc_id, df_cmp_table)
assert(isinstance(df_cmp_table, DocumentCompareTable))
print u"{}:\t{}\n".format(doc_id, df_cmp_table.DataframeTable)
print "------------------------"

if check_results:
Expand Down

0 comments on commit 07ee6aa

Please sign in to comment.