Skip to content

Commit

Permalink
Fixed #107.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed May 6, 2021
1 parent a425246 commit b259584
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 8 deletions.
3 changes: 1 addition & 2 deletions common/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def __init__(self, io):
def IO(self):
return self.__io

# TODO. Remove epochs count, since it is related to NeuralNetworks only.
def run_training(self, epochs_count, seed):
def run_training(self, model_params, seed):
raise NotImplementedError()

def predict(self, data_type=DataType.Test):
Expand Down
4 changes: 4 additions & 0 deletions common/model/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class BaseModelParams(object):

def __init__(self):
pass
7 changes: 5 additions & 2 deletions contrib/networks/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from arekit.contrib.networks.core.feeding.batch.multi import MultiInstanceMiniBatch
from arekit.contrib.networks.core.model_io import NeuralNetworkModelIO
from arekit.contrib.networks.core.nn import NeuralNetwork
from arekit.contrib.networks.core.params import NeuralNetworkModelParams

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,7 +99,9 @@ def __dispose_session(self):
"""
self.__sess.close()

def run_training(self, epochs_count, seed):
def run_training(self, model_params, seed):
assert(isinstance(model_params, NeuralNetworkModelParams))

self.__network.compile(self.Config, reset_graph=True, graph_seed=seed)
self.set_optimiser()
self.__notify_initialized()
Expand All @@ -110,7 +113,7 @@ def run_training(self, epochs_count, seed):
logger.info(u"Loading model: {}".format(saved_model_path))
self.load_model(saved_model_path)

self.fit(epochs_count=epochs_count)
self.fit(epochs_count=model_params.EpochsCount)
self.__dispose_session()

# endregion
Expand Down
11 changes: 11 additions & 0 deletions contrib/networks/core/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from arekit.common.model.params import BaseModelParams


class NeuralNetworkModelParams(BaseModelParams):

def __init__(self, epochs_count):
self.__epochs_count = epochs_count

@property
def EpochsCount(self):
return self.__epochs_count
12 changes: 8 additions & 4 deletions contrib/networks/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from arekit.contrib.networks.core.data_handling.data import HandledData
from arekit.contrib.networks.core.feeding.bags.collection.base import BagsCollection
from arekit.contrib.networks.core.model import BaseTensorflowModel

from arekit.contrib.networks.core.params import NeuralNetworkModelParams

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -81,11 +81,11 @@ def _handle_iteration(self, it_index):
# Update parameters after iteration preparation has been completed.
self.__config.reinit_config_dependent_parameters()

# Setup callback
# Setup callback.
callback = self._experiment.DataIO.Callback
callback.on_experiment_iteration_begin()

# Initialize network and model
# Initialize network and model.
network = self.__create_network_func()
model = BaseTensorflowModel(network=network,
config=self.__config,
Expand All @@ -94,9 +94,13 @@ def _handle_iteration(self, it_index):
callback=callback,
nn_io=self._experiment.DataIO.ModelIO)

# Initialize model params instance.
model_params = NeuralNetworkModelParams(epochs_count=callback.Epochs)

# Run model
with callback:
model.run_training(epochs_count=callback.Epochs, seed=self.__seed)
model.run_training(model_params=model_params,
seed=self.__seed)

del network
del model
Expand Down

0 comments on commit b259584

Please sign in to comment.