Skip to content

Commit

Permalink
#260 done. (#259 related)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 19, 2022
1 parent f8f4fa2 commit f9b7da0
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 70 deletions.
File renamed without changes.
27 changes: 0 additions & 27 deletions arekit/contrib/networks/core/callback/utils_hidden_states.py

This file was deleted.

29 changes: 3 additions & 26 deletions arekit/contrib/networks/core/callback/utils_model_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from os.path import join

from arekit.common.data import const
from arekit.common.data.storages.base import BaseRowsStorage
Expand All @@ -11,7 +10,6 @@
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.item_handle import HandleIterPipelineItem

from arekit.contrib.networks.core.callback.utils_hidden_states import save_minibatch_all_input_dependent_hidden_values
from arekit.contrib.networks.core.ctx_predict_log import NetworkInputDependentVariables
from arekit.contrib.networks.core.model import BaseTensorflowModel
from arekit.contrib.networks.core.pipeline_predict_labeling import EpochLabelsCollectorPipelineItem
Expand All @@ -22,9 +20,9 @@
logger = logging.getLogger(__name__)


# TODO. split onto callback items.
def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
labels_formatter, save_hidden_params,
label_calc_mode, log_dir):
labels_formatter, label_calc_mode, log_dir):
""" Performs Model Evaluation on a particular state (i.e. epoch),
for a particular data type.
"""
Expand All @@ -33,7 +31,6 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
assert(isinstance(model, BaseTensorflowModel))
assert(isinstance(data_type, DataType))
assert(isinstance(epoch_index, int))
assert(isinstance(save_hidden_params, bool))

# Prediction result is a pair of the following parameters:
# idhp -- input dependent variables that might be saved for additional research.
Expand Down Expand Up @@ -103,30 +100,10 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
for _ in pipeline_ctx.provide("src"):
pass

# Evaluate.
# TODO. Callback evaluator.
result = experiment.evaluate(data_type=data_type,
epoch_index=epoch_index)

# optionally save input-dependent hidden parameters.
if save_hidden_params:
save_minibatch_all_input_dependent_hidden_values(
predict_log=idhp,
path_by_var_name_func=lambda var_name: __path_by_var_name(
var_name=var_name,
log_dir=log_dir,
epoch_index=epoch_index,
data_type=data_type))

return result


def __path_by_var_name(var_name, data_type, epoch_index, log_dir):
filname = 'idparams_{data}_e{epoch_index}'.format(
data='{}-{}'.format(var_name, data_type),
epoch_index=epoch_index)

return join(log_dir, filname)


