diff --git a/flair/embeddings.py b/flair/embeddings.py index a255879de..a23d48d46 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -17,7 +17,6 @@ from .data import Dictionary, Token, Sentence from .file_utils import cached_path - log = logging.getLogger('flair') @@ -329,13 +328,14 @@ def extra_repr(self): def __str__(self): return self.name + class CharacterEmbeddings(TokenEmbeddings): """Character embeddings of words, as proposed in Lample et al., 2016.""" def __init__(self, path_to_char_dict: str = None): """Uses the default character dictionary if none provided.""" - super(CharacterEmbeddings, self).__init__() + super().__init__() self.name = 'Char' self.static_embeddings = False @@ -353,6 +353,8 @@ def __init__(self, path_to_char_dict: str = None): self.__embedding_length = self.char_embedding_dim * 2 + self.to(flair.device) + @property def embedding_length(self) -> int: return self.__embedding_length @@ -909,8 +911,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: all_input_masks = all_input_masks.to(flair.device) # put encoded batch through BERT model to get all hidden states of all encoder layers - if torch.cuda.is_available(): - self.model.cuda() + self.model.to(flair.device) self.model.eval() all_encoder_layers, _ = self.model(all_input_ids, token_type_ids=None, attention_mask=all_input_masks)