Skip to content
Permalink
Browse files

feat(helper): batching decorator supports tuple

  • Loading branch information...
hanxiao committed Aug 26, 2019
1 parent 928574c commit ce0e65aebcc4f779972fb59593542826467d27e5
Showing with 26 additions and 9 deletions.
  1. +5 −4 gnes/encoder/text/transformer.py
  2. +21 −5 gnes/helper.py
@@ -17,9 +17,9 @@
from typing import List, Tuple

import torch
from pytorch_transformers import *

from gnes.encoder.base import BaseTextEncoder
from ..base import BaseTextEncoder
from ...helper import batching


class PyTorchTransformers(BaseTextEncoder):
@@ -58,16 +58,17 @@ def load_model_tokenizer(x):
self.logger.warning('cannot deserialize model/tokenizer from %s, will download from web' % self.work_dir)
self.model, self.tokenizer = load_model_tokenizer(pretrained_weights)

@batching
def encode(self, text: List[str], *args, **kwargs) -> Tuple:
# encoding and padding
ids = [self.tokenizer.encode(t) for t in text]
max_len = max(len(t) for t in ids)
ids = [t + [0] * (max_len - len(t)) for t in ids]
m_ids = [[1] * len(t) + [0] * (max_len - len(t)) for t in ids]
seq_ids = torch.tensor(ids)
mask_ids = torch.tensor(m_ids)
mask_ids = torch.tensor(m_ids, dtype=torch.float32)

if self.use_cuda:
if self.on_gpu:
seq_ids = seq_ids.cuda()

with torch.no_grad():
@@ -403,11 +403,27 @@ def arg_wrapper(self, data, label=None, *args, **kwargs):
if r is not None:
final_result.append(r)

if len(final_result) and concat_axis is not None and isinstance(final_result[0], np.ndarray):
final_result = np.concatenate(final_result, concat_axis)

if chunk_dim != -1:
final_result = final_result.reshape((-1, chunk_dim, final_result.shape[1]))
if len(final_result) == 1:
# the only result of one batch
return final_result[0]

if len(final_result) and concat_axis is not None:
if isinstance(final_result[0], np.ndarray):
final_result = np.concatenate(final_result, concat_axis)
if chunk_dim != -1:
final_result = final_result.reshape((-1, chunk_dim, final_result.shape[1]))
elif isinstance(final_result[0], tuple):
reduced_result = []
num_cols = len(final_result[0])
for col in range(num_cols):
reduced_result.append(np.concatenate([row[col] for row in final_result], concat_axis))
if chunk_dim != -1:
for col in range(num_cols):
reduced_result[col] = reduced_result[col].reshape(
(-1, chunk_dim, reduced_result[col].shape[1]))
final_result = tuple(reduced_result)
else:
raise TypeError('dont know how to reduce %s' % type(final_result[0]))

if len(final_result):
return final_result

0 comments on commit ce0e65a

Please sign in to comment.
You can’t perform that action at this time.