Skip to content

Commit

Permalink
Refactoring #105.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Apr 29, 2021
1 parent 11489ca commit 53ff518
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 25 deletions.
2 changes: 1 addition & 1 deletion common/experiment/extract/opinions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def fill_opinion_collection(collection, linked_data_iter, labels_helper, to_opin

agg_label = labels_helper.aggregate_labels(
labels_list=list(linked.iter_labels()),
label_creation_mode=label_calc_mode)
label_calc_mode=label_calc_mode)

agg_opinion = to_opinion_func(linked.First, agg_label)

Expand Down
2 changes: 1 addition & 1 deletion common/model/labeling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def label_to_uint(self, label):
def get_classes_count(self):
return len(self._label_scaler.ordered_suppoted_labels())

def aggregate_labels(self, labels_list, label_creation_mode):
def aggregate_labels(self, labels_list, label_calc_mode):
raise NotImplementedError()

@staticmethod
Expand Down
19 changes: 5 additions & 14 deletions common/model/labeling/modes.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
# TODO: Use enum instead.
class LabelCalculationMode:
from enum import Enum

FIRST_APPEARED = u'take_first_appeared'
AVERAGE = u'average'

@staticmethod
def supported(value):
for s in LabelCalculationMode.__iter_supported():
if s == value:
return True
return False
class LabelCalculationMode(Enum):

@staticmethod
def __iter_supported():
for var_name in dir(LabelCalculationMode):
yield getattr(LabelCalculationMode, var_name)
FIRST_APPEARED = u'take_first_appeared'

AVERAGE = u'average'
9 changes: 4 additions & 5 deletions common/model/labeling/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@

class SingleLabelsHelper(LabelsHelper):

def aggregate_labels(self, labels_list, label_creation_mode):
def aggregate_labels(self, labels_list, label_calc_mode):
assert(isinstance(labels_list, list))
assert(isinstance(label_creation_mode, unicode))
assert(LabelCalculationMode.supported(label_creation_mode))
assert(label_calc_mode, LabelCalculationMode)

label = None

if label_creation_mode == LabelCalculationMode.FIRST_APPEARED:
if label_calc_mode == LabelCalculationMode.FIRST_APPEARED:
label = labels_list[0]

if label_creation_mode == LabelCalculationMode.AVERAGE:
if label_calc_mode == LabelCalculationMode.AVERAGE:
int_labels = [self._label_scaler.label_to_int(label)
for label in labels_list]
label = self._label_scaler.int_to_label(np.sign(sum(int_labels)))
Expand Down
10 changes: 6 additions & 4 deletions contrib/networks/core/callback/utils_model_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@


def evaluate_model(experiment, data_type, epoch_index, model,
labels_formatter, save_hidden_params, log_dir):
labels_formatter, save_hidden_params,
label_calc_mode, log_dir):
""" Performs Model Evaluation on a particular state (i.e. epoch),
for a particular data type.
"""
Expand Down Expand Up @@ -72,6 +73,7 @@ def evaluate_model(experiment, data_type, epoch_index, model,
data_type=data_type,
epoch_index=epoch_index,
result_filepath=result_filepath,
label_calc_mode=label_calc_mode,
labels_formatter=labels_formatter)

# Evaluate.
Expand All @@ -91,14 +93,15 @@ def evaluate_model(experiment, data_type, epoch_index, model,

def __convert_output_to_opinion_collections(exp_io, opin_ops, doc_ops, labels_scaler, opin_fmt,
result_filepath, data_type, epoch_index,
labels_formatter):
label_calc_mode, labels_formatter):
assert(isinstance(opin_ops, OpinionOperations))
assert(isinstance(doc_ops, DocumentOperations))
assert(isinstance(labels_scaler, BaseLabelScaler))
assert(isinstance(exp_io, NetworkIOUtils))
assert(isinstance(opin_fmt, OpinionCollectionsFormatter))
assert(isinstance(data_type, DataType))
assert(isinstance(epoch_index, int))
assert(isinstance(label_calc_mode, LabelCalculationMode))
assert(isinstance(labels_formatter, StringLabelsFormatter))

opinions_source = exp_io.get_input_opinions_filepath(data_type=data_type)
Expand All @@ -112,8 +115,7 @@ def __convert_output_to_opinion_collections(exp_io, opin_ops, doc_ops, labels_sc
labels_scaler=labels_scaler,
create_opinion_collection_func=opin_ops.create_opinion_collection,
keep_doc_id_func=lambda doc_id: doc_id in cmp_doc_ids_set,
# TODO. bring this onto parameters level.
label_calculation_mode=LabelCalculationMode.AVERAGE,
label_calculation_mode=label_calc_mode,
output=MulticlassOutput(labels_scaler=labels_scaler,
has_output_header=True))

Expand Down

0 comments on commit 53ff518

Please sign in to comment.