Skip to content

Commit

Permalink
Utilized missed parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jun 27, 2022
1 parent 05e3148 commit 2c84512
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
5 changes: 4 additions & 1 deletion arekit/contrib/networks/core/input/embedding/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
logger = logging.getLogger(__name__)


def create_term_embedding(term, embedding, word_separator=' '):
# TODO. #348 related. Move it into `utils` contrib.

def create_term_embedding(term, embedding, max_part_size, word_separator=' '):
"""
Embedding algorithm based on parts (trigrams originally)
"""
Expand All @@ -17,6 +19,7 @@ def create_term_embedding(term, embedding, word_separator=' '):
else:
word, word_embedding = __compose_from_parts(term=term,
embedding=embedding,
max_part_size=max_part_size,
word_separator=word_separator)

# In order to prevent a problem of the further separations during reading process.
Expand Down
10 changes: 7 additions & 3 deletions arekit/contrib/networks/core/input/terms_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def __init__(self, predefined_embedding, string_entities_formatter):

def map_word(self, w_ind, word):
value, vector = create_term_embedding(term=word,
embedding=self.__predefined_embedding)
embedding=self.__predefined_embedding,
max_part_size=self.MAX_PART_CUSTOM_EMBEDDING_SIZE)
return value, vector

def map_text_frame_variant(self, fv_ind, text_frame_variant):
assert(isinstance(text_frame_variant, TextFrameVariant))
value, embedding = create_term_embedding(term=text_frame_variant.Variant.get_value(),
embedding=self.__predefined_embedding)
embedding=self.__predefined_embedding,
max_part_size=self.MAX_PART_CUSTOM_EMBEDDING_SIZE)

return value, embedding

Expand All @@ -48,6 +50,7 @@ def map_token(self, t_ind, token):

seed_token_offset = self.TOKEN_RANDOM_SEED_OFFSET

# TODO. #348 related. Move it into `utils` contrib.
vector = self.__get_random_normal_distribution(
vector_size=self.__predefined_embedding.VectorSize,
seed=t_ind + seed_token_offset,
Expand All @@ -69,6 +72,7 @@ def map_entity(self, e_ind, entity):

# Vector extraction
emb_word, vector = create_term_embedding(term=str_formatted_entity,
max_part_size=self.MAX_PART_CUSTOM_EMBEDDING_SIZE,
embedding=self.__predefined_embedding)

return emb_word, vector
Expand All @@ -82,4 +86,4 @@ def __get_random_normal_distribution(vector_size, seed, loc, scale):
np.random.seed(seed)
return np.random.normal(loc=loc, scale=scale, size=vector_size)

# endregion
# endregion

0 comments on commit 2c84512

Please sign in to comment.