Skip to content

Commit

Permalink
#262 associated, Simplify DataIO. Switching to handlers.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 20, 2022
1 parent 5cb7ac7 commit 32065fa
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 89 deletions.
9 changes: 1 addition & 8 deletions arekit/common/experiment/api/ctx_training.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from arekit.common.experiment.api.ctx_base import DataIO
from arekit.common.experiment.callback import ExperimentCallback


class TrainingData(DataIO):
""" Data, that is necessary for models training stage.
"""

def __init__(self, labels_count, callback):
assert(isinstance(callback, ExperimentCallback))
def __init__(self, labels_count):
super(TrainingData, self).__init__()
self.__labels_count = labels_count
self.__exp_callback = callback

@property
def LabelsCount(self):
Expand All @@ -19,7 +16,3 @@ def LabelsCount(self):
@property
def Evaluator(self):
raise NotImplementedError()

@property
def Callback(self):
return self.__exp_callback
20 changes: 0 additions & 20 deletions arekit/common/experiment/callback.py

This file was deleted.

21 changes: 15 additions & 6 deletions arekit/common/experiment/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, experiment):
assert(isinstance(experiment, BaseExperiment))
self._experiment = experiment
self._logger = self.__create_logger()
self.__handlers = None

@staticmethod
def __create_logger():
Expand All @@ -24,29 +25,37 @@ def __create_logger():
logger.addHandler(stream_handler)
return logger

def __call_all_handlers(self, call_func):
assert(callable(call_func))

if self.__handlers is None:
return

for handler in self.__handlers:
call_func(handler)

# region protected methods

def _handle_iteration(self, iter_index):
raise NotImplementedError()
self.__call_all_handlers(lambda callback: callback.on_iteration(iter_index))

def _before_running(self):
""" Optional method that allows to implement actions before experiment iterations.
"""
pass
self.__call_all_handlers(lambda callback: callback.on_before_iteration())

def _after_running(self):
""" Optional method that allows to implement actions after experiment iterations.
"""
self.__call_all_handlers(lambda callback: callback.on_after_iteration())

# endregion

def run(self):
def run(self, handlers=None):
""" Running cv_index iteration and calling handler during every iteration.
"""

self.__handlers = handlers
self._before_running()

for iter_index, _ in enumerate(self._experiment.DocumentOperations.DataFolding.iter_states()):
self._handle_iteration(iter_index)

self._after_running()
17 changes: 17 additions & 0 deletions arekit/common/experiment/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from arekit.common.experiment.api.ctx_base import DataIO


class ExperimentEngineHandler(object):

def __init__(self, exp_data):
assert(isinstance(exp_data, DataIO))
self._exp_data = exp_data

def on_before_iteration(self):
pass

def on_iteration(self, iter_index):
pass

def on_after_iteration(self):
pass
49 changes: 9 additions & 40 deletions arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
logging.basicConfig(level=logging.INFO)


# TODO. 262. Refactor as handler (weird inheritance, limits capabilities).
class LanguageModelExperimentEvaluator(ExperimentEngine):

def __init__(self, experiment, data_type, eval_helper, max_epochs_count,
label_scaler, labels_formatter, eval_last_only=True, log_dir="./"):
label_scaler, labels_formatter, eval_last_only=True):
assert(isinstance(eval_helper, EvalHelper))
assert(isinstance(max_epochs_count, int))
assert(isinstance(eval_last_only, bool))
assert(isinstance(label_scaler, BaseLabelScaler))
assert(isinstance(labels_formatter, StringLabelsFormatter))
assert(isinstance(log_dir, str))

super(LanguageModelExperimentEvaluator, self).__init__(experiment=experiment)

