Skip to content

Commit

Permalink
#258 related. Callback now depends on model context and receive pipel…
Browse files Browse the repository at this point in the history
…ines (epoch-related) via args.
  • Loading branch information
nicolay-r committed Jan 18, 2022
1 parent 21f3048 commit 2dc5c10
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 39 deletions.
4 changes: 2 additions & 2 deletions arekit/common/experiment/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self):
def set_experiment(self, experiment):
self._experiment = experiment

def on_initialized(self, model):
def on_initialized(self, model_ctx):
# Do nothing by default.
pass

Expand All @@ -23,7 +23,7 @@ def on_fit_started(self, operation_cancel):
# Do nothing by default.
pass

def on_epoch_finished(self, epoch_index, operation_cancel):
def on_epoch_finished(self, pipeline, operation_cancel):
# Do nothing by default.
pass

Expand Down
8 changes: 4 additions & 4 deletions arekit/contrib/networks/core/callback/utils_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from arekit.common.utils import create_dir_if_not_exists
from arekit.contrib.networks.core.ctx_predict_log import NetworkInputDependentVariables
from arekit.contrib.networks.core.model import BaseTensorflowModel
from arekit.contrib.networks.core.model_ctx import TensorflowModelContext

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand All @@ -18,11 +18,11 @@
__input_dependent_params_template = 'idparams_{data}_e{epoch_index}'


def save_model_hidden_values(log_dir, model, epoch_index):
def save_model_hidden_values(log_dir, model_ctx, epoch_index):
assert(isinstance(log_dir, str))
assert(isinstance(model, BaseTensorflowModel))
assert(isinstance(model_ctx, TensorflowModelContext))

names, values = model.Context.get_hidden_parameters()
names, values = model_ctx.get_hidden_parameters()

assert (isinstance(names, list))
assert (isinstance(values, list))
Expand Down
23 changes: 7 additions & 16 deletions arekit/contrib/networks/core/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import collections
import logging

from arekit.common.model.base import BaseModel
Expand All @@ -11,6 +10,7 @@
from arekit.contrib.networks.core.network_callback import NetworkCallback
from arekit.contrib.networks.core.params import NeuralNetworkModelParams
from arekit.contrib.networks.core.pipeline_epoch import EpochHandlingPipelineItem
from arekit.contrib.networks.core.utils import get_item_from_pipeline
from arekit.contrib.networks.tf_helpers.nn_states import TensorflowNetworkStatesProvider

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,15 +81,6 @@ def __try_load_state(self):
self.__states_provider.load_model(sess=self.__context.Session,
path_tf_prefix=self.__context.IO.get_model_source_path_tf_prefix())

@staticmethod
def __get_item_from_pipeline(pipeline, item_type):
assert(isinstance(pipeline, list))
assert(issubclass(item_type, EpochHandlingPipelineItem))

for item in pipeline:
if isinstance(item, item_type):
return item

def __fit(self, epochs_count):
assert(isinstance(epochs_count, int))
assert(self.__context.Session is not None)
Expand All @@ -109,7 +100,7 @@ def __fit(self, epochs_count):
self.__run_epoch_pipeline(pipeline=self.__fit_pipeline, data_type=DataType.Train, prefix="Training")

