Skip to content

Commit

Permalink
#261 related. Removing model_ctx passing into NetworkCallback. #260 r…
Browse files Browse the repository at this point in the history
…elated.
  • Loading branch information
nicolay-r committed Jan 20, 2022
1 parent 4fff52c commit f9ed7d1
Show file tree
Hide file tree
Showing 18 changed files with 38 additions and 65 deletions.
4 changes: 2 additions & 2 deletions arekit/contrib/experiment_rusentrel/model_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.item_handle import HandleIterPipelineItem

from arekit.contrib.networks.core.ctx_predict_log import NetworkInputDependentVariables
from arekit.contrib.networks.core.idhv_collection import NetworkInputDependentVariables
from arekit.contrib.networks.core.model import BaseTensorflowModel
from arekit.contrib.networks.core.pipeline_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.core.pipeline.item_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.core.predict.provider import BasePredictProvider
from arekit.contrib.networks.core.predict.tsv_writer import TsvPredictWriter
from arekit.contrib.source.rusentrel.labels_fmt import RuSentRelLabelsFormatter
Expand Down
11 changes: 0 additions & 11 deletions arekit/contrib/networks/core/callback/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
from arekit.contrib.networks.core.model_ctx import TensorflowModelContext


class NetworkCallback(object):

def __init__(self):
super(NetworkCallback, self).__init__()
self._model_ctx = None

def on_initialized(self, model_ctx):
assert(isinstance(model_ctx, TensorflowModelContext))
self._model_ctx = model_ctx

def on_fit_started(self, operation_cancel):
# Do nothing by default.
pass
Expand Down
8 changes: 6 additions & 2 deletions arekit/contrib/networks/core/callback/hidden.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ def on_epoch_finished(self, pipeline, operation_cancel):

self.__epochs_passed += 1

# TODO. This might be taken from pipeline item.
names, values = self._model_ctx.get_hidden_parameters()
if len(pipeline) == 0:
return

model_ctx = pipeline[0].ModelContext
names, tensors = map(list, zip(*model_ctx.Network.iter_hidden_parameters()))
values = model_ctx.Session.run(tensors)

assert(isinstance(names, list))
assert(isinstance(values, list))
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/networks/core/callback/hidden_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from arekit.common.utils import create_dir_if_not_exists
from arekit.contrib.networks.core.callback.base import NetworkCallback
from arekit.contrib.networks.core.ctx_predict_log import NetworkInputDependentVariables
from arekit.contrib.networks.core.pipeline_keep_hidden import MinibatchHiddenFetcherPipelineItem
from arekit.contrib.networks.core.idhv_collection import NetworkInputDependentVariables
from arekit.contrib.networks.core.pipeline.item_keep_hidden import MinibatchHiddenFetcherPipelineItem
from arekit.contrib.networks.core.utils import get_item_from_pipeline


Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/networks/core/callback/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from arekit.contrib.networks.core.callback.base import NetworkCallback
from arekit.contrib.networks.core.cancellation import OperationCancellation
from arekit.contrib.networks.core.pipeline_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.pipeline.item_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.utils import get_item_from_pipeline

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/networks/core/callback/train_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from arekit.contrib.networks.core.callback.base import NetworkCallback
from arekit.contrib.networks.core.cancellation import OperationCancellation
from arekit.contrib.networks.core.pipeline_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.pipeline.item_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.utils import get_item_from_pipeline

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from collections import OrderedDict

import numpy as np
from collections import OrderedDict


class NetworkInputDependentVariables:
Expand Down Expand Up @@ -77,4 +77,4 @@ def iter_by_parameter_values(self, param_name):
def iter_var_names(self):
return iter(self.__by_param_names.keys())