def __calculate_doc_id_by_sample_id_dict(rows_iter):
""" Iter sample_ids with the related labels (if the latter presented in dataframe)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from os.path import join

from arekit.contrib.networks.core.callback.np_writer import NpzDataWriter
from arekit.contrib.networks.core.base_writer import BaseWriter
from arekit.contrib.networks.core.callback_network import NetworkCallback


class HiddenStatesWriterCallback(NetworkCallback):

def __init__(self, log_dir):
def __init__(self, log_dir, writer):
assert(isinstance(writer, BaseWriter))
super(HiddenStatesWriterCallback, self).__init__()

self.__epochs_passed = 0
self.__log_dir = log_dir
self.__writer = NpzDataWriter()
self.__writer = writer

def __target_provider(self, name, epoch_index):
return join(self.__log_dir, 'hparams_{name}_e{epoch}'.format(name=name, epoch=epoch_index))
Expand All @@ -22,6 +23,7 @@ def on_epoch_finished(self, pipeline, operation_cancel):

self.__epochs_passed += 1

# TODO. This might be taken from pipeline item.
names, values = self._model_ctx.get_hidden_parameters()

assert(isinstance(names, list))
Expand Down
57 changes: 57 additions & 0 deletions arekit/contrib/networks/core/callback_hidden_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from os.path import join

from arekit.common.utils import create_dir_if_not_exists
from arekit.contrib.networks.core.callback_network import NetworkCallback
from arekit.contrib.networks.core.ctx_predict_log import NetworkInputDependentVariables
from arekit.contrib.networks.core.pipeline_keep_hidden import MinibatchHiddenFetcherPipelineItem
from arekit.contrib.networks.core.utils import get_item_from_pipeline


class InputHiddenStatesWriterCallback(NetworkCallback):

def __init__(self, log_dir, writer):
super(InputHiddenStatesWriterCallback, self).__init__()
self.__epochs_passed = 0
self.__log_dir = log_dir
self.__writer = writer

@staticmethod
def __path_by_var_name(var_name, data_type, epoch_index, log_dir):
filname = 'idparams_{data}_e{epoch_index}'.format(
data='{}-{}'.format(var_name, data_type),
epoch_index=epoch_index)
return join(log_dir, filname)

def __save_minibatch_variable_values(self, target, predict_log, var_name):
assert(isinstance(predict_log, NetworkInputDependentVariables))
create_dir_if_not_exists(target)
id_and_value_pairs = list(predict_log.iter_by_parameter_values(var_name))
id_and_value_pairs = sorted(id_and_value_pairs, key=lambda pair: pair[0])
self.__writer.write(target=target, data=[pair[1] for pair in id_and_value_pairs])

def on_epoch_finished(self, pipeline, operation_cancel):
super(InputHiddenStatesWriterCallback, self).on_epoch_finished(
pipeline=pipeline,
operation_cancel=operation_cancel)

self.__epochs_passed += 1

pipeline_item = get_item_from_pipeline(pipeline=pipeline,
item_type=MinibatchHiddenFetcherPipelineItem)

if pipeline_item is None:
return

predict_log = pipeline_item.InputDependentParams
data_type = pipeline_item.DataType

for var_name in predict_log.iter_var_names():
target = self.__path_by_var_name(var_name=var_name,
data_type=data_type,
epoch_index=self.__epochs_passed,
log_dir=self.__log_dir)

self.__save_minibatch_variable_values(
target=target,
predict_log=predict_log,
var_name=var_name)
4 changes: 4 additions & 0 deletions arekit/contrib/networks/core/pipeline_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ def __init__(self):
self._context = None
self._data_type = None

@property
def DataType(self):
return self._data_type

def before_epoch(self, model_context, data_type):
assert(isinstance(model_context, TensorflowModelContext))
self._context = model_context
Expand Down
10 changes: 7 additions & 3 deletions arekit/contrib/networks/core/pipeline_keep_hidden.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ def __init__(self):
super(MinibatchHiddenFetcherPipelineItem, self).__init__()
self.__idh_names = None
self.__idh_tensors = None
self.__predict_log = None
self.__input_dependent_params = None

@property
def InputDependentParams(self):
return self.__input_dependent_params

def before_epoch(self, model_context, data_type):
super(MinibatchHiddenFetcherPipelineItem, self).before_epoch(
Expand All @@ -22,7 +26,7 @@ def before_epoch(self, model_context, data_type):
self.__idh_names.append(name)
self.__idh_tensors.append(tensor)

self.__predict_log = NetworkInputDependentVariables()
self.__input_dependent_params = NetworkInputDependentVariables()

def apply(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
Expand All @@ -36,7 +40,7 @@ def apply(self, pipeline_ctx):
if not (len(self.__idh_names) > 0 and len(idh_values) > 0):
return

self.__predict_log.add_input_dependent_values(
self.__input_dependent_params.add_input_dependent_values(
names_list=self.__idh_names,
tensor_values_list=idh_values,
text_opinion_ids=[sample.ID for sample in
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,15 @@

import numpy as np

from arekit.common.data.input.writers.base import BaseWriter
from arekit.common.utils import create_dir_if_not_exists

from arekit.contrib.networks.core.base_writer import BaseWriter

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class NpzDataWriter(BaseWriter):

def __init__(self):
pass

def write(self, data, target):
assert(isinstance(target, str))
logger.info("Save: {}".format(target))
Expand Down
8 changes: 4 additions & 4 deletions arekit/contrib/networks/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from arekit.common.experiment.engine import ExperimentEngine
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.contrib.networks.core.callback_stat import TrainingStatProviderCallback
from arekit.contrib.networks.core.ctx_inference import InferenceContext
from arekit.contrib.networks.core.feeding.bags.collection.base import BagsCollection
from arekit.contrib.networks.core.model import BaseTensorflowModel
Expand All @@ -16,8 +15,6 @@
from arekit.contrib.networks.core.pipeline_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.shapes import NetworkInputShapes
from arekit.contrib.networks.utils import rm_dir_contents
from examples.rusentrel.callback_hidden import HiddenStatesWriterCallback
from examples.rusentrel.callback_training import TrainingLimiterCallback

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -113,7 +110,10 @@ def _handle_iteration(self, it_index):
EpochLabelsCollectorPipelineItem(),
MinibatchHiddenFetcherPipelineItem()
],
fit_pipeline=[MinibatchFittingPipelineItem()])
fit_pipeline=[
MinibatchFittingPipelineItem(),
MinibatchHiddenFetcherPipelineItem()
])

# Initialize model params instance.
model_params = NeuralNetworkModelParams(epochs_count=self.__training_epochs)
Expand Down
9 changes: 7 additions & 2 deletions examples/run_rusentrel_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from arekit.common.folding.types import FoldingType
from arekit.contrib.experiment_rusentrel.factory import create_experiment
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.contrib.networks.core.callback_hidden import HiddenStatesWriterCallback
from arekit.contrib.networks.core.callback_hidden_input import InputHiddenStatesWriterCallback
from arekit.contrib.networks.core.callback_stat import TrainingStatProviderCallback
from arekit.contrib.networks.factory import create_network_and_network_config_funcs
from arekit.contrib.networks.np_utils.writer import NpzDataWriter
from arekit.contrib.networks.run_training import NetworksTrainingEngine
from arekit.contrib.source.ruattitudes.io_utils import RuAttitudesVersions
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions
Expand All @@ -19,7 +22,6 @@
from examples.network.args.train import BagsPerMinibatchArg, DropoutKeepProbArg, EpochsCountArg, LearningRateArg, \
ModelInputTypeArg, ModelNameTagArg
from examples.network.common import create_bags_collection_type, create_network_model_io
from examples.rusentrel.callback_hidden import HiddenStatesWriterCallback
from examples.rusentrel.callback_training import TrainingLimiterCallback
from examples.rusentrel.common import Common
from examples.rusentrel.config_setups import optionally_modify_config_for_experiment, modify_config_for_model
Expand Down Expand Up @@ -146,10 +148,13 @@
model_input_type=model_input_type,
config=config)

data_writer = NpzDataWriter()

nework_callbacks = [
TrainingLimiterCallback(train_acc_limit=0.99),
TrainingStatProviderCallback(),
HiddenStatesWriterCallback(log_dir=model_target_dir),
HiddenStatesWriterCallback(log_dir=model_target_dir, writer=data_writer),
InputHiddenStatesWriterCallback(log_dir=model_target_dir, writer=data_writer)
]

training_engine = NetworksTrainingEngine(load_model=model_load_dir is not None,
Expand Down

0 comments on commit f9b7da0

Please sign in to comment.