Skip to content

Commit

Permalink
#156 related. Fixed bug: embedding won't be initialized from the empt…
Browse files Browse the repository at this point in the history
…y sequence.
  • Loading branch information
nicolay-r committed Mar 24, 2022
1 parent 59140d5 commit b1a9ef4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion arekit/contrib/networks/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def from_word_embedding_pairs_iter(cls, word_embedding_pairs):
matrix.append(vector)
words.append(word)

return cls(matrix=np.array(matrix),
return cls(matrix=np.array(matrix) if len(matrix) > 0 else np.empty(shape=(0, 0)),
words=words)

@classmethod
Expand Down
21 changes: 21 additions & 0 deletions tests/contrib/networks/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest

from arekit.contrib.networks.embeddings.base import Embedding


class TestModelNamesService(unittest.TestCase):

def test(self):
single_element = [("a", [1])]
e = Embedding.from_word_embedding_pairs_iter(iter(single_element))
print(e.VocabularySize)
print(e.VectorSize)

def test_empty(self):
e = Embedding.from_word_embedding_pairs_iter(iter([]))
print(e.VocabularySize)
print(e.VectorSize)


if __name__ == '__main__':
unittest.main()

0 comments on commit b1a9ef4

Please sign in to comment.