Skip to content
Permalink
Browse files

refactor(base): moving is_trained to class attribute

  • Loading branch information...
hanxiao committed Aug 2, 2019
1 parent fc5026d commit 58217d8cd3deaad6dbca6e8683e5baeea370593f
@@ -85,7 +85,8 @@ def __call__(cls, *args, **kwargs):

obj = type.__call__(cls, *args, **kwargs)

# set attribute
# set attribute with priority
# gnes_config in YAML > class attribute > default_gnes_config
for k, v in TrainableType.default_gnes_config.items():
if k in gnes_config:
v = gnes_config[k]
@@ -163,7 +164,6 @@ class TrainableBase(metaclass=TrainableType):
store_args_kwargs = False

def __init__(self, *args, **kwargs):
self.is_trained = False
self.verbose = 'verbose' in kwargs and kwargs['verbose']
self.logger = set_logger(self.__class__.__name__, self.verbose)
self._post_init_vars = set()
@@ -26,10 +26,10 @@

class BertEncoder(BaseTextEncoder):
store_args_kwargs = True
is_trained = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_trained = True
self._bc_encoder_args = args
self._bc_encoder_kwargs = kwargs

@@ -52,6 +52,7 @@ def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:

class BertEncoderServer(BaseTextEncoder):
store_args_kwargs = True
is_trained = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -60,7 +61,6 @@ def __init__(self, *args, **kwargs):
bert_args.append('-%s' % k)
bert_args.append(str(v))
self._bert_args = bert_args
self.is_trained = True

def post_init(self):
from bert_serving.server import BertServer
@@ -25,6 +25,7 @@


class ElmoEncoder(BaseTextEncoder):
is_trained = True

def __init__(self, model_dir: str, batch_size: int = 64, pooling_layer: int = -1,
pooling_strategy: str = 'REDUCE_MEAN', *args, **kwargs):
@@ -38,7 +39,6 @@ def __init__(self, model_dir: str, batch_size: int = 64, pooling_layer: int = -1
pooling_layer)
self.pooling_layer = pooling_layer
self.pooling_strategy = pooling_strategy
self.is_trained = True

def post_init(self):
from elmoformanylangs import Embedder
@@ -25,6 +25,7 @@


class FlairEncoder(BaseTextEncoder):
is_trained = True

def __init__(self, model_name: str = 'multi-forward-fast',
batch_size: int = 64,
@@ -35,7 +36,6 @@ def __init__(self, model_name: str = 'multi-forward-fast',

self.batch_size = batch_size
self.pooling_strategy = pooling_strategy
self.is_trained = True

def post_init(self):
from flair.embeddings import FlairEmbeddings
@@ -25,6 +25,8 @@


class GPTEncoder(BaseTextEncoder):
is_trained = True

def __init__(self,
model_dir: str,
batch_size: int = 64,
@@ -38,7 +40,6 @@ def __init__(self,
self.batch_size = batch_size
self.pooling_strategy = pooling_strategy
self._use_cuda = use_cuda
self.is_trained = True

def post_init(self):
import torch
@@ -25,6 +25,8 @@


class Word2VecEncoder(BaseTextEncoder):
is_trained = True

def __init__(self, model_dir: str,
skiprows: int = 1,
batch_size: int = 64,
@@ -35,7 +37,6 @@ def __init__(self, model_dir: str,
self.skiprows = skiprows
self.batch_size = batch_size
self.pooling_strategy = pooling_strategy
self.is_trained = True
self.dimension = dimension

def post_init(self):
@@ -35,12 +35,7 @@ def _test_topology(self, yaml_path: str, num_layer_before: int, num_layer_after:

@unittest.SkipTest
def test_flask_local(self):
yaml_path = os.path.join(self.dirname, 'yaml', 'topology1.yml')
args = set_composer_flask_parser().parse_args([
'--flask',
'--yaml_path', yaml_path,
'--html_path', self.html_path
])
args = set_composer_flask_parser().parse_args(['--flask'])
YamlComposerFlask(args).run()

def test_flask(self):
@@ -5,9 +5,10 @@


class DummyTFEncoder(BaseEncoder):
is_trained = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_trained = True

def post_init(self):
import tensorflow as tf

0 comments on commit 58217d8

Please sign in to comment.
You can’t perform that action at this time.