Skip to content

Commit

Permalink
#304 refactoring. Done.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed May 26, 2022
1 parent 0a7552a commit 1222f7d
Showing 1 changed file with 20 additions and 28 deletions.
48 changes: 20 additions & 28 deletions arekit/contrib/networks/core/input/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import collections
import logging
from collections import OrderedDict

from arekit.common.data.input.providers.columns.opinion import OpinionColumnsProvider
from arekit.common.data.input.providers.columns.sample import SampleColumnsProvider
Expand Down Expand Up @@ -28,31 +27,12 @@ class NetworkInputHelper(object):
# region private methods

@staticmethod
def __create_text_provider(term_embedding_pairs, exp_ctx):
assert(isinstance(exp_ctx, NetworkSerializationContext))
assert(isinstance(term_embedding_pairs, OrderedDict))

terms_with_embeddings_terms_mapper = StringWithEmbeddingNetworkTermMapping(
predefined_embedding=exp_ctx.WordEmbedding,
string_entities_formatter=exp_ctx.StringEntityFormatter,
string_emb_entity_formatter=exp_ctx.StringEntityEmbeddingFormatter)

return NetworkSingleTextProvider(
text_terms_mapper=terms_with_embeddings_terms_mapper,
pair_handling_func=lambda pair: NetworkInputHelper.__add_term_embedding(
dict_data=term_embedding_pairs,
term=pair[0],
emb_vector=pair[1]))

@staticmethod
def __create_samples_repo(exp_ctx, term_embedding_pairs):
def __create_samples_repo(exp_ctx, text_provider):
assert(isinstance(exp_ctx, NetworkSerializationContext))

sample_row_provider = NetworkSampleRowProvider(
label_provider=exp_ctx.LabelProvider,
text_provider=NetworkInputHelper.__create_text_provider(
term_embedding_pairs=term_embedding_pairs,
exp_ctx=exp_ctx),
text_provider=text_provider,
frames_connotation_provider=exp_ctx.FramesConnotationProvider,
frame_role_label_scaler=exp_ctx.FrameRolesLabelScaler,
pos_terms_mapper=PosTermsMapper(exp_ctx.PosTagger))
Expand All @@ -77,17 +57,17 @@ def __add_term_embedding(dict_data, term, emb_vector):

@staticmethod
def __perform_writing(exp_ctx, exp_io, doc_ops, data_type, opinion_provider,
terms_per_context, balance, term_embedding_pairs):
"""
Perform experiment input serialization
terms_per_context, balance, text_provider):
""" Perform experiment input serialization
"""
assert(isinstance(data_type, DataType))
assert(isinstance(terms_per_context, int))
assert(isinstance(balance, bool))

opinions_repo = NetworkInputHelper.__create_opinions_repo()
samples_repo = NetworkInputHelper.__create_samples_repo(exp_ctx=exp_ctx,
term_embedding_pairs=term_embedding_pairs)

samples_repo = NetworkInputHelper.__create_samples_repo(
exp_ctx=exp_ctx, text_provider=text_provider)

# Populate repositories
opinions_repo.populate(opinion_provider=opinion_provider,
Expand Down Expand Up @@ -132,6 +112,18 @@ def prepare(exp_ctx, exp_io, doc_ops, opin_ops, terms_per_context, balance, valu

term_embedding_pairs = collections.OrderedDict()

text_terms_mapper = StringWithEmbeddingNetworkTermMapping(
predefined_embedding=exp_ctx.WordEmbedding,
string_entities_formatter=exp_ctx.StringEntityFormatter,
string_emb_entity_formatter=exp_ctx.StringEntityEmbeddingFormatter)

text_provider = NetworkSingleTextProvider(
text_terms_mapper=text_terms_mapper,
pair_handling_func=lambda pair: NetworkInputHelper.__add_term_embedding(
dict_data=term_embedding_pairs,
term=pair[0],
emb_vector=pair[1]))

for data_type in exp_ctx.DataFolding.iter_supported_data_types():

# Perform annotation
Expand All @@ -158,7 +150,7 @@ def prepare(exp_ctx, exp_io, doc_ops, opin_ops, terms_per_context, balance, valu
opinion_provider=opinion_provider,
terms_per_context=terms_per_context,
balance=balance,
term_embedding_pairs=term_embedding_pairs)
text_provider=text_provider)

# Save embedding information additionally.
term_embedding = Embedding.from_word_embedding_pairs_iter(iter(term_embedding_pairs.items()))
Expand Down

0 comments on commit 1222f7d

Please sign in to comment.