Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(encoder): add pytorch transformers support in text encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
raccoonliukai committed Aug 13, 2019
1 parent 909ec4b commit 732f2e6
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 3 deletions.
92 changes: 92 additions & 0 deletions gnes/encoder/text/torch_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=low-comment-ratio


from typing import List

import numpy as np

from ..base import BaseTextEncoder
from ...helper import batching


class TorchTransformersEncoder(BaseTextEncoder):
is_trained = True

def __init__(self, model_dir: str,
model_name: str,
tokenizer_name: str,
use_cuda: bool = False,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.model_dir = model_dir
self.model_name = model_name
self.tokenizer_name = tokenizer_name
self.use_cuda = use_cuda

def post_init(self):
import pytorch_transformers as ptt

self.tokenizer = getattr(ptt, self.tokenizer_name).from_pretrained(self.model_dir)
self.model = getattr(ptt, self.model_name).from_pretrained(self.model_dir)

@batching
def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
import torch
batch_size = len(text)

# tokenize text
tokens_ids = []
tokens_lens = []
max_len = 0
for _ in text:
# Convert token to vocabulary indices
token_ids = self.tokenizer.encode(_)
token_len = len(token_ids)

if max_len < token_len:
max_len = token_len

tokens_ids.append(token_ids)
tokens_lens.append(token_len)

batch_data = np.zeros([batch_size, max_len], dtype=np.int64)
# batch_mask = np.zeros([batch_size, max_len], dtype=np.float32)
for i, ids in enumerate(tokens_ids):
batch_data[i, :tokens_lens[i]] = tokens_ids[i]
# batch_mask[i, :tokens_lens[i]] = 1

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor(batch_data)
tokens_lens = torch.LongTensor(tokens_lens)
mask_tensor = torch.arange(max_len)[None, :] < tokens_lens[:, None]
mask_tensor = mask_tensor.to(
mask_tensor.device, dtype=torch.float32)

if self.use_cuda:
# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
mask_tensor = mask_tensor.to('cuda')

with torch.no_grad():
out_tensor = self.model(tokens_tensor)[0]
out_tensor = torch.mul(out_tensor, mask_tensor.unsqueeze(2))

if self.use_cuda:
output_tensor = output_tensor.cpu()

return out_tensor.numpy()
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
'leveldb': ['plyvel>=1.0.5'],
'test': ['pylint', 'memory_profiler>=0.55.0', 'psutil>=5.6.1', 'gputil>=1.4.0'],
'http': ['flask', 'flask-compress', 'flask-cors', 'flask-json', 'aiohttp==3.5.4'],
'onnx': ['onnxruntime']
'onnx': ['onnxruntime'],
'pytorch-transformers': ['pytorch-transformers']
}


Expand Down
6 changes: 4 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def line2pb_doc(line: str, doc_id: int = 0, deliminator: str = r'[.。!?!?]+
'INCEPTION_MODEL': '/',
'MOBILENET_MODEL': '/',
'FASTERRCNN_MODEL': '/',
'GNES_PROFILING': ''
'GNES_PROFILING': '',
'TORCH_TRANSFORMERS_MODEL': 'bert-base-uncased',
},
'idc-165': {
'BERT_CI_PORT': 7125,
Expand All @@ -67,7 +68,8 @@ def line2pb_doc(line: str, doc_id: int = 0, deliminator: str = r'[.。!?!?]+
'INCEPTION_MODEL': '/ext_data/image_encoder',
'MOBILENET_MODEL': '/ext_data/image_encoder',
'FASTERRCNN_MODEL': '/ext_data/image_preprocessor',
'GNES_PROFILING': ''
'GNES_PROFILING': '',
'TORCH_TRANSFORMERS_MODEL': '/ext_data/torch_transformer'
}

}
Expand Down
37 changes: 37 additions & 0 deletions tests/test_pytorch_transformers_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import unittest

from gnes.encoder.text.torch_transformers import TorchTransformersEncoder

class TestTorchTransformersEncoder(unittest.TestCase):

def setUp(self):
dirname = os.path.dirname(__file__)
self.dump_path = os.path.join(dirname, 'model.bin')
self.text_yaml = os.path.join(dirname, 'yaml', 'torch-transformers-encoder.yml')
self.tt_encoder = TorchTransformersEncoder.load_yaml(self.text_yaml)

self.test_str = []
with open(os.path.join(dirname, 'sonnets_small.txt')) as f:
for line in f:
line = line.strip()
if line:
self.test_str.append(line)

def test_encoding(self):
vec = self.tt_encoder.encode(self.test_str)
self.assertEqual(vec.shape[0], len(self.test_str))
self.assertEqual(vec.shape[2], 768)

def test_dump_load(self):
self.tt_encoder.dump(self.dump_path)

tt_encoder2 = TorchTransformersEncoder.load(self.dump_path)

vec = tt_encoder2.encode(self.test_str)
self.assertEqual(vec.shape[0], len(self.test_str))
self.assertEqual(vec.shape[2], 768)

def tearDown(self):
if os.path.exists(self.dump_path):
os.remove(self.dump_path)
7 changes: 7 additions & 0 deletions tests/yaml/torch-transformers-encoder.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
!TorchTransformersEncoder
parameter:
model_dir: $TORCH_TRANSFORMERS_MODEL
model_name: BertModel
tokenizer_name: BertTokenizer
gnes_config:
is_trained: true

0 comments on commit 732f2e6

Please sign in to comment.