Skip to content

Commit

Permalink
fix(executors): fix gpu supports
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed May 4, 2020
1 parent e4d44d4 commit 67cf25f
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 35 deletions.
73 changes: 73 additions & 0 deletions jina/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self, *args, **kwargs):
self._last_snapshot_ts = datetime.now()
self._drivers = {} # type: Dict[str, List['BaseDriver']]
self._attached_pea = None
self._backend = 'tensorflow'

def _post_init_wrapper(self, _metas: Dict = None, _requests: Dict = None, fill_in_metas: bool = True):
with TimeContext('post initiating, this may take some time', self.logger):
Expand Down Expand Up @@ -216,6 +217,27 @@ def post_init(self):
All class members created here will NOT be serialized when calling :func:`save`. Therefore if you
want to store them, please override the :func:`__getstate__`.
"""
self._set_device()

def _set_device(self):
if self._backend == 'tensorflow':
import tensorflow as tf
cpus = tf.config.experimental.list_physical_devices(device_type='CPU')
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
self._device = gpus[0] if self.on_gpu else cpus
# set before loading model
elif self._backend == 'paddlepaddle':
import paddle.fluid as fluid
self._device = fluid.CUDAPlace(0) if self.on_gpu else fluid.CPUPlace()
elif self._backend == 'pytorch':
import torch
self._device = torch.device('cuda:0') if self.on_gpu else torch.device('cpu')
elif self._backend == 'onnx':
self._device = ['CUDAExecutionProvider'] if self.on_gpu else ['CPUExecutionProvider']
else:
pass

def set_device(self):
pass

@classmethod
Expand Down Expand Up @@ -536,3 +558,54 @@ def __call__(self, req_type, *args, **kwargs):
raise UnattachedDriver(d)
else:
raise NoDriverForRequest(req_type)


class _BaseFramewordExecutor(BaseExecutor):
def post_init(self):
super().post_init()
self.build_model()
self.set_device()

def build_model(self):
raise NotImplementedError

def set_device(self):
raise NotImplementedError


class BaseTorchExecutor(_BaseFramewordExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'pytorch'

def set_device(self):
self.model.to(self._device)


class BaseOnnxExecutor(BaseExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'onnx'

def set_device(self):
self.model.set_providers(self._device)


class BaseTfExecutor(_BaseFramewordExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'tensorflow'

def set_device(self):
import tensorflow as tf
tf.config.experimental.set_visible_devices(self._device)


class BasePaddleExecutor(_BaseFramewordExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'paddlepaddle'

def set_device(self):
import paddle.fluid as fluid
self.exe = fluid.Executor(self._device)
5 changes: 3 additions & 2 deletions jina/executors/encoders/image/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from .. import BaseImageEncoder
from ...decorators import batching, as_ndarray
from ... import BaseOnnxExecutor


class OnnxImageEncoder(BaseImageEncoder):
class OnnxImageEncoder(BaseImageEncoder, BaseOnnxExecutor):
"""
:class:`OnnxImageEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
ndarray of `B x D`.
Expand Down Expand Up @@ -42,7 +43,7 @@ def __init__(self,
self.raw_model_path = model_path
self.model_name = ""

def post_init(self):
def build_model(self):
import onnxruntime
self.model_name = self.raw_model_path.split('/')[-1]
self.tmp_model_path = self.get_file_from_workspace(f'{self.model_name}.tmp')
Expand Down
6 changes: 3 additions & 3 deletions jina/executors/encoders/image/tfkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from .. import BaseImageEncoder
from ...decorators import batching, as_ndarray
from ... import BaseTfExecutor


class KerasImageEncoder(BaseImageEncoder):
class KerasImageEncoder(BaseImageEncoder, BaseTfExecutor):
"""
:class:`KerasImageEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
ndarray of `B x D`.
Expand Down Expand Up @@ -42,14 +43,13 @@ def __init__(self, model_name: str = 'MobileNetV2', img_shape: int = 96,
self.img_shape = img_shape
self.channel_axis = channel_axis

def post_init(self):
def build_model(self):
import tensorflow as tf
model = getattr(tf.keras.applications, self.model_name)(
input_shape=(self.img_shape, self.img_shape, 3),
include_top=False,
pooling=self.pool_strategy,
weights='imagenet')

model.trainable = False
self.model = model

Expand Down
6 changes: 3 additions & 3 deletions jina/executors/encoders/nlp/flair.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

from .. import BaseTextEncoder
from ...decorators import batching, as_ndarray
from ... import BaseTorchExecutor


class FlairTextEncoder(BaseTextEncoder):
class FlairTextEncoder(BaseTextEncoder, BaseTorchExecutor):
"""
:class:`FlairTextEncoder` encodes data from an array of string in size `B` into a ndarray in size `B x D`.
Internally, :class:`FlairTextEncoder` wraps the DocumentPoolEmbeddings from Flair.
Expand All @@ -36,10 +37,9 @@ def __init__(self,
self.model = None
self.max_length = -1 # reserved variable for future usages

def post_init(self):
def build_model(self):
from flair.embeddings import WordEmbeddings, FlairEmbeddings, BytePairEmbeddings, PooledFlairEmbeddings, \
DocumentPoolEmbeddings

if self.model is not None:
return
embeddings_list = []
Expand Down
5 changes: 3 additions & 2 deletions jina/executors/encoders/nlp/paddlehub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from .. import BaseTextEncoder
from ...decorators import batching, as_ndarray
from ... import BasePaddleExecutor


class TextPaddlehubEncoder(BaseTextEncoder):
class TextPaddlehubEncoder(BaseTextEncoder, BasePaddleExecutor):
"""
:class:`TextPaddlehubEncoder` encodes data from an array of string in size `B` into a ndarray in size `B x D`.
Internally, :class:`TextPaddlehubEncoder` wraps the Ernie module from paddlehub.
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self,
self.max_length = max_length
self.tokenizer = None

def post_init(self):
def build_model(self):
import paddlehub as hub
self.model = hub.Module(name=self.model_name)
self.model.MAX_SEQ_LEN = self.max_length
Expand Down
28 changes: 14 additions & 14 deletions jina/executors/encoders/nlp/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from .. import BaseTextEncoder
from ..helper import reduce_mean, reduce_max, reduce_min, reduce_cls
from ...decorators import batching, as_ndarray
from ... import _BaseFramewordExecutor, BaseTfExecutor, BaseTorchExecutor


class TransformerEncoder(BaseTextEncoder):
class BaseTransformerEncoder(_BaseFramewordExecutor):
"""
:class:`TransformerTextEncoder` encodes data from an array of string in size `B` into an ndarray in size `B x D`.
"""
Expand All @@ -32,14 +33,13 @@ def __init__(self,
:param model_path: the path of the encoder model. If a valid path is given, the encoder will be loaded from the
given path.
"""

super().__init__(*args, **kwargs)
self.model_name = model_name
self.pooling_strategy = pooling_strategy
self.max_length = max_length
self.raw_model_path = model_path

def post_init(self):
def _build_tokenizer(self):
from transformers import BertTokenizer, OpenAIGPTTokenizer, GPT2Tokenizer, \
XLNetTokenizer, XLMTokenizer, DistilBertTokenizer, RobertaTokenizer, XLMRobertaTokenizer, \
FlaubertTokenizer, CamembertTokenizer, CTRLTokenizer
Expand Down Expand Up @@ -130,15 +130,20 @@ def model_abspath(self) -> str:
"""
return self.get_file_from_workspace(self.raw_model_path)

def build_model(self):
self._build_tokenizer()
self._build_model()

def _build_model(self):
raise NotImplementedError

class TransformerTFEncoder(TransformerEncoder):

class TransformerTFEncoder(BaseTfExecutor, BaseTransformerEncoder):
"""
Internally, TransformerTFEncoder wraps the tensorflow-version of transformers from huggingface.
"""

def post_init(self):
super().post_init()

def _build_model(self):
import tensorflow as tf
from transformers import TFBertModel, TFOpenAIGPTModel, TFGPT2Model, TFXLNetModel, TFXLMModel, \
TFDistilBertModel, TFRobertaModel, TFXLMRobertaModel, TFCamembertModel, TFCTRLModel
Expand All @@ -157,23 +162,19 @@ def post_init(self):
self.model = model_dict[self.model_name].from_pretrained(self._tmp_model_path)
self._tensor_func = tf.constant
self._sess_func = tf.GradientTape

if self.model_name in ('xlnet-base-cased', 'openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'):
self.model.resize_token_embeddings(len(self.tokenizer))


class TransformerTorchEncoder(TransformerEncoder):
class TransformerTorchEncoder(BaseTorchExecutor, BaseTransformerEncoder):
"""
Internally, TransformerTorchEncoder wraps the pytorch-version of transformers from huggingface.
"""

def post_init(self):
super().post_init()

def _build_model(self):
import torch
from transformers import BertModel, OpenAIGPTModel, GPT2Model, XLNetModel, XLMModel, DistilBertModel, \
RobertaModel, XLMRobertaModel, FlaubertModel, CamembertModel, CTRLModel

model_dict = {
'bert-base-uncased': BertModel,
'openai-gpt': OpenAIGPTModel,
Expand All @@ -190,6 +191,5 @@ def post_init(self):
self.model = model_dict[self.model_name].from_pretrained(self._tmp_model_path)
self._tensor_func = torch.tensor
self._sess_func = torch.no_grad

if self.model_name in ('xlnet-base-cased', 'openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'):
self.model.resize_token_embeddings(len(self.tokenizer))
8 changes: 3 additions & 5 deletions jina/executors/encoders/paddlehub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from . import BaseNumericEncoder
from ..decorators import batching, as_ndarray
from .. import BasePaddleExecutor


class PaddlehubEncoder(BaseNumericEncoder):
class PaddlehubEncoder(BaseNumericEncoder, BasePaddleExecutor):
def __init__(self,
model_name: str,
output_feature: str,
Expand All @@ -23,14 +24,11 @@ def __init__(self,
self.channel_axis = channel_axis
self._default_channel_axis = -3

def post_init(self):
def build_model(self):
import paddlehub as hub
import paddle.fluid as fluid
module = hub.Module(name=self.model_name)
inputs, outputs, self.model = module.context(trainable=False)
self.get_inputs_and_outputs_name(inputs, outputs)
place = fluid.CUDAPlace(0) if self.on_gpu else fluid.CPUPlace()
self.exe = fluid.Executor(place)

def get_inputs_and_outputs_name(self, input_dict, output_dict):
raise NotImplementedError
Expand Down
12 changes: 6 additions & 6 deletions jina/executors/encoders/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import numpy as np

from . import BaseNumericEncoder
from .. import BaseTorchExecutor
from ..decorators import batching, as_ndarray


class TorchEncoder(BaseNumericEncoder):
class TorchEncoder(BaseTorchExecutor):
def __init__(self,
model_name: str,
channel_axis: int = 1,
Expand All @@ -17,11 +17,11 @@ def __init__(self,
self.channel_axis = channel_axis
self._default_channel_axis = 1

def post_init(self):
import torch
def build_model(self):
self._build_model()
device = 'cuda:0' if self.on_gpu else 'cpu'
self.model.to(torch.device(device))

def set_device(self):
self.model.to(self._device)

@batching
@as_ndarray
Expand Down

0 comments on commit 67cf25f

Please sign in to comment.