Skip to content

Commit

Permalink
Refactoring: now config weights are explicitly initialized. Prodeeded…
Browse files Browse the repository at this point in the history
… to #199.
  • Loading branch information
nicolay-r committed Sep 23, 2021
1 parent e7ac731 commit 86b5a35
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
38 changes: 27 additions & 11 deletions arekit/contrib/networks/core/data_handling/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, labeled_collections, bags_collection):
assert(isinstance(bags_collection, dict))
self.__labeled_collections = labeled_collections
self.__bags_collection = bags_collection
self.__train_stat_uint_labeled_sample_row_ids = None

# region Properties

Expand All @@ -35,6 +36,10 @@ def BagsCollections(self):
def SamplesLabelingCollection(self):
return self.__labeled_collections

@property
def HasNormalizedWeights(self):
return self.__train_stat_uint_labeled_sample_row_ids is not None

# endregion

@staticmethod
Expand All @@ -53,7 +58,9 @@ def create_empty(cls):
bags_collection={})

def initialize(self, dtypes, create_samples_reader_func, has_model_predefined_state,
vocab, labels_count, bags_collection_type, config):
vocab, labels_count, bags_collection_type,
# TODO: 199. Remove config.
config):
"""
Perform reading information from the serialized experiment inputs.
Initializing core configuration.
Expand All @@ -63,8 +70,6 @@ def initialize(self, dtypes, create_samples_reader_func, has_model_predefined_st
assert(isinstance(has_model_predefined_state, bool))
assert(isinstance(labels_count, int) and labels_count > 0)

stat_uint_labeled_sample_row_ids = None

# Reading from serialized information
for data_type in dtypes:

Expand All @@ -77,6 +82,7 @@ def initialize(self, dtypes, create_samples_reader_func, has_model_predefined_st
is_external_vocab=has_model_predefined_state,
bags_collection_type=bags_collection_type,
vocab=vocab,
# TODO: 199. Remove config.
config=config,
desc="Filling bags collection [{}]".format(data_type))

Expand All @@ -86,14 +92,19 @@ def initialize(self, dtypes, create_samples_reader_func, has_model_predefined_st
uint_labeled_sample_row_ids=uint_labeled_sample_row_ids)

if data_type == DataType.Train:
stat_uint_labeled_sample_row_ids = uint_labeled_sample_row_ids
self.__train_stat_uint_labeled_sample_row_ids = uint_labeled_sample_row_ids

def calc_normalized_weigts(self, labels_count):
assert(isinstance(labels_count, int) and labels_count > 0)

if self.__train_stat_uint_labeled_sample_row_ids is None:
return

normalized_label_stat, _ = calculate_labels_distribution_stat(
uint_labeled_sample_row_ids=self.__train_stat_uint_labeled_sample_row_ids,
classes_count=labels_count)

# Calculate class weights.
if stat_uint_labeled_sample_row_ids is not None:
normalized_label_stat, _ = calculate_labels_distribution_stat(
uint_labeled_sample_row_ids=stat_uint_labeled_sample_row_ids,
classes_count=labels_count)
config.set_class_weights(normalized_label_stat)
return normalized_label_stat

# region writing methods

Expand Down Expand Up @@ -156,9 +167,12 @@ def __iter_parsed_news_func(doc_ops, data_type):

@staticmethod
def __read_for_data_type(samples_reader, is_external_vocab,
bags_collection_type, vocab, config, desc=""):
bags_collection_type, vocab,
# TODO: 199. Remove config.
config, desc=""):
assert(isinstance(samples_reader, BaseInputSampleReader))

# TODO: 199. Use shapes.
terms_per_context = config.TermsPerContext
frames_per_context = config.FramesPerContext
synonyms_per_context = config.SynonymsPerContext
Expand All @@ -169,6 +183,7 @@ def __read_for_data_type(samples_reader, is_external_vocab,
bag_size=config.BagSize,
shuffle=True,
create_empty_sample_func=lambda: InputSample.create_empty(
# TODO: 199. Use shapes.
terms_per_context=terms_per_context,
frames_per_context=frames_per_context,
synonyms_per_context=synonyms_per_context),
Expand All @@ -184,6 +199,7 @@ def __read_for_data_type(samples_reader, is_external_vocab,
frame_sent_roles=row.TextFrameVariantRoles,
syn_obj_inds=row.SynonymObjectInds,
syn_subj_inds=row.SynonymSubjectInds,
# TODO: 199. Use shapes.
terms_per_context=terms_per_context,
frames_per_context=frames_per_context,
synonyms_per_context=synonyms_per_context,
Expand Down
5 changes: 5 additions & 0 deletions arekit/contrib/networks/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ def _handle_iteration(self, it_index):
labels_count=self._experiment.DataIO.LabelsCount,
vocab=vocab,
bags_collection_type=self.__bags_collection_type,
# TODO: 199. Remove config.
config=self.__config)

if handled_data.HasNormalizedWeights:
weights = handled_data.calc_normalized_weigts(labels_count=self._experiment.DataIO.LabelsCount)
self.__config.set_class_weights(weights)

# Update parameters after iteration preparation has been completed.
self.__config.reinit_config_dependent_parameters()

Expand Down
2 changes: 2 additions & 0 deletions arekit/contrib/networks/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def create_empty(cls, terms_per_context, frames_per_context, synonyms_per_contex
frame_indices=blank_frames,
input_sample_id="1")

# TODO. Refactoring #199.
@classmethod
def _generate_test(cls, config):
assert(isinstance(config, DefaultNetworkConfig))
Expand Down Expand Up @@ -164,6 +165,7 @@ def create_from_parameters(cls,
subj_ind,
obj_ind,
words_vocab, # for indexing input (all the vocabulary, obtained from offsets.py)
# TODO: 199. Use shapes.
terms_per_context, # for terms_per_context, frames_per_context.
frames_per_context,
synonyms_per_context,
Expand Down
3 changes: 2 additions & 1 deletion examples/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def extract(text):
vocab=None,
labels_count=3,
bags_collection_type=SingleBagsCollection,
config=None, # TODO. Конфигурация сети.
# TODO: 199. Remove config.
config=None,
)

############################
Expand Down

0 comments on commit 86b5a35

Please sign in to comment.