Skip to content

Commit

Permalink
#315 fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Apr 5, 2022
1 parent d7b059e commit a9a6f8a
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions arekit/contrib/bert/handlers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,23 @@

class BertExperimentInputSerializerIterationHandler(ExperimentIterationHandler):

def __init__(self, exp_io, exp_ctx, doc_ops, opin_ops, labels_formatter,
value_to_group_id_func,
def __init__(self, exp_io, exp_ctx, doc_ops, opin_ops,
sample_labels_fmt, annot_labels_fmt, value_to_group_id_func,
sample_provider_type, entity_formatter, balance_train_samples):
assert(isinstance(exp_io, BaseIOUtils))
assert(isinstance(doc_ops, DocumentOperations))
assert(isinstance(opin_ops, OpinionOperations))
assert(isinstance(labels_formatter, StringLabelsFormatter))
assert(isinstance(sample_labels_fmt, StringLabelsFormatter))
assert(isinstance(annot_labels_fmt, StringLabelsFormatter))
assert(callable(value_to_group_id_func))
super(BertExperimentInputSerializerIterationHandler, self).__init__()

self.__value_to_group_id_func = value_to_group_id_func
self.__entity_formatter = entity_formatter
self.__sample_provider_type = sample_provider_type
self.__balance_train_samples = balance_train_samples
self.__labels_formatter = labels_formatter
self.__sample_label_formatter = sample_labels_fmt
self.__annot_label_formatter = annot_labels_fmt
self.__exp_io = exp_io
self.__exp_ctx = exp_ctx
self.__doc_ops = doc_ops
Expand All @@ -43,7 +45,7 @@ def __handle_iteration(self, data_type):

# Create samples formatter.
sample_rows_provider = create_bert_sample_provider(
labels_formatter=self.__labels_formatter,
labels_formatter=self.__sample_label_formatter,
provider_type=self.__sample_provider_type,
label_scaler=self.__exp_ctx.LabelsScaler,
entity_formatter=self.__entity_formatter)
Expand Down Expand Up @@ -112,6 +114,6 @@ def on_before_iteration(self):
self.__exp_io.write_opinion_collection(
collection=collection,
target=target,
labels_formatter=self.__labels_formatter)
labels_formatter=self.__annot_label_formatter)

# endregion

0 comments on commit a9a6f8a

Please sign in to comment.