Expand All @@ -37,19 +37,9 @@ def __init__(self, experiment, data_type, eval_helper, max_epochs_count,
self.__eval_last_only = eval_last_only
self.__labels_formatter = labels_formatter
self.__label_scaler = label_scaler
self.__log_dir = log_dir

def _log_info(self, message, forced=False):
assert(isinstance(message, str))

if not self._experiment._do_log and not forced:
return

logger.info(message)

def __run_pipeline(self, epoch_index, iter_index):
exp_io = self._experiment.ExperimentIO
exp_data = self._experiment.DataIO
doc_ops = self._experiment.DocumentOperations

cmp_doc_ids_set = set(doc_ops.iter_tagget_doc_ids(BaseDocumentTag.Compare))
Expand Down Expand Up @@ -96,37 +86,16 @@ def __run_pipeline(self, epoch_index, iter_index):
def _handle_iteration(self, iter_index):
exp_data = self._experiment.DataIO
assert(isinstance(exp_data, TrainingData))

# Setup callback.
callback = exp_data.Callback
assert(isinstance(callback, Callback))
callback.set_iter_index(iter_index)
super(LanguageModelExperimentEvaluator, self)._handle_iteration(iter_index)

if not self._experiment.ExperimentIO.try_prepare():
return

if callback.check_log_exists():
self._log_info("Skipping [Log file already exist]")
return

with callback:
for epoch_index in reversed(list(range(self.__max_epochs_count))):

# Perform iteration related actions.
self.__run_pipeline(epoch_index=epoch_index, iter_index=iter_index)

# Evaluate.
result = self._experiment.evaluate(data_type=self.__data_type, epoch_index=epoch_index)
result.calculate()

# Saving results.
callback.write_results(result=result, data_type=self.__data_type, epoch_index=epoch_index)
for epoch_index in reversed(list(range(self.__max_epochs_count))):

if self.__eval_last_only:
self._log_info("Evaluation done [Evaluating last only]")
return
# Perform iteration related actions.
self.__run_pipeline(epoch_index=epoch_index, iter_index=iter_index)

def _before_running(self):
# Providing a root dir for logging.
callback = self._experiment.DataIO.Callback
callback.set_log_dir(self.__log_dir)
# Evaluate.
result = self._experiment.evaluate(data_type=self.__data_type, epoch_index=epoch_index)
result.calculate()
1 change: 1 addition & 0 deletions arekit/contrib/bert/run_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from arekit.contrib.bert.samplers.factory import create_bert_sample_provider


# TODO. 262. Refactor as handler (weird inheritance, limits capabilities).
class BertExperimentInputSerializer(ExperimentEngine):

def __init__(self, experiment,
Expand Down
1 change: 1 addition & 0 deletions arekit/contrib/networks/run_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from arekit.contrib.networks.core.input.helper import NetworkInputHelper


# TODO. 262. Refactor as handler (weird inheritance, limits capabilities).
class NetworksExperimentInputSerializer(ExperimentEngine):

def __init__(self, experiment, force_serialize, value_to_group_id_func, balance, skip_folder_if_exists):
Expand Down
14 changes: 3 additions & 11 deletions arekit/contrib/networks/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
logging.basicConfig(level=logging.INFO)


# TODO. 262. Refactor as handler (weird inheritance, limits capabilities).
class NetworksTrainingEngine(ExperimentEngine):

def __init__(self, bags_collection_type, experiment,
Expand Down Expand Up @@ -91,10 +92,6 @@ def _handle_iteration(self, it_index):
# Update parameters after iteration preparation has been completed.
self.__config.reinit_config_dependent_parameters()

# Setup callback.
callback = self._experiment.DataIO.Callback
callback.on_experiment_iteration_begin()

# Initialize network and model.
network = self.__create_network_func()
model = BaseTensorflowModel(
Expand All @@ -118,10 +115,8 @@ def _handle_iteration(self, it_index):
# Initialize model params instance.
model_params = NeuralNetworkModelParams(epochs_count=self.__training_epochs)

# Run model
with callback:
model.fit(model_params=model_params,
seed=self.__seed)
model.fit(model_params=model_params,
seed=self.__seed)

del network
del model
Expand All @@ -141,7 +136,4 @@ def _before_running(self):
# Notify other subscribers that initialization process has been completed.
self.__config.init_initializers()

def _after_running(self):
self._experiment.DataIO.Callback.on_experiment_finished()

# endregion
9 changes: 5 additions & 4 deletions examples/run_rusentrel_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse

from arekit.common.experiment.api.ctx_training import TrainingData
from arekit.common.experiment.callback import ExperimentCallback
from arekit.common.experiment.handler import ExperimentEngineHandler
from arekit.common.folding.types import FoldingType
from arekit.contrib.experiment_rusentrel.factory import create_experiment
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
Expand Down Expand Up @@ -87,8 +87,7 @@
labels_scaler = Common.create_labels_scaler(labels_count)

# Creating experiment
experiment_data = TrainingData(labels_count=labels_scaler.LabelsCount,
callback=ExperimentCallback())
experiment_data = TrainingData(labels_count=labels_scaler.LabelsCount)

extra_name_suffix = Common.create_exp_name_suffix(
use_balancing=use_balancing,
Expand Down Expand Up @@ -165,4 +164,6 @@
network_callbacks=nework_callbacks,
training_epochs=epochs_count)

training_engine.run()
training_engine.run(handlers=[
ExperimentEngineHandler(exp_data=experiment_data)
])

0 comments on commit 32065fa

Please sign in to comment.