Skip to content

Commit

Permalink
#257 done.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 17, 2022
1 parent b29307d commit 61e35fc
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 47 deletions.
8 changes: 0 additions & 8 deletions arekit/common/experiment/labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ def __init__(self, uint_labeled_ids):
self.__original_uint_labels = collections.OrderedDict(uint_labeled_ids)
self.__assigned_uint_labels = {}

# TODO. #257. Remove.
def is_empty(self):
return len(self.__assigned_uint_labels) == 0

def assign_uint_label(self, uint_label, sample_row_id):
"""
Optionally applies the label.
Expand All @@ -32,10 +28,6 @@ def assign_uint_label(self, uint_label, sample_row_id):
if sample_row_id not in self.__assigned_uint_labels:
self.__assigned_uint_labels[sample_row_id] = uint_label

# TODO. #257. Remove.
def reset_labels(self):
self.__assigned_uint_labels.clear()

def iter_non_duplicated_labeled_sample_row_ids(self):
for sample_id, _ in self.__original_uint_labels.items():
yield sample_id, self.__assigned_uint_labels[sample_id]
15 changes: 7 additions & 8 deletions arekit/contrib/networks/core/ctx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

class InferenceContext(object):

def __init__(self, labeled_samples_dict, bags_collections_dict):
assert(isinstance(labeled_samples_dict, dict))
def __init__(self, sample_label_pairs_dict, bags_collections_dict):
assert(isinstance(sample_label_pairs_dict, dict))
assert(isinstance(bags_collections_dict, dict))
self.__labeled_samples_dict = labeled_samples_dict
self.__sample_label_pairs_dict = sample_label_pairs_dict
self.__bags_collections_dict = bags_collections_dict
self.__train_stat_uint_labeled_sample_row_ids = None

Expand All @@ -29,8 +29,8 @@ def BagsCollections(self):
return self.__bags_collections_dict

@property
def LabeledSamplesCollections(self):
return self.__labeled_samples_dict
def SampleIdAndLabelPairs(self):
return self.__sample_label_pairs_dict

@property
def HasNormalizedWeights(self):
Expand All @@ -40,7 +40,7 @@ def HasNormalizedWeights(self):

@classmethod
def create_empty(cls):
return cls(labeled_samples_dict={}, bags_collections_dict={})
return cls(sample_label_pairs_dict={}, bags_collections_dict={})

def initialize(self, dtypes, create_samples_view_func, has_model_predefined_state,
vocab, labels_count, bags_collection_type, bag_size, input_shapes):
Expand Down Expand Up @@ -71,8 +71,7 @@ def initialize(self, dtypes, create_samples_view_func, has_model_predefined_stat

# Saving into dictionaries.
self.__bags_collections_dict[data_type] = bags_collection
self.__labeled_samples_dict[data_type] = LabeledCollection(
uint_labeled_ids=uint_labeled_sample_row_ids)
self.__sample_label_pairs_dict[data_type] = list(uint_labeled_sample_row_ids)

if data_type == DataType.Train:
self.__train_stat_uint_labeled_sample_row_ids = uint_labeled_sample_row_ids
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/networks/core/model_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __set_optimiser_value(self, value):
def get_bags_collection(self, data_type):
return self.__inference_ctx.BagsCollections[data_type]

def get_labeled_samples_collection(self, data_type):
return self.__inference_ctx.LabeledSamplesCollections[data_type]
def get_sample_id_label_pairs(self, data_type):
return self.__inference_ctx.SampleIdAndLabelPairs[data_type]

def get_hidden_parameters(self):
names = []
Expand Down
29 changes: 1 addition & 28 deletions arekit/contrib/networks/core/pipeline_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,6 @@


class EpochLabelsPredictorPipelineItem(EpochHandlingPipelineItem):
""" Considering to treat feed dictionary.
"""

def __init__(self):
super(EpochLabelsPredictorPipelineItem, self).__init__()
self.__labeled_samples = None

@property
def LabeledSamples(self):
return self.__labeled_samples

def before_epoch(self, model_context, data_type):
super(EpochLabelsPredictorPipelineItem, self).before_epoch(model_context=model_context,
data_type=data_type)

# Select the appropriate labels collection.
self.__labeled_samples = self._context.InferenceContext.LabeledSamplesCollections[data_type]

# Clear and assert the correctness.
# TODO. #257. Remove. Create new instance instead.
self.__labeled_samples.reset_labels()
# TODO. #257. Remove. Create new instance instead.
assert(self.__labeled_samples.is_empty())

def apply(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
Expand All @@ -34,8 +11,4 @@ def apply(self, pipeline_ctx):
feed_dict = self._context.create_feed_dict(minibatch=minibatch, data_type=self._data_type)
uint_labels = self._context.Session.run(self._context.Network.Labels, feed_dict=feed_dict)

# Apply labeling.
for bag_index, bag in enumerate(minibatch.iter_by_bags()):
uint_label = int(uint_labels[bag_index])
for sample in bag:
self.__labeled_samples.assign_uint_label(uint_label, sample.ID)
pipeline_ctx.update("uint_labels", uint_labels)
33 changes: 33 additions & 0 deletions arekit/contrib/networks/core/pipeline_predict_labeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from arekit.common.experiment.labeling import LabeledCollection
from arekit.common.pipeline.context import PipelineContext
from arekit.contrib.networks.core.pipeline_epoch import EpochHandlingPipelineItem


class EpochLabelsCollectorPipelineItem(EpochHandlingPipelineItem):

def __init__(self):
super(EpochLabelsCollectorPipelineItem, self).__init__()
self.__labeled_samples = None

@property
def LabeledSamples(self):
return self.__labeled_samples

def before_epoch(self, model_context, data_type):
super(EpochLabelsCollectorPipelineItem, self).before_epoch(model_context=model_context,
data_type=data_type)
pairs = self._context.InferenceContext.SampleIdAndLabelPairs[data_type]
self.__labeled_samples = LabeledCollection(uint_labeled_ids=pairs)

def apply(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
assert("uint_labels" in pipeline_ctx)

minibatch = pipeline_ctx.provide("src")
uint_labels = pipeline_ctx.provide("uint_labels")

# Apply labeling.
for bag_index, bag in enumerate(minibatch.iter_by_bags()):
uint_label = int(uint_labels[bag_index])
for sample in bag:
self.__labeled_samples.assign_uint_label(uint_label, sample.ID)
2 changes: 2 additions & 0 deletions arekit/contrib/networks/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from arekit.contrib.networks.core.pipeline_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.pipeline_keep_hidden import MinibatchHiddenFetcherPipelineItem
from arekit.contrib.networks.core.pipeline_predict import EpochLabelsPredictorPipelineItem
from arekit.contrib.networks.core.pipeline_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.shapes import NetworkInputShapes
from arekit.contrib.networks.utils import rm_dir_contents

Expand Down Expand Up @@ -103,6 +104,7 @@ def _handle_iteration(self, it_index):
callback=callback,
predict_pipeline=[
EpochLabelsPredictorPipelineItem(),
EpochLabelsCollectorPipelineItem(),
MinibatchHiddenFetcherPipelineItem()
],
fit_pipeline=[MinibatchFittingPipelineItem()])
Expand Down
4 changes: 3 additions & 1 deletion examples/run_text_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from arekit.contrib.networks.core.pipeline_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.pipeline_keep_hidden import MinibatchHiddenFetcherPipelineItem
from arekit.contrib.networks.core.pipeline_predict import EpochLabelsPredictorPipelineItem
from arekit.contrib.networks.core.pipeline_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.core.predict.provider import BasePredictProvider
from arekit.contrib.networks.core.predict.tsv_writer import TsvPredictWriter
from arekit.contrib.networks.factory import create_network_and_network_config_funcs
Expand Down Expand Up @@ -153,14 +154,15 @@
callback=NetworkCallback(),
predict_pipeline=[
EpochLabelsPredictorPipelineItem(),
EpochLabelsCollectorPipelineItem(),
MinibatchHiddenFetcherPipelineItem()
],
fit_pipeline=[MinibatchFittingPipelineItem()])

model.predict(do_compile=True)

# Gather annotated contexts onto document level.
item = model.from_predicted(EpochLabelsPredictorPipelineItem)
item = model.from_predicted(EpochLabelsCollectorPipelineItem)
labeled_samples = item.LabeledSamples

predict_provider = BasePredictProvider()
Expand Down

0 comments on commit 61e35fc

Please sign in to comment.