Skip to content

Commit

Permalink
#216 refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 11, 2021
1 parent c5794b3 commit 29193d7
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 23 deletions.
2 changes: 1 addition & 1 deletion arekit/common/frames/connotations/provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class FrameConnotationProvider(object):

def try_get_frame_sentiment_polarity(self, frame_id):
def try_provide(self, frame_id):
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, collection):
assert(isinstance(collection, RuSentiFramesCollection))
self.__collection = collection

def try_get_frame_sentiment_polarity(self, frame_id):
def try_provide(self, frame_id):
return self.__collection.try_get_frame_polarity(frame_id=frame_id,
role_src='a0',
role_dest='a1')
2 changes: 1 addition & 1 deletion arekit/contrib/networks/core/ctx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __read_for_data_type(samples_view, is_external_vocab,
obj_ind=row.ObjectIndex,
words_vocab=vocab,
frame_inds=row.TextFrameVariantIndices,
frame_sent_roles=row.TextFrameVariantRoles,
frame_sent_roles=row.TextFrameConnotations,
syn_obj_inds=row.SynonymObjectInds,
syn_subj_inds=row.SynonymSubjectInds,
input_shapes=input_shapes,
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/networks/core/input/const.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Additional input columns
FrameVariantIndices = "frames"
FrameRoles = "frame_roles_uint"
FrameConnotations = "frame_connots_uint"
SynonymObject = "syn_objs"
SynonymSubject = "syn_subjs"
Entities = "entities"
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/networks/core/input/providers/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_columns_list_with_types(self):

# insert indices
dtypes_list.append((const.FrameVariantIndices, str))
dtypes_list.append((const.FrameRoles, str))
dtypes_list.append((const.FrameConnotations, str))
dtypes_list.append((const.SynonymSubject, str))
dtypes_list.append((const.SynonymObject, str))
dtypes_list.append((const.Entities, str))
Expand Down
12 changes: 6 additions & 6 deletions arekit/contrib/networks/core/input/providers/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from arekit.common.news.parsed.base import ParsedNews
from arekit.contrib.networks.core.input import const
from arekit.contrib.networks.core.input.formatters.pos_mapper import PosTermsMapper
from arekit.contrib.networks.features.term_frame_roles import FrameRoleFeatures
from arekit.contrib.networks.features.term_connotation import FrameConnotationFeatures


class NetworkSampleRowProvider(BaseSampleRowProvider):
Expand Down Expand Up @@ -50,11 +50,11 @@ def _fill_row_core(self, row, linked_wrap, index_in_linked, etalon_label,
# Compose frame indices.
uint_frame_inds = list(self.__iter_indices(terms=terms, filter=lambda t: isinstance(t, TextFrameVariant)))

# Compose frame sentiment.
uint_frame_roles = list(
map(lambda variant: FrameRoleFeatures.extract_uint_frame_variant_sentiment_role(
# Compose frame connotations.
uint_frame_connotations = list(
map(lambda variant: FrameConnotationFeatures.extract_uint_frame_variant_connotation(
text_frame_variant=variant,
frames_connotation_provider=self.__frames_collection,
frames_connotation_provider=self.__frames_connotation_provider,
three_label_scaler=self.__frame_role_label_scaler),
[terms[frame_ind] for frame_ind in uint_frame_inds]))

Expand All @@ -72,7 +72,7 @@ def _fill_row_core(self, row, linked_wrap, index_in_linked, etalon_label,

# Saving.
row[const.FrameVariantIndices] = self.__to_arg(uint_frame_inds)
row[const.FrameRoles] = self.__to_arg(uint_frame_roles)
row[const.FrameConnotations] = self.__to_arg(uint_frame_connotations)
row[const.SynonymSubject] = self.__to_arg(uint_syn_s_inds)
row[const.SynonymObject] = self.__to_arg(uint_syn_t_inds)
row[const.Entities] = self.__to_arg(entity_inds)
Expand Down
6 changes: 3 additions & 3 deletions arekit/contrib/networks/core/input/rows_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __process_int_values_list(value):
const.T_IND: lambda value: value,
network_input_const.FrameVariantIndices: lambda value:
__process_indices_list(value) if isinstance(value, str) else empty_list,
network_input_const.FrameRoles: lambda value:
network_input_const.FrameConnotations: lambda value:
__process_indices_list(value) if isinstance(value, str) else empty_list,
network_input_const.SynonymObject: lambda value: __process_indices_list(value),
network_input_const.SynonymSubject: lambda value: __process_indices_list(value),
Expand Down Expand Up @@ -83,8 +83,8 @@ def TextFrameVariantIndices(self):
return self.__params[network_input_const.FrameVariantIndices]

@property
def TextFrameVariantRoles(self):
return self.__params[network_input_const.FrameRoles]
def TextFrameConnotations(self):
return self.__params[network_input_const.FrameConnotations]

@property
def EntityInds(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from arekit.contrib.networks.features.utils import create_filled_array


class FrameRoleFeatures(object):
class FrameConnotationFeatures(object):

@ staticmethod
def to_input(frame_inds, frame_sent_roles, size, filler):
Expand All @@ -23,21 +23,20 @@ def to_input(frame_inds, frame_sent_roles, size, filler):
return vector

@staticmethod
def extract_uint_frame_variant_sentiment_role(text_frame_variant, frames_connotation_provider, three_label_scaler):
def extract_uint_frame_variant_connotation(text_frame_variant, frames_connotation_provider, three_label_scaler):
assert(isinstance(text_frame_variant, TextFrameVariant))
assert(isinstance(frames_connotation_provider, FrameConnotationProvider))
assert(isinstance(three_label_scaler, BaseLabelScaler))

frame_id = text_frame_variant.Variant.FrameID
polarity = frames_connotation_provider.try_get_frame_sentiment_polarity(frame_id)
connot_descriptor = frames_connotation_provider.try_provide(frame_id)

if polarity is None:
if connot_descriptor is None:
return three_label_scaler.label_to_uint(label=three_label_scaler.get_no_label_instance())

assert(isinstance(polarity, FrameConnotationDescriptor))
assert(isinstance(connot_descriptor, FrameConnotationDescriptor))

if text_frame_variant.IsInverted:
inv_label = three_label_scaler.invert_label(polarity.Label)
return three_label_scaler.label_to_uint(label=inv_label)
target_label = three_label_scaler.invert_label(connot_descriptor.Label) \
if text_frame_variant.IsInverted else connot_descriptor.Label

return three_label_scaler.label_to_uint(polarity.Label)
return three_label_scaler.label_to_uint(target_label)

0 comments on commit 29193d7

Please sign in to comment.