Skip to content

Commit

Permalink
Using external labels formatter (related to #100)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Apr 29, 2021
1 parent 53ff518 commit 6bc8d6a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,33 @@
from arekit.common.experiment.input.readers.sample import InputSampleReader
from arekit.common.experiment.output.opinions.converter import OutputToOpinionCollectionsConverter
from arekit.common.experiment.output.opinions.writer import save_opinion_collections
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.model.labeling.modes import LabelCalculationMode
from arekit.common.utils import join_dir_with_subfolder_name
from arekit.contrib.bert.callback import Callback
from arekit.contrib.bert.output.eval_helper import EvalHelper
from arekit.contrib.bert.output.google_bert import GoogleBertMulticlassOutput
from arekit.contrib.source.rusentrel.labels_fmt import RuSentRelLabelsFormatter

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class LanguageModelExperimentEvaluator(ExperimentEngine):

def __init__(self, experiment, data_type, eval_helper, max_epochs_count, eval_last_only=True):
def __init__(self, experiment, data_type, eval_helper, max_epochs_count,
labels_formatter, eval_last_only=True):
assert(isinstance(eval_helper, EvalHelper))
assert(isinstance(max_epochs_count, int))
assert(isinstance(eval_last_only, bool))
assert(isinstance(labels_formatter, StringLabelsFormatter))

super(LanguageModelExperimentEvaluator, self).__init__(experiment=experiment)

self.__data_type = data_type
self.__eval_helper = eval_helper
self.__max_epochs_count = max_epochs_count
self.__eval_last_only = eval_last_only
self.__labels_formatter = labels_formatter

def _log_info(self, message, forced=False):
assert(isinstance(message, unicode))
Expand Down Expand Up @@ -81,7 +84,6 @@ def _handle_iteration(self, iter_index):
row_id_provider = MultipleIDProvider()
# TODO. This should be removed as this is a part of the particular
# experiment, not source!.
labels_formatter = RuSentRelLabelsFormatter()
cmp_doc_ids_set = set(self._experiment.DocumentOperations.iter_doc_ids_to_compare())

if callback.check_log_exists():
Expand Down Expand Up @@ -134,7 +136,7 @@ def _handle_iteration(self, iter_index):
self._experiment.DataIO.OpinionFormatter.save_to_file(
collection=collection,
filepath=filepath,
labels_formatter=labels_formatter))
labels_formatter=self.__labels_formatter))

# evaluate
result = self._experiment.evaluate(data_type=self.__data_type,
Expand Down

0 comments on commit 6bc8d6a

Please sign in to comment.