Skip to content

Commit

Permalink
Fixed bug: couple times utitilized iter (doc_ids_iter) works only once (
Browse files Browse the repository at this point in the history
#195 related)
  • Loading branch information
nicolay-r committed Dec 27, 2021
1 parent 431ccb5 commit 150535c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
8 changes: 3 additions & 5 deletions arekit/common/data/input/repositories/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import collections

from arekit.common.data.input.providers.columns.base import BaseColumnsProvider
from arekit.common.data.input.providers.opinions import OpinionProvider
from arekit.common.data.input.providers.rows.base import BaseRowProvider
Expand Down Expand Up @@ -31,15 +29,15 @@ def _setup_rows_provider(self):

# endregion

def populate(self, opinion_provider, doc_ids_iter, desc=""):
def populate(self, opinion_provider, doc_ids, desc=""):
assert(isinstance(opinion_provider, OpinionProvider))
assert(isinstance(self._storage, BaseRowsStorage))
assert(isinstance(doc_ids_iter, collections.Iterable))
assert(isinstance(doc_ids, list))

def iter_rows(idle_mode):
return self._rows_provider.iter_by_rows(
opinion_provider=opinion_provider,
doc_ids_iter=doc_ids_iter,
doc_ids_iter=doc_ids,
idle_mode=idle_mode)

self._storage.init_empty(columns_provider=self._columns_provider)
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/bert/run_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def __handle_iteration(self, data_type):

# Populate repositories
opinions_repo.populate(opinion_provider=opinion_provider,
doc_ids_iter=self._experiment.DocumentOperations.iter_doc_ids(data_type),
doc_ids=list(self._experiment.DocumentOperations.iter_doc_ids(data_type)),
desc="opinion")

samples_repo.populate(opinion_provider=opinion_provider,
doc_ids_iter=self._experiment.DocumentOperations.iter_doc_ids(data_type),
doc_ids=list(self._experiment.DocumentOperations.iter_doc_ids(data_type)),
desc="sample")

if self._experiment.ExperimentIO.balance_samples(data_type=data_type, balance=self.__balance_train_samples):
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/networks/core/input/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def __perform_writing(experiment, data_type, opinion_provider,

# Populate repositories
opinions_repo.populate(opinion_provider=opinion_provider,
doc_ids_iter=experiment.DocumentOperations.iter_doc_ids(data_type),
doc_ids=list(experiment.DocumentOperations.iter_doc_ids(data_type)),
desc="opinion")

samples_repo.populate(opinion_provider=opinion_provider,
doc_ids_iter=experiment.DocumentOperations.iter_doc_ids(data_type),
doc_ids=list(experiment.DocumentOperations.iter_doc_ids(data_type)),
desc="sample")

if experiment.ExperimentIO.balance_samples(data_type=data_type, balance=balance):
Expand Down

0 comments on commit 150535c

Please sign in to comment.