Skip to content

Commit

Permalink
- #261 related, #259 done.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 19, 2022
1 parent 2f2f919 commit f8f4fa2
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 152 deletions.
8 changes: 4 additions & 4 deletions arekit/common/experiment/api/ctx_training.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from arekit.common.experiment.api.ctx_base import DataIO
from arekit.common.experiment.callback import Callback
from arekit.common.experiment.callback import ExperimentCallback


class TrainingData(DataIO):
""" Data, that is necessary for models training stage.
"""

def __init__(self, labels_count, callback):
assert(isinstance(callback, Callback))
assert(isinstance(callback, ExperimentCallback))
super(TrainingData, self).__init__()
self.__labels_count = labels_count
self.__callback = callback
self.__exp_callback = callback

@property
def LabelsCount(self):
Expand All @@ -22,4 +22,4 @@ def Evaluator(self):

@property
def Callback(self):
return self.__callback
return self.__exp_callback
24 changes: 1 addition & 23 deletions arekit/common/experiment/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,12 @@
logger.setLevel(logging.INFO)


class Callback(object):

def __init__(self):
self._experiment = None

def set_experiment(self, experiment):
self._experiment = experiment

def on_initialized(self, model_ctx):
# Do nothing by default.
pass
class ExperimentCallback(object):

def on_experiment_iteration_begin(self):
# Do nothing by default.
pass

def on_fit_started(self, operation_cancel):
# Do nothing by default.
pass

def on_epoch_finished(self, pipeline, operation_cancel):
# Do nothing by default.
pass

def on_fit_finished(self):
# Do nothing by default.
pass

def on_experiment_finished(self):
# Do nothing by default.
pass
Expand Down
20 changes: 20 additions & 0 deletions arekit/contrib/networks/core/callback_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from arekit.contrib.networks.core.model_ctx import TensorflowModelContext


class NetworkCallback(object):

def __init__(self):
super(NetworkCallback, self).__init__()
self._model_ctx = None

def on_initialized(self, model_ctx):
assert(isinstance(model_ctx, TensorflowModelContext))
self._model_ctx = model_ctx

def on_fit_started(self, operation_cancel):
# Do nothing by default.
pass

def on_epoch_finished(self, pipeline, operation_cancel):
# Do nothing by default.
pass
Original file line number Diff line number Diff line change
@@ -1,34 +1,26 @@
import collections
import logging
from datetime import datetime

from arekit.common.experiment.callback import Callback
from arekit.common.utils import progress_bar_defined
from arekit.contrib.networks.core.callback_network import NetworkCallback
from arekit.contrib.networks.core.cancellation import OperationCancellation
from arekit.contrib.networks.core.model_ctx import TensorflowModelContext
from arekit.contrib.networks.core.pipeline_fit import MinibatchFittingPipelineItem
from arekit.contrib.networks.core.utils import get_item_from_pipeline

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class NetworkCallback(Callback):
class TrainingStatProviderCallback(NetworkCallback):
""" Represent network callback which provides
wrappers for batches iterations
and epoch termination notifications.
"""

def __init__(self):
super(NetworkCallback, self).__init__()
super(TrainingStatProviderCallback, self).__init__()
self._training_epochs_passed = 0
self._model_ctx = None

def on_initialized(self, model_ctx):
assert(isinstance(model_ctx, TensorflowModelContext))
super(NetworkCallback, self).on_initialized(model_ctx)
self._model_ctx = model_ctx

