Skip to content

Commit

Permalink
Fixed #119, #121
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed May 22, 2021
1 parent 2690344 commit 0230e21
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 59 deletions.
17 changes: 8 additions & 9 deletions common/experiment/data/serializing.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from arekit.common.experiment.data.base import DataIO
from arekit.common.experiment.neutral.annot.factory import create_annotator
from arekit.common.experiment.neutral.annot.base import BaseNeutralAnnotator
from arekit.common.labels.scaler import BaseLabelScaler


class SerializationData(DataIO):
""" Data, that is necessary for models training stage.
"""

def __init__(self, label_scaler, stemmer):
def __init__(self, label_scaler, neutral_annot, stemmer):
assert(isinstance(label_scaler, BaseLabelScaler))
assert(isinstance(neutral_annot, BaseNeutralAnnotator))
super(SerializationData, self).__init__(stemmer=stemmer)

self.__label_scaler = label_scaler
self.__neutral_annot = create_annotator(
labels_count=self.LabelsCount,
dist_in_terms_between_opin_ends=self.DistanceInTermsBetweenOpinionEndsBound)

if self.LabelsCount != neutral_annot.LabelsCount:
raise Exception(u"Label scaler and neutral annotation are incompatible due to differs in labels count!")

self.__neutral_annot = neutral_annot

@property
def LabelsScaler(self):
Expand All @@ -35,10 +38,6 @@ def NeutralAnnotator(self):
"""
return self.__neutral_annot

@property
def DistanceInTermsBetweenOpinionEndsBound(self):
raise NotImplementedError()

@property
def StringEntityFormatter(self):
raise NotImplementedError()
Expand Down
14 changes: 6 additions & 8 deletions common/experiment/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.input.formatters.opinion import BaseOpinionsFormatter
from arekit.common.experiment.input.formatters.sample import BaseSampleFormatter
from arekit.common.experiment.neutral.annot.factory import get_annotator_type
from arekit.common.utils import join_dir_with_subfolder_name


Expand Down Expand Up @@ -51,6 +50,12 @@ def _get_filepath(out_dir, template, prefix):
assert(isinstance(prefix, unicode))
return join(out_dir, BaseIOUtils.__generate_tsv_archive_filename(template=template, prefix=prefix))

def _get_neutral_annot_name(self):
""" We use custom implementation as it allows to
be independent of NeutralAnnotator instance.
"""
return u"neut_annot_{labels_count}l".format(labels_count=self._experiment.DataIO.LabelsCount)

# endregion

# region public methods
Expand Down Expand Up @@ -84,13 +89,6 @@ def create_neutral_opinion_collection_filepath(self, doc_id, data_type):
def create_result_opinion_collection_filepath(self, data_type, doc_id, epoch_index):
raise NotImplementedError()

def _get_neutral_annot_name(self):
""" We use custom implementation as it allows to
be independent from NeutralAnnotator instance.
"""
annot_type = get_annotator_type(labels_count=self._experiment.DataIO.LabelsCount)
return annot_type.name

# endregion

# region private methods
Expand Down
6 changes: 3 additions & 3 deletions common/experiment/neutral/annot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ def __init__(self):
self.__opin_ops = None
self.__doc_ops = None

# region Properties

@property
def Name(self):
def LabelsCount(self):
raise NotImplementedError()

# region Properties

@property
def _OpinOps(self):
assert(isinstance(self.__opin_ops, OpinionOperations))
Expand Down
23 changes: 0 additions & 23 deletions common/experiment/neutral/annot/factory.py

This file was deleted.

Empty file.
12 changes: 12 additions & 0 deletions contrib/experiment_rusentrel/annot/algo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from arekit.common.experiment.neutral.algo.default import DefaultNeutralAnnotationAlgorithm


class RuSentRelDefaultNeutralAnnotationAlgorithm(DefaultNeutralAnnotationAlgorithm):

IGNORED_ENTITY_VALUES = [u"author", u"unknown"]

def __init__(self, dist_in_terms_bound):
super(RuSentRelDefaultNeutralAnnotationAlgorithm, self).__init__(
dist_in_sents=0,
dist_in_terms_bound=dist_in_terms_bound,
ignored_entity_values=self.IGNORED_ENTITY_VALUES)
17 changes: 17 additions & 0 deletions contrib/experiment_rusentrel/annot/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from arekit.contrib.experiment_rusentrel.annot.three_scale import ThreeScaleNeutralAnnotator
from arekit.contrib.experiment_rusentrel.annot.two_scale import TwoScaleNeutralAnnotator


class ExperimentNeutralAnnotatorFactory:

@staticmethod
def create(labels_count, create_algo):
assert(isinstance(labels_count, int))
assert(callable(create_algo))

if labels_count == 2:
return TwoScaleNeutralAnnotator()
elif labels_count == 3:
return ThreeScaleNeutralAnnotator(create_algo())
else:
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from arekit.common.experiment.neutral.algo.default import DefaultNeutralAnnotationAlgorithm
from arekit.common.experiment.neutral.algo.base import BaseNeutralAnnotationAlgorithm
from arekit.common.experiment.neutral.annot.base import BaseNeutralAnnotator
from arekit.common.experiment.data_type import DataType
from arekit.common.news.parsed.base import ParsedNews
Expand All @@ -14,18 +14,14 @@ class ThreeScaleNeutralAnnotator(BaseNeutralAnnotator):
""" For three scale classification task.
"""

name = u"annot-3-scale"
IGNORED_ENTITY_VALUES = [u"author", u"unknown"]

def __init__(self, distance_in_terms_between_bounds):
def __init__(self, algo):
super(ThreeScaleNeutralAnnotator, self).__init__()
self.__algo = DefaultNeutralAnnotationAlgorithm(
dist_in_terms_bound=distance_in_terms_between_bounds,
ignored_entity_values=self.IGNORED_ENTITY_VALUES)
assert(isinstance(algo, BaseNeutralAnnotationAlgorithm))
self.__algo = algo

@property
def Name(self):
return ThreeScaleNeutralAnnotator.name
def LabelsCount(self):
return 3

# region private methods

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ class TwoScaleNeutralAnnotator(BaseNeutralAnnotator):
""" For two scale classification task.
"""

name = u"annot-2-scale"

def __init__(self):
super(TwoScaleNeutralAnnotator, self).__init__()

@property
def Name(self):
return TwoScaleNeutralAnnotator.name
def LabelsCount(self):
return 2

# region static methods

Expand Down
7 changes: 5 additions & 2 deletions contrib/networks/core/data/serializing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

class NetworkSerializationData(SerializationData):

def __init__(self, labels_scaler, stemmer):
super(NetworkSerializationData, self).__init__(label_scaler=labels_scaler, stemmer=stemmer)
def __init__(self, labels_scaler, neutral_annot, stemmer):
super(NetworkSerializationData, self).__init__(
label_scaler=labels_scaler,
neutral_annot=neutral_annot,
stemmer=stemmer)
self.__label_provider = MultipleLabelProvider(labels_scaler)

@property
Expand Down

0 comments on commit 0230e21

Please sign in to comment.