if self.__callback is not None:
self.__callback.on_epoch_finished(epoch_index=epoch_index,
self.__callback.on_epoch_finished(pipeline=self.__fit_pipeline,
operation_cancel=operation_cancel)

if BaseTensorflowModel.SaveTensorflowModelStateOnFit:
Expand All @@ -125,19 +116,19 @@ def fit(self, model_params, seed):
assert(isinstance(model_params, NeuralNetworkModelParams))
self.__context.Network.compile(self.__context.Config, reset_graph=True, graph_seed=seed)
self.__context.set_optimiser()
self.__callback.on_initialized(self)
self.__callback.on_initialized(self.__context)
self.__context.initialize_session()
self.__try_load_state()
self.__fit(epochs_count=model_params.EpochsCount)

self.__context.dispose_session()

def predict(self, data_type=DataType.Test, do_compile=False, graph_seed=0):
""" Fills the related labeling collection.
"""

# Optionally perform network compilation
if do_compile:
self.__context.Network.compile(config=self.__context.Config, reset_graph=True, graph_seed=graph_seed)

self.__context.initialize_session()
self.__try_load_state()
self.__run_epoch_pipeline(pipeline=self.__predict_pipeline,
Expand All @@ -146,8 +137,8 @@ def predict(self, data_type=DataType.Test, do_compile=False, graph_seed=0):

def from_fitted(self, item_type):
assert(issubclass(item_type, EpochHandlingPipelineItem))
return self.__get_item_from_pipeline(pipeline=self.__fit_pipeline, item_type=item_type)
return get_item_from_pipeline(pipeline=self.__fit_pipeline, item_type=item_type)

def from_predicted(self, item_type):
assert (issubclass(item_type, EpochHandlingPipelineItem))
return self.__get_item_from_pipeline(pipeline=self.__predict_pipeline, item_type=item_type)
return get_item_from_pipeline(pipeline=self.__predict_pipeline, item_type=item_type)
27 changes: 15 additions & 12 deletions arekit/contrib/networks/core/network_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from arekit.common.experiment.callback import Callback
from arekit.common.utils import progress_bar_defined
from arekit.contrib.networks.core.cancellation import OperationCancellation
from arekit.contrib.networks.core.model_ctx import TensorflowModelContext
from arekit.contrib.networks.core.pipeline_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.utils import get_item_from_pipeline

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -19,12 +21,13 @@ class NetworkCallback(Callback):

def __init__(self):
super(NetworkCallback, self).__init__()
self.__training_epochs_passed = 0
self._model = None
self._training_epochs_passed = 0
self._model_ctx = None

def on_initialized(self, model):
super(NetworkCallback, self).on_initialized(model)
self._model = model
def on_initialized(self, model_ctx):
assert(isinstance(model_ctx, TensorflowModelContext))
super(NetworkCallback, self).on_initialized(model_ctx)
self._model_ctx = model_ctx

@staticmethod
def __create_epoch_stat(epoch_index, avg_fit_cost, avg_fit_acc):
Expand All @@ -38,22 +41,22 @@ def __create_epoch_stat(epoch_index, avg_fit_cost, avg_fit_acc):
return u"{time}: {epochs}: {avg_fc}, {avg_ac}".format(
time=time, epochs=epochs, avg_fc=avg_fc, avg_ac=avg_ac)

def on_epoch_finished(self, epoch_index, operation_cancel):
assert(isinstance(epoch_index, int))
def on_epoch_finished(self, pipeline, operation_cancel):
assert(isinstance(operation_cancel, OperationCancellation))

item = self._model.from_fitted(MinibatchFittingPipelineItem)
self.__training_epochs_passed += 1
item = get_item_from_pipeline(pipeline=pipeline, item_type=MinibatchFittingPipelineItem)

self._training_epochs_passed += 1

if item is None:
return

assert(isinstance(item, MinibatchFittingPipelineItem))

super(NetworkCallback, self).on_epoch_finished(epoch_index=epoch_index,
super(NetworkCallback, self).on_epoch_finished(pipeline=pipeline,
operation_cancel=operation_cancel)

message = self.__create_epoch_stat(epoch_index=epoch_index,
message = self.__create_epoch_stat(epoch_index=self._training_epochs_passed,
avg_fit_cost=item.TotalFitCost,
avg_fit_acc=item.TotalFitAccuracy)

Expand All @@ -66,5 +69,5 @@ def handle_batches_iter(self, batches_iter, total, prefix, unit='mbs'):
assert(isinstance(batches_iter, collections.Iterable))
assert(isinstance(unit, str))
assert(isinstance(prefix, str))
desc = "{prefix} e={epoch}".format(prefix=prefix, epoch=self.__training_epochs_passed)
desc = "{prefix} e={epoch}".format(prefix=prefix, epoch=self._training_epochs_passed)
return progress_bar_defined(iterable=batches_iter, unit=unit, total=total, desc=desc)
10 changes: 10 additions & 0 deletions arekit/contrib/networks/core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from arekit.contrib.networks.core.pipeline_epoch import EpochHandlingPipelineItem


def get_item_from_pipeline(pipeline, item_type):
assert (isinstance(pipeline, list))
assert (issubclass(item_type, EpochHandlingPipelineItem))

for item in pipeline:
if isinstance(item, item_type):
return item
12 changes: 7 additions & 5 deletions examples/rusentrel/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from arekit.contrib.networks.core.cancellation import OperationCancellation
from arekit.contrib.networks.core.network_callback import NetworkCallback
from arekit.contrib.networks.core.pipeline_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.utils import get_item_from_pipeline

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -29,13 +30,12 @@ def __is_cancel_needed(self, avg_fit_acc):

# endregion

def on_epoch_finished(self, epoch_index, operation_cancel):
assert(isinstance(epoch_index, int))
def on_epoch_finished(self, pipeline, operation_cancel):
assert(isinstance(operation_cancel, OperationCancellation))
super(TrainingCallback, self).on_epoch_finished(epoch_index=epoch_index,
super(TrainingCallback, self).on_epoch_finished(pipeline=pipeline,
operation_cancel=operation_cancel)

item = self._model.from_predicted(MinibatchFittingPipelineItem)
item = get_item_from_pipeline(pipeline=pipeline, item_type=MinibatchFittingPipelineItem)

if item is None:
return
Expand All @@ -49,4 +49,6 @@ def on_epoch_finished(self, epoch_index, operation_cancel):
return

# Saving model hidden values using the related numpy utils.
save_model_hidden_values(log_dir=self.__log_dir, epoch_index=epoch_index, model=self._model)
save_model_hidden_values(log_dir=self.__log_dir,
epoch_index=self._training_epochs_passed,
model_ctx=self._model_ctx)

0 comments on commit 2dc5c10

Please sign in to comment.