Skip to content

Commit

Permalink
#239 done. Refactoring #161
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 26, 2021
1 parent b044e0c commit 846c4a3
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 30 deletions.
2 changes: 1 addition & 1 deletion arekit/contrib/experiment_rusentrel/exp_sl/opinions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_result_opinion_collection(self, doc_id, data_type, epoch_index):

# region private provider methods

def __create_collection(self, opinions):
def __create_collection(self, opinions=None):
return OpinionCollection(opinions=[] if opinions is None else opinions,
synonyms=self.__get_synonyms_func(),
error_on_duplicates=True,
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/networks/core/input/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __perform_writing(experiment, data_type, opinion_provider,
# endregion

@staticmethod
def prepare(experiment, terms_per_context, balance):
def prepare(experiment, terms_per_context, balance, value_to_group_id_func=None):
assert(isinstance(experiment, BaseExperiment))
assert(isinstance(terms_per_context, int))
assert(isinstance(balance, bool))
Expand All @@ -135,7 +135,7 @@ def prepare(experiment, terms_per_context, balance):
labels_formatter=experiment.OpinionOperations.LabelsFormatter)

opinion_provider = OpinionProvider.create(
value_to_group_id_func=None, # TODO. Remove this parameter.
value_to_group_id_func=value_to_group_id_func, # TODO. Remove this parameter.
parse_news_func=lambda doc_id: experiment.DocumentOperations.parse_doc(doc_id),
iter_news_opins_for_extraction=lambda doc_id:
experiment.OpinionOperations.iter_opinions_for_extraction(doc_id=doc_id, data_type=data_type),
Expand Down
60 changes: 50 additions & 10 deletions examples/network/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,26 @@ def get_doc(self, doc_id):

class CustomOpinionOperations(OpinionOperations):

def __init__(self, labels_formatter, iter_opins, synonyms):
def __init__(self, labels_formatter, exp_io, synonyms, neutral_labels_fmt):
super(CustomOpinionOperations, self).__init__()
self.__labels_formatter = labels_formatter
self.__iter_opins = iter_opins
self.__exp_io = exp_io
self.__synonyms = synonyms
self.__neutral_labels_fmt = neutral_labels_fmt

@property
def LabelsFormatter(self):
return self.__labels_formatter

def iter_opinions_for_extraction(self, doc_id, data_type):
return self.__iter_opins
# Reading automatically annotated collection of neutral opinions.
return self.__exp_io.read_opinion_collection(
target=self.__exp_io.create_result_opinion_collection_target(
doc_id=doc_id,
data_type=data_type,
epoch_index=0),
labels_formatter=self.__neutral_labels_fmt,
create_collection_func=self.create_opinion_collection)

def get_etalon_opinion_collection(self, doc_id):
return self.create_opinion_collection()
Expand All @@ -72,15 +80,47 @@ def create_opinion_collection(self):
error_on_synonym_end_missed=True)


class CustomIOUtils(NetworkIOUtils):

def __create_target(self, doc_id, data_type, epoch_index):
return "data/result_d{doc_id}_{data_type}_e{epoch_index}.txt".format(
doc_id=doc_id,
data_type=data_type.name,
epoch_index=epoch_index)

def get_experiment_sources_dir(self):
return "data"

def create_opinion_collection_target(self, doc_id, data_type, check_existance=False):
self.__create_target(doc_id=doc_id,
data_type=data_type,
epoch_index=0)

def create_result_opinion_collection_target(self, doc_id, data_type, epoch_index):
self.__create_target(doc_id=doc_id,
data_type=data_type,
epoch_index=epoch_index)


class CustomExperiment(BaseExperiment):

def __init__(self, exp_data, exp_io_type, opin_ops, doc_ops):
super(CustomExperiment, self).__init__(exp_data=exp_data,
experiment_io=exp_io_type(self),
opin_ops=opin_ops,
doc_ops=doc_ops,
name="test",
extra_name_suffix="test")
def __init__(self, exp_data, synonyms, doc_ops, labels_formatter, neutral_labels_fmt):

exp_io = CustomIOUtils(self)

opin_ops = CustomOpinionOperations(
labels_formatter=labels_formatter,
exp_io=exp_io,
synonyms=synonyms,
neutral_labels_fmt=neutral_labels_fmt)

super(CustomExperiment, self).__init__(
exp_data=exp_data,
experiment_io=exp_io,
opin_ops=opin_ops,
doc_ops=doc_ops,
name="test",
extra_name_suffix="test")


class CustomSerializationData(NetworkSerializationData):
Expand Down
27 changes: 10 additions & 17 deletions examples/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,32 +70,25 @@ def pipeline_serialize(sentences_text_list, label_provider):
frame_variants=frame_variants_collection)
])

parsed_news = NewsParser.parse(news=news, text_parser=text_parser)

opins_for_extraction = annot_algo.iter_opinions(parsed_news=parsed_news)

doc_ops = SingleDocOperations(news=news, text_parser=text_parser)

labels_formatter = StringLabelsFormatter(stol={"neu": NoLabel})

opin_ops = CustomOpinionOperations(labels_formatter=labels_formatter,
iter_opins=opins_for_extraction,
synonyms=synonyms)

exp_data = CustomSerializationData(label_scaler=label_provider.LabelScaler,
stemmer=stemmer,
annot=ThreeScaleTaskAnnotator(annot_algo=annot_algo),
frame_variants_collection=frame_variants_collection)

labels_fmt = StringLabelsFormatter(stol={"neu": NoLabel})

# Step 3. Serialize data
experiment = CustomExperiment(exp_data=exp_data,
exp_io_type=CustomNetworkIOUtils,
doc_ops=doc_ops,
opin_ops=opin_ops)
experiment = CustomExperiment(
exp_data=exp_data,
doc_ops=SingleDocOperations(news=news, text_parser=text_parser),
labels_formatter=labels_fmt,
synonyms=synonyms,
neutral_labels_fmt=labels_fmt)

NetworkInputHelper.prepare(experiment=experiment,
terms_per_context=50,
balance=False)
balance=False,
value_to_group_id_func=synonyms.get_synonym_group_index)


if __name__ == '__main__':
Expand Down

0 comments on commit 846c4a3

Please sign in to comment.