Skip to content
This repository has been archived by the owner on Sep 24, 2020. It is now read-only.

Commit

Permalink
♻️ (bert) provide two ways to get embedding: __call__ and embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
imgarylai committed Mar 4, 2019
1 parent dd22a28 commit 1a65330
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion bert_embedding/bert.py
Expand Up @@ -60,6 +60,9 @@ def __init__(self, ctx=mx.cpu(), model='bert_12_768_12',
use_classifier=False)

def __call__(self, sentences, oov_way='avg'):
return self.embedding(sentences, oov_way='avg')

def embedding(self, sentences, oov_way='avg'):
"""
Get tokens, tokens embedding
Expand All @@ -85,7 +88,7 @@ def __call__(self, sentences, oov_way='avg'):
sequence_outputs = self.bert(token_ids, token_types,
valid_length.astype('float32'))
for token_id, sequence_output in zip(token_ids.asnumpy(),
sequence_outputs.asnumpy()):
sequence_outputs.asnumpy()):
batches.append((token_id, sequence_output))
return self.oov(batches, oov_way)

Expand Down

0 comments on commit 1a65330

Please sign in to comment.