# endregion
# endregion
3 changes: 1 addition & 2 deletions arekit/contrib/networks/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from arekit.contrib.networks.core.feeding.bags.collection.factory import create_batch_by_bags_group
from arekit.contrib.networks.core.model_ctx import TensorflowModelContext
from arekit.contrib.networks.core.params import NeuralNetworkModelParams
from arekit.contrib.networks.core.pipeline_epoch import EpochHandlingPipelineItem
from arekit.contrib.networks.core.pipeline.item_base import EpochHandlingPipelineItem
from arekit.contrib.networks.core.utils import get_item_from_pipeline
from arekit.contrib.networks.tf_helpers.nn_states import TensorflowNetworkStatesProvider

Expand Down Expand Up @@ -132,7 +132,6 @@ def fit(self, model_params, seed):
assert(isinstance(model_params, NeuralNetworkModelParams))
self.__context.Network.compile(self.__context.Config, reset_graph=True, graph_seed=seed)
self.__context.set_optimiser()
self.__callback_do(lambda callback: callback.on_initialized(self.__context)),
self.__context.initialize_session()
self.__try_load_state()
self.__fit(epochs_count=model_params.EpochsCount)
Expand Down
15 changes: 0 additions & 15 deletions arekit/contrib/networks/core/model_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ def Optimiser(self):
def BagsCollectionType(self):
return self.__bags_collection_type

@property
def InferenceContext(self):
return self.__inference_ctx

# endregion

def __set_optimiser_value(self, value):
Expand All @@ -59,17 +55,6 @@ def get_bags_collection(self, data_type):
def get_sample_id_label_pairs(self, data_type):
return self.__inference_ctx.SampleIdAndLabelPairs[data_type]

def get_hidden_parameters(self):
names = []
tensors = []

for name, tensor in self.__network.iter_hidden_parameters():
names.append(name)
tensors.append(tensor)

result_list = self.__sess.run(tensors)
return names, result_list

def set_optimiser(self):
optimiser = self.Config.Optimiser.minimize(self.__network.Cost)
self.__set_optimiser_value(optimiser)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ def __init__(self):
def DataType(self):
return self._data_type

@property
def ModelContext(self):
return self._context

def before_epoch(self, model_context, data_type):
assert(isinstance(model_context, TensorflowModelContext))
self._context = model_context
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from arekit.common.pipeline.context import PipelineContext
from arekit.contrib.networks.core.pipeline_epoch import EpochHandlingPipelineItem
from arekit.contrib.networks.core.pipeline.item_base import EpochHandlingPipelineItem


class MinibatchFittingPipelineItem(EpochHandlingPipelineItem):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from arekit.common.pipeline.context import PipelineContext
from arekit.contrib.networks.core.ctx_predict_log import NetworkInputDependentVariables
from arekit.contrib.networks.core.pipeline_epoch import EpochHandlingPipelineItem
from arekit.contrib.networks.core.idhv_collection import NetworkInputDependentVariables
from arekit.contrib.networks.core.pipeline.item_base import EpochHandlingPipelineItem


class MinibatchHiddenFetcherPipelineItem(EpochHandlingPipelineItem):
Expand All @@ -19,22 +19,15 @@ def before_epoch(self, model_context, data_type):
super(MinibatchHiddenFetcherPipelineItem, self).before_epoch(
model_context=model_context, data_type=data_type)

self.__idh_names = []
self.__idh_tensors = []

for name, tensor in self._context.Network.iter_input_dependent_hidden_parameters():
self.__idh_names.append(name)
self.__idh_tensors.append(tensor)
self.__idh_names, self.__idh_tensors = map(
list, zip(*self._context.Network.iter_input_dependent_hidden_parameters()))

self.__input_dependent_params = NetworkInputDependentVariables()

