Skip to content

Commit

Permalink
Implemented #199
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Sep 23, 2021
1 parent 2572426 commit 1b5f9b3
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 46 deletions.
29 changes: 7 additions & 22 deletions arekit/contrib/networks/core/data_handling/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def create_empty(cls):
bags_collection={})

def initialize(self, dtypes, create_samples_reader_func, has_model_predefined_state,
vocab, labels_count, bags_collection_type,
# TODO: 199. Remove config.
config):
vocab, labels_count, bags_collection_type, bag_size, input_shapes, config):
"""
Perform reading information from the serialized experiment inputs.
Initializing core configuration.
Expand All @@ -68,8 +66,8 @@ def initialize(self, dtypes, create_samples_reader_func, has_model_predefined_st
is_external_vocab=has_model_predefined_state,
bags_collection_type=bags_collection_type,
vocab=vocab,
# TODO: 199. Remove config.
config=config,
bag_size=bag_size,
input_shapes=input_shapes,
desc="Filling bags collection [{}]".format(data_type))

# Saving into dictionaries.
Expand Down Expand Up @@ -97,25 +95,15 @@ def calc_normalized_weigts(self, labels_count):
@staticmethod
def __read_for_data_type(samples_reader, is_external_vocab,
bags_collection_type, vocab,
# TODO: 199. Remove config.
config, desc=""):
bag_size, input_shapes, desc=""):
assert(isinstance(samples_reader, BaseInputSampleReader))

# TODO: 199. Use shapes.
terms_per_context = config.TermsPerContext
frames_per_context = config.FramesPerContext
synonyms_per_context = config.SynonymsPerContext

bags_collection = bags_collection_type.from_formatted_samples(
formatted_samples_iter=samples_reader.iter_rows_linked_by_text_opinions(),
desc=desc,
bag_size=config.BagSize,
bag_size=bag_size,
shuffle=True,
create_empty_sample_func=lambda: InputSample.create_empty(
# TODO: 199. Use shapes.
terms_per_context=terms_per_context,
frames_per_context=frames_per_context,
synonyms_per_context=synonyms_per_context),
create_empty_sample_func=lambda: InputSample.create_empty(input_shapes),
create_sample_func=lambda row: InputSample.create_from_parameters(
input_sample_id=row.SampleID,
terms=row.Terms,
Expand All @@ -128,10 +116,7 @@ def __read_for_data_type(samples_reader, is_external_vocab,
frame_sent_roles=row.TextFrameVariantRoles,
syn_obj_inds=row.SynonymObjectInds,
syn_subj_inds=row.SynonymSubjectInds,
# TODO: 199. Use shapes.
terms_per_context=terms_per_context,
frames_per_context=frames_per_context,
synonyms_per_context=synonyms_per_context,
input_shapes=input_shapes,
pos_tags=row.PartOfSpeechTags))

rows_it = NetworkInputSampleReaderHelper.iter_uint_labeled_sample_rows(samples_reader)
Expand Down
9 changes: 7 additions & 2 deletions arekit/contrib/networks/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from arekit.contrib.networks.core.feeding.bags.collection.base import BagsCollection
from arekit.contrib.networks.core.model import BaseTensorflowModel
from arekit.contrib.networks.core.params import NeuralNetworkModelParams
from arekit.contrib.networks.shapes import NetworkInputShapes

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -83,8 +84,12 @@ def _handle_iteration(self, it_index):
labels_count=self._experiment.DataIO.LabelsCount,
vocab=vocab,
bags_collection_type=self.__bags_collection_type,
# TODO: 199. Remove config.
config=self.__config)
input_shapes=NetworkInputShapes(iter_pairs=[
(NetworkInputShapes.FRAMES_PER_CONTEXT, self.__config.FramesPerContext),
(NetworkInputShapes.TERMS_PER_CONTEXT, self.__config.TermsPerContext),
(NetworkInputShapes.SYNONYMS_PER_CONTEXT, self.__config.SynonymsPerContext),
]),
bag_size=self.__config.BagSize)

if handled_data.HasNormalizedWeights:
weights = handled_data.calc_normalized_weigts(labels_count=self._experiment.DataIO.LabelsCount)
Expand Down
28 changes: 13 additions & 15 deletions arekit/contrib/networks/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from arekit.contrib.networks.features.term_indices import IndicesFeature
from arekit.contrib.networks.features.term_types import calculate_term_types
from arekit.contrib.networks.features.utils import pad_right_or_crop_inplace
from arekit.contrib.networks.shapes import NetworkInputShapes


class InputSample(InputSampleBase):
Expand Down Expand Up @@ -97,14 +98,12 @@ def __init__(self,
# region class methods

@classmethod
def create_empty(cls, terms_per_context, frames_per_context, synonyms_per_context):
assert(isinstance(terms_per_context, int))
assert(isinstance(frames_per_context, int))
assert(isinstance(synonyms_per_context, int))

blank_synonyms = np.zeros(synonyms_per_context)
blank_terms = np.zeros(terms_per_context)
blank_frames = np.full(shape=frames_per_context,
def create_empty(cls, input_shapes):
assert(isinstance(input_shapes, NetworkInputShapes))

blank_synonyms = np.zeros(input_shapes.get_shape(input_shapes.SYNONYMS_PER_CONTEXT))
blank_terms = np.zeros(input_shapes.get_shape(input_shapes.TERMS_PER_CONTEXT))
blank_frames = np.full(shape=input_shapes.get_shape(input_shapes.FRAMES_PER_CONTEXT),
fill_value=cls.FRAMES_PAD_VALUE)
return cls(X=blank_terms,
subj_ind=0,
Expand Down Expand Up @@ -165,10 +164,7 @@ def create_from_parameters(cls,
subj_ind,
obj_ind,
words_vocab, # for indexing input (all the vocabulary, obtained from offsets.py)
# TODO: 199. Use shapes.
terms_per_context, # for terms_per_context, frames_per_context.
frames_per_context,
synonyms_per_context,
input_shapes,
syn_subj_inds,
syn_obj_inds,
frame_inds,
Expand All @@ -185,9 +181,7 @@ def create_from_parameters(cls,
assert(isinstance(words_vocab, dict))
assert(isinstance(subj_ind, int) and 0 <= subj_ind < len(terms))
assert(isinstance(obj_ind, int) and 0 <= obj_ind < len(terms))
assert(isinstance(terms_per_context, int))
assert(isinstance(frames_per_context, int))
assert(isinstance(synonyms_per_context, int))
assert(isinstance(input_shapes, NetworkInputShapes))
assert(isinstance(syn_subj_inds, list))
assert(isinstance(syn_obj_inds, list))
assert(isinstance(pos_tags, list))
Expand All @@ -208,6 +202,10 @@ def get_end_offset():
# Composing vectors
x_indices = np.array([cls.__get_index_by_term(term, words_vocab, is_external_vocab) for term in terms])

terms_per_context = input_shapes.get_shape(input_shapes.TERMS_PER_CONTEXT)
synonyms_per_context = input_shapes.get_shape(input_shapes.SYNONYMS_PER_CONTEXT)
frames_per_context = input_shapes.get_shape(input_shapes.FRAMES_PER_CONTEXT)

# Check an ability to create sample by analyzing required window size.
window_size = terms_per_context
dist_between_entities = TextOpinionHelper.calc_dist_between_text_opinion_end_indices(
Expand Down
14 changes: 14 additions & 0 deletions arekit/contrib/networks/shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class NetworkInputShapes(object):

SYNONYMS_PER_CONTEXT = "synonyms_per_context"
TERMS_PER_CONTEXT = "terms_per_context"
FRAMES_PER_CONTEXT = "frames_per_context"

def __init__(self, iter_pairs):
assert(isinstance(iter_pairs, dict))
self.__d = dict()
for key, value in iter_pairs:
self.__d[key] = value

def get_shape(self, key):
return self.__d[key]
19 changes: 12 additions & 7 deletions examples/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from arekit.contrib.networks.core.model import BaseTensorflowModel
from arekit.contrib.networks.core.model_io import NeuralNetworkModelIO
from arekit.contrib.networks.core.predict.tsv_provider import TsvPredictProvider
from arekit.contrib.networks.shapes import NetworkInputShapes
from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection
from arekit.contrib.source.rusentiframes.types import RuSentiFramesVersions
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions
Expand Down Expand Up @@ -133,7 +134,9 @@ def extract(text):

handled_data = HandledData.create_empty()

# TODO. Provide samples reader.
network = PiecewiseCNN()
config = CNNConfig()

handled_data.initialize(
dtypes=[DataType.Test],
create_samples_reader_func=TsvInputSampleReader.from_tsv(
Expand All @@ -142,10 +145,12 @@ def extract(text):
has_model_predefined_state=True,
vocab=None,
labels_count=3,
bags_collection_type=SingleBagsCollection,
# TODO: 199. Remove config.
config=None,
)
input_shapes=NetworkInputShapes(iter_pairs=[
(NetworkInputShapes.FRAMES_PER_CONTEXT, config.FramesPerContext),
(NetworkInputShapes.TERMS_PER_CONTEXT, config.TermsPerContext),
(NetworkInputShapes.SYNONYMS_PER_CONTEXT, config.SynonymsPerContext),
]),
bag_size=config.BagSize)

############################
# Step 5. Model preparation.
Expand All @@ -156,8 +161,8 @@ def extract(text):
target_dir=".model",
full_model_name="PCNN",
model_name_tag="_"),
network=PiecewiseCNN(),
config=CNNConfig(),
network=network,
config=config,
handled_data=handled_data,
bags_collection_type=SingleBagsCollection, # Используем на вход 1 пример.
)
Expand Down

0 comments on commit 1b5f9b3

Please sign in to comment.