@staticmethod
def __create_epoch_stat(epoch_index, avg_fit_cost, avg_fit_acc):
""" Providing epoch training results notification.
Expand All @@ -53,21 +45,9 @@ def on_epoch_finished(self, pipeline, operation_cancel):

assert(isinstance(item, MinibatchFittingPipelineItem))

super(NetworkCallback, self).on_epoch_finished(pipeline=pipeline,
operation_cancel=operation_cancel)

message = self.__create_epoch_stat(epoch_index=self._training_epochs_passed,
avg_fit_cost=item.TotalFitCost,
avg_fit_acc=item.TotalFitAccuracy)

# Providing information into main logger.
logger.info(message)

def handle_batches_iter(self, batches_iter, total, prefix, unit='mbs'):
""" Do wrapping progress notification.
"""
assert(isinstance(batches_iter, collections.Iterable))
assert(isinstance(unit, str))
assert(isinstance(prefix, str))
desc = "{prefix} e={epoch}".format(prefix=prefix, epoch=self._training_epochs_passed)
return progress_bar_defined(iterable=batches_iter, unit=unit, total=total, desc=desc)
46 changes: 31 additions & 15 deletions arekit/contrib/networks/core/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import collections
import logging

from arekit.common.model.base import BaseModel
from arekit.common.experiment.data_type import DataType
from arekit.common.pipeline.context import PipelineContext
from arekit.common.utils import progress_bar_defined

from arekit.contrib.networks.core.cancellation import OperationCancellation
from arekit.contrib.networks.core.feeding.bags.collection.factory import create_batch_by_bags_group
from arekit.contrib.networks.core.model_ctx import TensorflowModelContext
from arekit.contrib.networks.core.network_callback import NetworkCallback
from arekit.contrib.networks.core.params import NeuralNetworkModelParams
from arekit.contrib.networks.core.pipeline_epoch import EpochHandlingPipelineItem
from arekit.contrib.networks.core.utils import get_item_from_pipeline
Expand All @@ -20,17 +21,17 @@ class BaseTensorflowModel(BaseModel):

SaveTensorflowModelStateOnFit = True

def __init__(self, context, callback,
def __init__(self, context, callbacks,
predict_pipeline=None,
fit_pipeline=None):
assert(isinstance(context, TensorflowModelContext))
assert(isinstance(callback, NetworkCallback))
assert(isinstance(callbacks, list))
assert(isinstance(predict_pipeline, list))
assert(isinstance(fit_pipeline, list))
super(BaseTensorflowModel, self).__init__()

self.__context = context
self.__callback = callback
self.__callbacks = callbacks
self.__predict_pipeline = predict_pipeline
self.__fit_pipeline = fit_pipeline
self.__states_provider = TensorflowNetworkStatesProvider()
Expand All @@ -41,6 +42,20 @@ def Context(self):

# region private methods

def __callback_do(self, call_func):
for callback in self.__callbacks:
call_func(callback)

@staticmethod
def __handle_batches_iter(batches_iter, total, prefix, unit='mbs'):
""" Do wrapping progress notification.
"""
assert(isinstance(batches_iter, collections.Iterable))
assert(isinstance(unit, str))
assert(isinstance(prefix, str))
desc = "{prefix}".format(prefix=prefix)
return progress_bar_defined(iterable=batches_iter, unit=unit, total=total, desc=desc)

def __run_epoch_pipeline(self, data_type, pipeline, prefix):
assert(isinstance(pipeline, list))
assert(isinstance(prefix, str))
Expand All @@ -53,10 +68,10 @@ def __run_epoch_pipeline(self, data_type, pipeline, prefix):
"(Might be greater or equal, as the last "
"bag is expanded)".format(minibatches_count))

groups_it = self.__callback.handle_batches_iter(
groups_it = self.__handle_batches_iter(
batches_iter=bags_collection.iter_by_groups(bags_per_group=bags_per_group),
total=minibatches_count,
prefix="Training")
prefix=prefix)

for item in pipeline:
assert(isinstance(item, EpochHandlingPipelineItem))
Expand Down Expand Up @@ -84,11 +99,11 @@ def __try_load_state(self):
def __fit(self, epochs_count):
assert(isinstance(epochs_count, int))
assert(self.__context.Session is not None)
assert(isinstance(self.__callback, NetworkCallback))

operation_cancel = OperationCancellation()
bags_collection = self.__context.get_bags_collection(DataType.Train)
self.__callback.on_fit_started(operation_cancel)

self.__callback_do(lambda callback: callback.on_fit_started(operation_cancel))

for epoch_index in range(epochs_count):

Expand All @@ -97,26 +112,27 @@ def __fit(self, epochs_count):

bags_collection.shuffle()

self.__run_epoch_pipeline(pipeline=self.__fit_pipeline, data_type=DataType.Train, prefix="Training")
self.__run_epoch_pipeline(pipeline=self.__fit_pipeline,
data_type=DataType.Train,
prefix="Training")

if self.__callback is not None:
self.__callback.on_epoch_finished(pipeline=self.__fit_pipeline,
operation_cancel=operation_cancel)
self.__callback_do(lambda callback: callback.on_epoch_finished(
pipeline=self.__fit_pipeline,
operation_cancel=operation_cancel))

if BaseTensorflowModel.SaveTensorflowModelStateOnFit:
self.__states_provider.save_model(sess=self.__context.Session,
path_tf_prefix=self.__context.IO.get_model_target_path_tf_prefix())

if self.__callback is not None:
self.__callback.on_fit_finished()
self.__callback_do(lambda callback: callback.on_fit_finished())

# endregion

def fit(self, model_params, seed):
assert(isinstance(model_params, NeuralNetworkModelParams))
self.__context.Network.compile(self.__context.Config, reset_graph=True, graph_seed=seed)
self.__context.set_optimiser()
self.__callback.on_initialized(self.__context)
self.__callback_do(lambda callback: callback.on_initialized(self.__context)),
self.__context.initialize_session()
self.__try_load_state()
self.__fit(epochs_count=model_params.EpochsCount)
Expand Down
12 changes: 7 additions & 5 deletions arekit/contrib/networks/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from arekit.common.experiment.engine import ExperimentEngine
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.contrib.networks.core.callback_stat import TrainingStatProviderCallback
from arekit.contrib.networks.core.ctx_inference import InferenceContext
from arekit.contrib.networks.core.feeding.bags.collection.base import BagsCollection
from arekit.contrib.networks.core.model import BaseTensorflowModel
Expand All @@ -15,6 +16,8 @@
from arekit.contrib.networks.core.pipeline_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.shapes import NetworkInputShapes
from arekit.contrib.networks.utils import rm_dir_contents
from examples.rusentrel.callback_hidden import HiddenStatesWriterCallback
from examples.rusentrel.callback_training import TrainingLimiterCallback

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand All @@ -26,6 +29,7 @@ def __init__(self, bags_collection_type, experiment,
load_model, config,
create_network_func,
training_epochs,
network_callbacks,
prepare_model_root=True,
seed=None):
assert(callable(create_network_func))
Expand All @@ -34,13 +38,15 @@ def __init__(self, bags_collection_type, experiment,
assert(isinstance(load_model, bool))
assert(isinstance(seed, int) or seed is None)
assert(isinstance(training_epochs, int))
assert(isinstance(network_callbacks, list))

super(NetworksTrainingEngine, self).__init__(experiment)

self.__clear_model_root_before_experiment = prepare_model_root
self.__config = config
self.__create_network_func = create_network_func
self.__bags_collection_type = bags_collection_type
self.__network_callbacks = network_callbacks
self.__load_model = load_model
self.__training_epochs = training_epochs
self.__seed = seed
Expand Down Expand Up @@ -101,7 +107,7 @@ def _handle_iteration(self, it_index):
inference_ctx=inference_ctx,
bags_collection_type=self.__bags_collection_type,
nn_io=self._experiment.DataIO.ModelIO),
callback=callback,
callbacks=self.__network_callbacks,
predict_pipeline=[
EpochLabelsPredictorPipelineItem(),
EpochLabelsCollectorPipelineItem(),
Expand Down Expand Up @@ -129,10 +135,6 @@ def _before_running(self):
rm_dir_contents(dir_path=self.__get_model_dir(),
logger=self._logger)

# Setup callback
callback = self._experiment.DataIO.Callback
callback.set_experiment(self._experiment)

# Disable tensorflow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

Expand Down
18 changes: 12 additions & 6 deletions examples/run_rusentrel_train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse

from arekit.common.experiment.api.ctx_training import TrainingData
from arekit.common.experiment.callback import ExperimentCallback
from arekit.common.folding.types import FoldingType
from arekit.contrib.experiment_rusentrel.factory import create_experiment
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.contrib.networks.core.callback_stat import TrainingStatProviderCallback
from arekit.contrib.networks.factory import create_network_and_network_config_funcs
from arekit.contrib.networks.run_training import NetworksTrainingEngine
from arekit.contrib.source.ruattitudes.io_utils import RuAttitudesVersions
Expand All @@ -17,7 +19,8 @@
from examples.network.args.train import BagsPerMinibatchArg, DropoutKeepProbArg, EpochsCountArg, LearningRateArg, \
ModelInputTypeArg, ModelNameTagArg
from examples.network.common import create_bags_collection_type, create_network_model_io
from examples.rusentrel.callback import TrainingCallback
from examples.rusentrel.callback_hidden import HiddenStatesWriterCallback
from examples.rusentrel.callback_training import TrainingLimiterCallback
from examples.rusentrel.common import Common
from examples.rusentrel.config_setups import optionally_modify_config_for_experiment, modify_config_for_model
from examples.rusentrel.exp_io import CustomRuSentRelNetworkExperimentIO
Expand Down Expand Up @@ -81,13 +84,9 @@

labels_scaler = Common.create_labels_scaler(labels_count)

# Initialize callback.
callback = TrainingCallback(train_acc_limit=0.99,
log_dir=model_target_dir)

# Creating experiment
experiment_data = TrainingData(labels_count=labels_scaler.LabelsCount,
callback=callback)
callback=ExperimentCallback())

extra_name_suffix = Common.create_exp_name_suffix(
use_balancing=use_balancing,
Expand Down Expand Up @@ -147,11 +146,18 @@
model_input_type=model_input_type,
config=config)

nework_callbacks = [
TrainingLimiterCallback(train_acc_limit=0.99),
TrainingStatProviderCallback(),
HiddenStatesWriterCallback(log_dir=model_target_dir),
]

training_engine = NetworksTrainingEngine(load_model=model_load_dir is not None,
experiment=experiment,
create_network_func=network_func,
config=config,
bags_collection_type=bags_collection_type,
network_callbacks=nework_callbacks,
training_epochs=epochs_count)

training_engine.run()
Loading

0 comments on commit f8f4fa2

Please sign in to comment.