Skip to content
Permalink
Browse files

refactor(encoder): update the init func for flair

  • Loading branch information...
hanxiao committed Sep 29, 2019
1 parent 1b85375 commit e588c946afc1e82bf4baeadc8fd6cf5d4013eac6
Showing with 11 additions and 5 deletions.
  1. +11 −5 gnes/encoder/text/flair.py
@@ -14,7 +14,7 @@
# limitations under the License.


from typing import List
from typing import List, Tuple

import numpy as np

@@ -25,16 +25,22 @@
class FlairEncoder(BaseTextEncoder):
is_trained = True

def __init__(self, pooling_strategy: str = 'mean', *args, **kwargs):
def __init__(self,
word_embedding: str = 'glove',
flair_embeddings: Tuple[str] = ('news-forward', 'news-backward'),
pooling_strategy: str = 'mean', *args, **kwargs):
super().__init__(*args, **kwargs)

self.word_embedding = word_embedding
self.flair_embeddings = flair_embeddings
self.pooling_strategy = pooling_strategy

def post_init(self):
from flair.embeddings import DocumentPoolEmbeddings, WordEmbeddings, FlairEmbeddings
self._flair = DocumentPoolEmbeddings(
[WordEmbeddings('glove'),
FlairEmbeddings('news-forward'),
FlairEmbeddings('news-backward')],
[WordEmbeddings(self.word_embedding),
FlairEmbeddings(self.flair_embeddings[0]),
FlairEmbeddings(self.flair_embeddings[1])],
pooling=self.pooling_strategy)

@batching

0 comments on commit e588c94

Please sign in to comment.
You can’t perform that action at this time.