Skip to content

Commit

Permalink
#290 related; #291 done
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Feb 24, 2022
1 parent c2257cb commit fefa560
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 26 deletions.
10 changes: 10 additions & 0 deletions arekit/common/labels/scaler/single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from collections import OrderedDict

from arekit.common.labels.scaler.base import BaseLabelScaler


class SingleLabelScaler(BaseLabelScaler):

def __init__(self, label, uint_label=0):
d = OrderedDict([(label, uint_label)])
super(SingleLabelScaler, self).__init__(uint_dict=d, int_dict=d)
5 changes: 4 additions & 1 deletion arekit/contrib/networks/core/predict/base_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
class BasePredictWriter(object):

def __init__(self, target):
def __init__(self):
self._target = None

def set_target(self, target):
self._target = target

def write(self, title, contents_it):
Expand Down
5 changes: 2 additions & 3 deletions arekit/contrib/networks/core/predict/tsv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

class TsvPredictWriter(BasePredictWriter):

def __init__(self, filepath):
assert(isinstance(filepath, str))
super(TsvPredictWriter, self).__init__(target=filepath)
def __init__(self):
super(TsvPredictWriter, self).__init__()
self.__col_separator = '\t'
self.__f = None

Expand Down
5 changes: 3 additions & 2 deletions examples/network/args/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from arekit.processing.lemmatization.mystem import MystemWrapper
from examples.network.args.base import BaseArg
from examples.text.pipeline_entities_bert_ontonotes import BertOntonotesNERPipelineItem
from examples.text.pipeline_entities_default import TextEntitiesParser


class InputTextArg(BaseArg):
Expand Down Expand Up @@ -262,8 +263,8 @@ class EntitiesParserArg(BaseArg):
@staticmethod
def read_argument(args):
arg = args.entities_parser
if arg == "no":
return None
if arg == "default":
return TextEntitiesParser()
elif arg == "bert-ontonotes":
return BertOntonotesNERPipelineItem()

Expand Down
21 changes: 12 additions & 9 deletions examples/pipelines/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from arekit.contrib.networks.core.pipeline.item_keep_hidden import MinibatchHiddenFetcherPipelineItem
from arekit.contrib.networks.core.pipeline.item_predict import EpochLabelsPredictorPipelineItem
from arekit.contrib.networks.core.pipeline.item_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.core.predict.tsv_writer import TsvPredictWriter
from arekit.contrib.networks.core.predict.base_writer import BasePredictWriter
from arekit.contrib.networks.factory import create_network_and_network_config_funcs
from arekit.contrib.networks.shapes import NetworkInputShapes
from arekit.processing.languages.ru.pos_service import PartOfSpeechTypesService
Expand All @@ -25,9 +25,10 @@

class TensorflowNetworkInferencePipelineItem(BasePipelineItem):

