Skip to content

Commit

Permalink
Fixed #136. Providing get_no_label_instance at BaseLabelScaler class.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jun 29, 2021
1 parent 5ccb8bd commit 952d822
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 17 deletions.
16 changes: 15 additions & 1 deletion common/labels/scaler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import OrderedDict

from arekit.common.labels.base import Label
from arekit.common.labels.base import Label, NoLabel


class BaseLabelScaler(object):
Expand All @@ -16,6 +16,7 @@ def __init__(self, uint_dict, int_dict):
self.__int_dict = int_dict

self.__ordered_labels = list(uint_dict.iterkeys())
self.__no_label_instance = self.__find_no_label_instance(uint_dict.iterkeys())

@property
def LabelsCount(self):
Expand All @@ -24,8 +25,21 @@ def LabelsCount(self):
def ordered_suppoted_labels(self):
return self.__ordered_labels

def get_no_label_instance(self):
if self.__no_label_instance is None:
raise Exception("NoLabel does no supported by this scaler")

return self.__no_label_instance

# region private methods

@staticmethod
def __find_no_label_instance(iter_labels):
for label in iter_labels:
if isinstance(label, NoLabel):
return label
return None

@staticmethod
def __ltoi(label, d):
assert(isinstance(label, Label))
Expand Down
4 changes: 2 additions & 2 deletions contrib/experiment_rusentrel/annot/two_scale.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging

from arekit.common.experiment.annot.base import BaseAnnotator
from arekit.common.labels.base import NoLabel
from arekit.common.news.parsed.base import ParsedNews
from arekit.common.opinions.base import Opinion
from arekit.common.opinions.collection import OpinionCollection
from arekit.common.experiment.data_type import DataType
from arekit.contrib.experiment_rusentrel.labels.types import ExperimentNeutralLabel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,7 +38,7 @@ def _annot_collection_core(self, parsed_news, data_type, doc_ops, opin_ops):
for opinion in opin_ops.read_etalon_opinion_collection(doc_id):
neut_collection.add_opinion(Opinion(source_value=opinion.SourceValue,
target_value=opinion.TargetValue,
sentiment=NoLabel()))
sentiment=ExperimentNeutralLabel()))

return neut_collection

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from arekit.common.labels.base import NoLabel
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.contrib.experiment_rusentrel.labels.types import ExperimentNegativeLabel, ExperimentPositiveLabel
from arekit.contrib.experiment_rusentrel.labels.types import ExperimentNegativeLabel, ExperimentPositiveLabel, \
ExperimentNeutralLabel


class RussianThreeScaleRussianLabelsFormatter(StringLabelsFormatter):
Expand All @@ -14,6 +14,6 @@ def __init__(self):

stol = {u'негативно': ExperimentNegativeLabel(),
u'позитивно': ExperimentPositiveLabel(),
u'нейтрально': NoLabel()}
u'нейтрально': ExperimentNeutralLabel()}

super(RussianThreeScaleRussianLabelsFormatter, self).__init__(stol=stol)
8 changes: 4 additions & 4 deletions contrib/experiment_rusentrel/labels/scalers/three.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from collections import OrderedDict

from arekit.common.labels.base import NoLabel
from arekit.common.labels.scaler import BaseLabelScaler
from arekit.contrib.experiment_rusentrel.labels.types import ExperimentNegativeLabel, ExperimentPositiveLabel
from arekit.contrib.experiment_rusentrel.labels.types import ExperimentNegativeLabel, ExperimentPositiveLabel, \
ExperimentNeutralLabel


class ThreeLabelScaler(BaseLabelScaler):

def __init__(self):

uint_labels = [(NoLabel(), 0),
uint_labels = [(ExperimentNeutralLabel(), 0),
(ExperimentPositiveLabel(), 1),
(ExperimentNegativeLabel(), 2)]

int_labels = [(NoLabel(), 0),
int_labels = [(ExperimentNeutralLabel(), 0),
(ExperimentPositiveLabel(), 1),
(ExperimentNegativeLabel(), -1)]

Expand Down
2 changes: 2 additions & 0 deletions contrib/experiment_rusentrel/labels/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


class ExperimentNeutralLabel(NoLabel):
""" RuSentRel Experiment Neutral Label.
"""
pass


Expand Down
2 changes: 1 addition & 1 deletion contrib/networks/features/term_frame_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __extract_uint_frame_variant_sentiment_role(text_frame_variant, frames_colle
polarity = frames_collection.try_get_frame_sentiment_polarity(frame_id)

if polarity is None:
return three_label_scaler.label_to_uint(label=NoLabel())
return three_label_scaler.label_to_uint(label=three_label_scaler.get_no_label_instance())

assert(isinstance(polarity, FramePolarity))

Expand Down
8 changes: 6 additions & 2 deletions tests/contrib/bert/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from arekit.common.labels.scaler import BaseLabelScaler


class TestNeutralLabel(NoLabel):
pass


class TestPositiveLabel(Label):
pass

Expand All @@ -16,11 +20,11 @@ class TestThreeLabelScaler(BaseLabelScaler):

def __init__(self):

uint_labels = [(NoLabel(), 0),
uint_labels = [(TestNeutralLabel(), 0),
(TestPositiveLabel(), 1),
(TestNegativeLabel(), 2)]

int_labels = [(NoLabel(), 0),
int_labels = [(TestNeutralLabel(), 0),
(TestPositiveLabel(), 1),
(TestNegativeLabel(), -1)]

Expand Down
8 changes: 6 additions & 2 deletions tests/contrib/networks/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from arekit.common.labels.scaler import BaseLabelScaler


class TestNeutralLabel(NoLabel):
pass


class TestPositiveLabel(Label):
pass

Expand All @@ -16,11 +20,11 @@ class TestThreeLabelScaler(BaseLabelScaler):

def __init__(self):

uint_labels = [(NoLabel(), 0),
uint_labels = [(TestNeutralLabel(), 0),
(TestPositiveLabel(), 1),
(TestNegativeLabel(), 2)]

int_labels = [(NoLabel(), 0),
int_labels = [(TestNeutralLabel(), 0),
(TestPositiveLabel(), 1),
(TestNegativeLabel(), -1)]

Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/networks/test_tf_mi_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

sys.path.append('../../../')

from arekit.common.labels.base import NoLabel
from arekit.common.labels.scaler import BaseLabelScaler

from arekit.contrib.networks.core.feeding.bags.bag import Bag
Expand All @@ -16,6 +15,7 @@
from arekit.contrib.networks.sample import InputSample
from arekit.contrib.networks.multi.architectures.max_pooling import MaxPoolingOverSentences

from arekit.tests.contrib.networks.labels import TestNeutralLabel
from arekit.tests.contrib.networks.test_tf_ctx_feed import TestContextNetworkFeeding
from arekit.tests.contrib.networks.tf_networks.supported import get_supported

Expand All @@ -27,7 +27,7 @@ def __create_minibatch(config, labels_scaler):
assert(isinstance(config, DefaultNetworkConfig))
assert(isinstance(labels_scaler, BaseLabelScaler))
bags = []
label = NoLabel()
label = TestNeutralLabel()
empty_sample = InputSample.create_empty(terms_per_context=config.TermsPerContext,
frames_per_context=config.FramesPerContext,
synonyms_per_context=config.SynonymsPerContext)
Expand Down

0 comments on commit 952d822

Please sign in to comment.