def apply(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
minibatch = pipeline_ctx.provide("src")

feed_dict = self._context.create_feed_dict(minibatch=minibatch,
data_type=self._data_type)

feed_dict = self._context.create_feed_dict(minibatch=minibatch, data_type=self._data_type)
idh_values = self._context.Session.run(self.__idh_tensors, feed_dict=feed_dict)

if not (len(self.__idh_names) > 0 and len(idh_values) > 0):
Expand All @@ -43,7 +36,6 @@ def apply(self, pipeline_ctx):
self.__input_dependent_params.add_input_dependent_values(
names_list=self.__idh_names,
tensor_values_list=idh_values,
text_opinion_ids=[sample.ID for sample in
minibatch.iter_by_samples()],
text_opinion_ids=[sample.ID for sample in minibatch.iter_by_samples()],
bags_per_minibatch=self._context.Config.BagsPerMinibatch,
bag_size=self._context.Config.BagSize)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from arekit.common.pipeline.context import PipelineContext
from arekit.contrib.networks.core.pipeline_epoch import EpochHandlingPipelineItem
from arekit.contrib.networks.core.pipeline.item_base import EpochHandlingPipelineItem


class EpochLabelsPredictorPipelineItem(EpochHandlingPipelineItem):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from arekit.common.experiment.labeling import LabeledCollection
from arekit.common.pipeline.context import PipelineContext
from arekit.contrib.networks.core.pipeline_epoch import EpochHandlingPipelineItem
from arekit.contrib.networks.core.pipeline.item_base import EpochHandlingPipelineItem


class EpochLabelsCollectorPipelineItem(EpochHandlingPipelineItem):
Expand All @@ -16,7 +16,7 @@ def LabeledSamples(self):
def before_epoch(self, model_context, data_type):
super(EpochLabelsCollectorPipelineItem, self).before_epoch(model_context=model_context,
data_type=data_type)
pairs = self._context.InferenceContext.SampleIdAndLabelPairs[data_type]
pairs = self._context.get_sample_id_label_pairs(data_type)
self.__labeled_samples = LabeledCollection(uint_labeled_ids=pairs)

def apply(self, pipeline_ctx):
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/networks/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from arekit.contrib.networks.core.pipeline_epoch import EpochHandlingPipelineItem
from arekit.contrib.networks.core.pipeline.item_base import EpochHandlingPipelineItem


def get_item_from_pipeline(pipeline, item_type):
Expand Down
8 changes: 4 additions & 4 deletions arekit/contrib/networks/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from arekit.contrib.networks.core.model import BaseTensorflowModel
from arekit.contrib.networks.core.model_ctx import TensorflowModelContext
from arekit.contrib.networks.core.params import NeuralNetworkModelParams
from arekit.contrib.networks.core.pipeline_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.pipeline_keep_hidden import MinibatchHiddenFetcherPipelineItem
from arekit.contrib.networks.core.pipeline_predict import EpochLabelsPredictorPipelineItem
from arekit.contrib.networks.core.pipeline_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.core.pipeline.item_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.pipeline.item_keep_hidden import MinibatchHiddenFetcherPipelineItem
from arekit.contrib.networks.core.pipeline.item_predict import EpochLabelsPredictorPipelineItem
from arekit.contrib.networks.core.pipeline.item_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.shapes import NetworkInputShapes
from arekit.contrib.networks.utils import rm_dir_contents

Expand Down
8 changes: 4 additions & 4 deletions examples/run_text_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from arekit.contrib.networks.core.ctx_inference import InferenceContext
from arekit.contrib.networks.core.model import BaseTensorflowModel
from arekit.contrib.networks.core.model_ctx import TensorflowModelContext
from arekit.contrib.networks.core.pipeline_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.pipeline_keep_hidden import MinibatchHiddenFetcherPipelineItem
from arekit.contrib.networks.core.pipeline_predict import EpochLabelsPredictorPipelineItem
from arekit.contrib.networks.core.pipeline_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.core.pipeline.item_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.pipeline.item_keep_hidden import MinibatchHiddenFetcherPipelineItem
from arekit.contrib.networks.core.pipeline.item_predict import EpochLabelsPredictorPipelineItem
from arekit.contrib.networks.core.pipeline.item_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.core.predict.provider import BasePredictProvider
from arekit.contrib.networks.core.predict.tsv_writer import TsvPredictWriter
from arekit.contrib.networks.factory import create_network_and_network_config_funcs
Expand Down

0 comments on commit f9ed7d1

Please sign in to comment.