Skip to content

Commit

Permalink
Merge pull request #434 from zalandoresearch/GH-421-character-embeddings
Browse files Browse the repository at this point in the history
Gh 421 character embeddings
  • Loading branch information
kashif committed Jan 31, 2019
2 parents 78f74c3 + b7547bc commit affec65
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions flair/embeddings.py
Expand Up @@ -17,7 +17,6 @@
from .data import Dictionary, Token, Sentence
from .file_utils import cached_path


log = logging.getLogger('flair')


Expand Down Expand Up @@ -329,13 +328,14 @@ def extra_repr(self):
def __str__(self):
return self.name


class CharacterEmbeddings(TokenEmbeddings):
"""Character embeddings of words, as proposed in Lample et al., 2016."""

def __init__(self, path_to_char_dict: str = None):
"""Uses the default character dictionary if none provided."""

super(CharacterEmbeddings, self).__init__()
super().__init__()
self.name = 'Char'
self.static_embeddings = False

Expand All @@ -353,6 +353,8 @@ def __init__(self, path_to_char_dict: str = None):

self.__embedding_length = self.char_embedding_dim * 2

self.to(flair.device)

@property
def embedding_length(self) -> int:
return self.__embedding_length
Expand Down Expand Up @@ -909,8 +911,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
all_input_masks = all_input_masks.to(flair.device)

# put encoded batch through BERT model to get all hidden states of all encoder layers
if torch.cuda.is_available():
self.model.cuda()
self.model.to(flair.device)
self.model.eval()
all_encoder_layers, _ = self.model(all_input_ids, token_type_ids=None, attention_mask=all_input_masks)

Expand Down

0 comments on commit affec65

Please sign in to comment.