Skip to content

Commit

Permalink
#258 simplified. Using single pipeline runner implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 18, 2022
1 parent 635feed commit 21f3048
Showing 1 changed file with 10 additions and 33 deletions.
43 changes: 10 additions & 33 deletions arekit/contrib/networks/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def Context(self):

# region private methods

def __fit_epoch(self):
data_type = DataType.Train
def __run_epoch_pipeline(self, data_type, pipeline, prefix):
assert(isinstance(pipeline, list))
assert(isinstance(prefix, str))

bags_per_group = self.__context.Config.BagsPerMinibatch
bags_collection = self.__context.get_bags_collection(data_type)
bags_collection.shuffle()

minibatches_count = bags_collection.get_groups_count(bags_per_group)

logger.info("Minibatches passing per epoch count: ~{} "
"(Might be greater or equal, as the last "
"bag is expanded)".format(minibatches_count))
Expand All @@ -58,45 +58,20 @@ def __fit_epoch(self):
total=minibatches_count,
prefix="Training")

self.__run_epoch_pipeline(batches_it=groups_it,
pipeline=self.__fit_pipeline,
data_type=data_type)

def __predict(self, data_type):
assert(isinstance(data_type, DataType))
bags_collection = self.__context.get_bags_collection(data_type)
bags_per_group = self.__context.Config.BagsPerMinibatch
minibatches_count = bags_collection.get_groups_count(bags_per_group)
groups_it = self.__callback.handle_batches_iter(
batches_iter=bags_collection.iter_by_groups(bags_per_group=bags_per_group,
text_opinion_ids_set=None),
total=minibatches_count,
prefix="Predict [{dtype}]".format(dtype=data_type))

self.__run_epoch_pipeline(batches_it=groups_it,
pipeline=self.__predict_pipeline,
data_type=data_type)

def __run_epoch_pipeline(self, batches_it, data_type, pipeline):
assert(isinstance(batches_it, collections.Iterable))
assert(isinstance(pipeline, list))

for item in pipeline:
assert(isinstance(item, EpochHandlingPipelineItem))
item.before_epoch(model_context=self.__context,
data_type=data_type)

for bags_group in batches_it:
for bags_group in groups_it:
assert(isinstance(bags_group, list))

# Composing minibatch from bags group.
minibatch = create_batch_by_bags_group(
bags_coolection_type=self.__context.BagsCollectionType,
bags_group=bags_group)

ctx = PipelineContext({
"src": minibatch
})
ctx = PipelineContext({"src": minibatch})

for item in pipeline:
item.apply(pipeline_ctx=ctx)
Expand Down Expand Up @@ -131,7 +106,7 @@ def __fit(self, epochs_count):

bags_collection.shuffle()

self.__fit_epoch()
self.__run_epoch_pipeline(pipeline=self.__fit_pipeline, data_type=DataType.Train, prefix="Training")

if self.__callback is not None:
self.__callback.on_epoch_finished(epoch_index=epoch_index,
Expand Down Expand Up @@ -165,7 +140,9 @@ def predict(self, data_type=DataType.Test, do_compile=False, graph_seed=0):
self.__context.Network.compile(config=self.__context.Config, reset_graph=True, graph_seed=graph_seed)
self.__context.initialize_session()
self.__try_load_state()
return self.__predict(data_type=data_type)
self.__run_epoch_pipeline(pipeline=self.__predict_pipeline,
data_type=data_type,
prefix="Predict [{dtype}]".format(dtype=data_type))

def from_fitted(self, item_type):
assert(issubclass(item_type, EpochHandlingPipelineItem))
Expand Down

0 comments on commit 21f3048

Please sign in to comment.