Skip to content

Commit

Permalink
PosTags now optional
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 7, 2023
1 parent 85008db commit bca6dc1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 24 deletions.
6 changes: 4 additions & 2 deletions arekit/contrib/networks/input/ctx_serialization.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from arekit.common.data.input.providers.label.multiple import MultipleLabelProvider
from arekit.contrib.utils.processing.pos.base import POSTagger


class NetworkSerializationContext(object):

def __init__(self, labels_scaler, pos_tagger, frame_roles_label_scaler, frames_connotation_provider):
def __init__(self, labels_scaler, frame_roles_label_scaler, frames_connotation_provider, pos_tagger=None):
assert(isinstance(pos_tagger, POSTagger) or pos_tagger is None)
self.__label_provider = MultipleLabelProvider(labels_scaler)
self.__pos_tagger = pos_tagger
self.__frame_roles_label_scaler = frame_roles_label_scaler
self.__frames_connotation_provider = frames_connotation_provider
self.__pos_tagger = pos_tagger

@property
def LabelProvider(self):
Expand Down
17 changes: 0 additions & 17 deletions arekit/contrib/networks/input/providers/columns.py

This file was deleted.

10 changes: 6 additions & 4 deletions arekit/contrib/networks/input/providers/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def __init__(self,
text_provider,
frames_connotation_provider,
frame_role_label_scaler,
pos_terms_mapper):
pos_terms_mapper=None):
assert(isinstance(label_provider, LabelProvider))
assert(isinstance(pos_terms_mapper, PosTermsMapper))
assert(isinstance(frame_role_label_scaler, SentimentLabelScaler))
assert(isinstance(pos_terms_mapper, PosTermsMapper) or pos_terms_mapper is None)

super(NetworkSampleRowProvider, self).__init__(label_provider=label_provider,
text_provider=text_provider)
Expand Down Expand Up @@ -64,14 +64,16 @@ def _fill_row_core(self, row, text_opinion_linkage, index_in_linked, etalon_labe
uint_syn_t_inds = self.__create_synonyms_set(terms=terms, term_ind=actual_t_ind)

# Part of speech tags
pos_int_tags = [int(pos_tag) for pos_tag in self.__pos_terms_mapper.iter_mapped(terms)]
pos_int_tags = None if self.__pos_terms_mapper is None \
else [int(pos_tag) for pos_tag in self.__pos_terms_mapper.iter_mapped(terms)]

# Saving.
row[const.FrameVariantIndices] = self.__to_arg(uint_frame_inds)
row[const.FrameConnotations] = self.__to_arg(uint_frame_connotations)
row[const.SynonymSubject] = self.__to_arg(uint_syn_s_inds)
row[const.SynonymObject] = self.__to_arg(uint_syn_t_inds)
row[const.PosTags] = self.__to_arg(pos_int_tags)
if pos_int_tags is not None:
row[const.PosTags] = self.__to_arg(pos_int_tags)

# region private methods

Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/utils/pipelines/items/sampling/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, vectorizers, save_labels_func, str_entity_fmt, ctx,
text_provider=text_provider,
frames_connotation_provider=ctx.FramesConnotationProvider,
frame_role_label_scaler=ctx.FrameRolesLabelScaler,
pos_terms_mapper=PosTermsMapper(ctx.PosTagger))
pos_terms_mapper=PosTermsMapper(ctx.PosTagger) if ctx.PosTagger is not None else None)

@staticmethod
def __add_term_embedding(dict_data, term, emb_vector):
Expand Down

0 comments on commit bca6dc1

Please sign in to comment.