def __init__(self, model_name, bags_collection_type, model_input_type,
def __init__(self, model_name, bags_collection_type, model_input_type, predict_writer,
data_type, bags_per_minibatch, nn_io, labels_scaler, callbacks):
assert(isinstance(callbacks, list))
assert(isinstance(predict_writer, BasePredictWriter))
assert(isinstance(data_type, DataType))

# Create network an configuration.
Expand All @@ -41,6 +42,7 @@ def __init__(self, model_name, bags_collection_type, model_input_type,
self.__config.modify_bag_size(BAG_SIZE)
self.__config.modify_bags_per_minibatch(bags_per_minibatch)
self.__config.set_class_weights([1, 1, 1])
self.__config.set_pos_count(PartOfSpeechTypesService.get_mystem_pos_count())

# intialize model context.
self.__create_model_ctx = lambda inference_ctx: TensorflowModelContext(
Expand All @@ -50,8 +52,11 @@ def __init__(self, model_name, bags_collection_type, model_input_type,
inference_ctx=inference_ctx,
bags_collection_type=bags_collection_type)

self.__callbacks = callbacks
self.__labels_scaler = labels_scaler
self.__callbacks = callbacks + [
PredictResultWriterCallback(labels_scaler=labels_scaler, writer=predict_writer)
]

self.__writer = predict_writer
self.__bags_collection_type = bags_collection_type
self.__data_type = data_type

Expand All @@ -76,7 +81,6 @@ def apply_core(self, input_data, pipeline_ctx):

# Setup config parameters.
self.__config.set_term_embedding(embedding)
self.__config.set_pos_count(PartOfSpeechTypesService.get_mystem_pos_count())

inference_ctx = InferenceContext.create_empty()
inference_ctx.initialize(
Expand All @@ -95,18 +99,17 @@ def apply_core(self, input_data, pipeline_ctx):
]),
bag_size=self.__config.BagSize)

predict_callback = PredictResultWriterCallback(labels_scaler=self.__labels_scaler,
writer=TsvPredictWriter(tgt))

# Model preparation.
model = BaseTensorflowModel(
context=self.__create_model_ctx(inference_ctx),
callbacks=self.__callbacks + [predict_callback],
callbacks=self.__callbacks,
predict_pipeline=[
EpochLabelsPredictorPipelineItem(),
EpochLabelsCollectorPipelineItem(),
MinibatchHiddenFetcherPipelineItem()
],
fit_pipeline=[MinibatchFittingPipelineItem()])

self.__writer.set_target(tgt)

model.predict(do_compile=True)
14 changes: 4 additions & 10 deletions examples/pipelines/serialize.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from collections import OrderedDict

from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.folding.base import BaseDataFolding
from arekit.common.labels.base import NoLabel
from arekit.common.labels.scaler.base import BaseLabelScaler
from arekit.common.labels.scaler.single import SingleLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.news.base import News
from arekit.common.news.entities_grouping import EntitiesGroupingPipelineItem
Expand All @@ -27,14 +25,13 @@
from examples.network.infer.exp import CustomExperiment
from examples.network.infer.exp_io import InferIOUtils
from examples.network.serialization_data import CustomSerializationContext
from examples.text.pipeline_entities_default import TextEntitiesParser


class TextSerializationPipelineItem(BasePipelineItem):

def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot,
embedding_path, entity_fmt, stemmer, data_folding):
assert(isinstance(entities_parser, BasePipelineItem) or entities_parser is None)
assert(isinstance(entities_parser, BasePipelineItem))
assert(isinstance(entity_fmt, StringEntitiesFormatter))
assert(isinstance(synonyms, SynonymsCollection))
assert(isinstance(terms_per_context, int))
Expand All @@ -51,17 +48,14 @@ def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot,
pos_tagger = POSMystemWrapper(MystemWrapper().MystemInstance)

# Label provider setup.
labels_scaler = BaseLabelScaler(uint_dict=OrderedDict([(NoLabel(), 0)]),
int_dict=OrderedDict([(NoLabel(), 0)]))

self.__labels_fmt = StringLabelsFormatter(stol={"neu": NoLabel})

# Initialize text parser with the related dependencies.
frames_collection = create_frames_collection()
frame_variants_collection = create_and_fill_variant_collection(frames_collection)
self.__text_parser = BaseTextParser(pipeline=[
TermsSplitterParser(),
TextEntitiesParser() if entities_parser is None else entities_parser,
entities_parser,
EntitiesGroupingPipelineItem(self.__synonyms.get_synonym_group_index),
DefaultTextTokenizer(keep_tokens=True),
FrameVariantsParser(frame_variants=frame_variants_collection),
Expand All @@ -72,7 +66,7 @@ def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot,

# initialize expriment related data.
self.__exp_ctx = CustomSerializationContext(
labels_scaler=labels_scaler,
labels_scaler=SingleLabelScaler(NoLabel()),
stemmer=stemmer,
embedding=embedding,
annotator=opin_annot,
Expand Down
4 changes: 3 additions & 1 deletion examples/run_text_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from arekit.contrib.experiment_rusentrel.labels.types import ExperimentPositiveLabel, ExperimentNegativeLabel
from arekit.contrib.networks.core.callback.stat import TrainingStatProviderCallback
from arekit.contrib.networks.core.callback.train_limiter import TrainingLimiterCallback
from arekit.contrib.networks.core.predict.tsv_writer import TsvPredictWriter
from arekit.contrib.networks.enum_input_types import ModelInputType
from arekit.contrib.networks.enum_name_types import ModelNames
from examples.input import EXAMPLES
Expand Down Expand Up @@ -102,9 +103,10 @@
bags_collection_type=create_bags_collection_type(model_input_type=model_input_type),
model_input_type=model_input_type,
labels_scaler=labels_scaler,
predict_writer=TsvPredictWriter(),
callbacks=[
TrainingLimiterCallback(train_acc_limit=0.99),
TrainingStatProviderCallback()
TrainingStatProviderCallback(),
]),
BratBackendPipelineItem(label_to_rel={
str(labels_scaler.label_to_uint(ExperimentPositiveLabel())): "POS",
Expand Down

0 comments on commit fefa560

Please sign in to comment.