Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(encoder): replace gpt and elmo with transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Aug 26, 2019
1 parent a584c7e commit fe35193
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 414 deletions.
3 changes: 2 additions & 1 deletion gnes/base/__init__.py
Expand Up @@ -65,7 +65,8 @@ class TrainableType(type):
'is_trained': False,
'batch_size': None,
'work_dir': os.environ.get('GNES_VOLUME', os.getcwd()),
'name': None
'name': None,
'on_gpu': False
}

def __new__(cls, *args, **kwargs):
Expand Down
6 changes: 2 additions & 4 deletions gnes/encoder/__init__.py
Expand Up @@ -21,10 +21,7 @@
'BertEncoder': 'text.bert',
'BertEncoderWithServer': 'text.bert',
'BertEncoderServer': 'text.bert',
'ElmoEncoder': 'text.elmo',
'FlairEncoder': 'text.flair',
'GPTEncoder': 'text.gpt',
'GPT2Encoder': 'text.gpt',
'PCALocalEncoder': 'numeric.pca',
'PQEncoder': 'numeric.pq',
'TFPQEncoder': 'numeric.tf_pq',
Expand All @@ -43,7 +40,8 @@
'IncepMixtureEncoder': 'video.incep_mixture',
'VladEncoder': 'numeric.vlad',
'MfccEncoder': 'audio.mfcc',
'PoolingEncoder': 'numeric.pooling'
'PoolingEncoder': 'numeric.pooling',
'PyTorchTransformers': 'text.transformer'
}

register_all_class(_cls2file_map, 'encoder')
4 changes: 2 additions & 2 deletions gnes/encoder/base.py
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.


from typing import List, Any
from typing import List, Any, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -44,7 +44,7 @@ def encode(self, data: List['np.ndarray'], *args, **kwargs) -> np.ndarray:

class BaseTextEncoder(BaseEncoder):

def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
def encode(self, text: List[str], *args, **kwargs) -> Union[Tuple, np.ndarray]:
pass


Expand Down
66 changes: 0 additions & 66 deletions gnes/encoder/text/elmo.py

This file was deleted.

30 changes: 13 additions & 17 deletions gnes/encoder/text/flair.py
Expand Up @@ -19,34 +19,30 @@
import numpy as np

from ..base import BaseTextEncoder
from ...helper import batching, pooling_np
from ...helper import batching, as_numpy_array


class FlairEncoder(BaseTextEncoder):
is_trained = True

def __init__(self, model_name: str = 'multi-forward-fast',
pooling_strategy: str = 'REDUCE_MEAN', *args, **kwargs):
def __init__(self, pooling_strategy: str = 'mean', *args, **kwargs):
super().__init__(*args, **kwargs)

self.model_name = model_name
self.pooling_strategy = pooling_strategy

def post_init(self):
from flair.embeddings import FlairEmbeddings
self._flair = FlairEmbeddings(self.model_name)
from flair.embeddings import DocumentPoolEmbeddings, WordEmbeddings, FlairEmbeddings
self._flair = DocumentPoolEmbeddings(
[WordEmbeddings('glove'),
FlairEmbeddings('news-forward'),
FlairEmbeddings('news-backward')],
pooling=self.pooling_strategy)

@batching
@as_numpy_array
def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
from flair.data import Sentence
import torch
# tokenize text
batch_tokens = [Sentence(sent) for sent in text]

flair_encodes = self._flair.embed(batch_tokens)

pooled_data = []
for sentence in flair_encodes:
_layer_data = np.stack([s.embedding.numpy() for s in sentence])
_pooled = pooling_np(_layer_data, self.pooling_strategy)
pooled_data.append(_pooled)
return np.array(pooled_data, dtype=np.float32)
batch_tokens = [Sentence(v) for v in text]
self._flair.embed(batch_tokens)
return torch.stack([v.embedding for v in batch_tokens]).detach()
124 changes: 0 additions & 124 deletions gnes/encoder/text/gpt.py

This file was deleted.

91 changes: 0 additions & 91 deletions gnes/encoder/text/torch_transformers.py

This file was deleted.

0 comments on commit fe35193

Please sign in to comment.