Skip to content

Commit

Permalink
#335 fixed and #308 . Considering a function which calcuales the vect…
Browse files Browse the repository at this point in the history
…or from parts rather than seeking such in an Embedding
  • Loading branch information
nicolay-r committed Jun 22, 2022
1 parent c4ee4bb commit 4e799d6
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 37 deletions.
4 changes: 0 additions & 4 deletions arekit/contrib/networks/core/input/ctx_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,4 @@ def WordEmbedding(self):

@property
def PosTagger(self):
raise NotImplementedError()

@property
def StringEntityEmbeddingFormatter(self):
raise NotImplementedError()
36 changes: 5 additions & 31 deletions arekit/contrib/networks/core/input/terms_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from arekit.common.data.input.terms_mapper import OpinionContainingTextTermsMapper
from arekit.common.entities.base import Entity
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.entities.types import OpinionEntityType
from arekit.common.frames.text_variant import TextFrameVariant
from arekit.contrib.networks.core.input.embedding.custom import create_term_embedding
from arekit.contrib.networks.embeddings.base import Embedding
Expand All @@ -17,20 +15,18 @@ class StringWithEmbeddingNetworkTermMapping(OpinionContainingTextTermsMapper):
MAX_PART_CUSTOM_EMBEDDING_SIZE = 3
TOKEN_RANDOM_SEED_OFFSET = 12345

def __init__(self, predefined_embedding, string_entities_formatter, string_emb_entity_formatter):
def __init__(self, predefined_embedding, string_entities_formatter):
"""
predefined_embedding:
string_emb_entity_formatter:
Utilized in order to obtain embedding value from predefined_embeding for enties
"""
assert(isinstance(predefined_embedding, Embedding))
assert(isinstance(string_emb_entity_formatter, StringEntitiesFormatter))

super(StringWithEmbeddingNetworkTermMapping, self).__init__(
entity_formatter=string_entities_formatter)

self.__predefined_embedding = predefined_embedding
self.__string_emb_entity_formatter = string_emb_entity_formatter

def map_word(self, w_ind, word):
value, vector = create_term_embedding(term=word,
Expand Down Expand Up @@ -67,23 +63,15 @@ def map_entity(self, e_ind, entity):
assert(isinstance(entity, Entity))

# Value extraction
str_entity_mask = super(StringWithEmbeddingNetworkTermMapping, self).map_entity(
str_formatted_entity = super(StringWithEmbeddingNetworkTermMapping, self).map_entity(
e_ind=e_ind,
entity=entity)

# TODO. Use synonyms. (Or use this functionality in a base class).
empty_set = set()
e_type = self.__get_entity_type(e_ind=e_ind,
subj_ind_set=empty_set,
obj_ind_set=empty_set)

# Vector extraction
entity_word = self.__string_emb_entity_formatter.to_string(original_value=None,
entity_type=e_type)
m_ind = self.__predefined_embedding.try_find_index_by_plain_word(entity_word)
vector = self.__predefined_embedding.get_vector_by_index(m_ind)
emb_word, vector = create_term_embedding(term=str_formatted_entity,
embedding=self.__predefined_embedding)

return str_entity_mask, vector
return emb_word, vector

# region private methods

Expand All @@ -94,18 +82,4 @@ def __get_random_normal_distribution(vector_size, seed, loc, scale):
np.random.seed(seed)
return np.random.normal(loc=loc, scale=scale, size=vector_size)

@staticmethod
def __get_entity_type(e_ind, subj_ind_set, obj_ind_set):
assert(isinstance(e_ind, int))
assert(isinstance(subj_ind_set, set))
assert(isinstance(obj_ind_set, set))

result = OpinionEntityType.Other
if e_ind in obj_ind_set:
result = OpinionEntityType.Object
elif e_ind in subj_ind_set:
result = OpinionEntityType.Subject

return result

# endregion
3 changes: 1 addition & 2 deletions arekit/contrib/networks/handlers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def on_iteration(self, iter_index):

text_terms_mapper = StringWithEmbeddingNetworkTermMapping(
predefined_embedding=self.__exp_ctx.WordEmbedding,
string_entities_formatter=self.__exp_ctx.StringEntityFormatter,
string_emb_entity_formatter=self.__exp_ctx.StringEntityEmbeddingFormatter)
string_entities_formatter=self.__exp_ctx.StringEntityFormatter)

text_provider = NetworkSingleTextProvider(
text_terms_mapper=text_terms_mapper,
Expand Down

0 comments on commit 4e799d6

Please sign in to comment.