Skip to content
Permalink
Browse files

feat(helper): add as_numpy_array decorator

  • Loading branch information...
hanxiao committed Aug 26, 2019
1 parent 543d561 commit a584c7e5c5e59590d62da6b09c00044bdbf63476
Showing with 81 additions and 29 deletions.
  1. +2 −1 gnes/encoder/__init__.py
  2. +31 −28 gnes/encoder/numeric/pooling.py
  3. +13 −0 gnes/helper.py
  4. +35 −0 tests/test_pooling_encoder.py
@@ -42,7 +42,8 @@
'CVAEEncoder': 'image.cvae',
'IncepMixtureEncoder': 'video.incep_mixture',
'VladEncoder': 'numeric.vlad',
'MfccEncoder': 'audio.mfcc'
'MfccEncoder': 'audio.mfcc',
'PoolingEncoder': 'numeric.pooling'
}

register_all_class(_cls2file_map, 'encoder')
@@ -3,6 +3,7 @@
import numpy as np

from ..base import BaseNumericEncoder
from ...helper import as_numpy_array


class PoolingEncoder(BaseNumericEncoder):
@@ -12,7 +13,7 @@ def __init__(self, pooling_strategy: str = 'REDUCE_MEAN',
super().__init__(*args, **kwargs)

valid_poolings = {'REDUCE_MEAN', 'REDUCE_MAX', 'REDUCE_MEAN_MAX'}
valid_backends = {'tensorflow', 'numpy', 'pytorch'}
valid_backends = {'tensorflow', 'numpy', 'pytorch', 'torch'}

if pooling_strategy not in valid_poolings:
raise ValueError('"pooling_strategy" must be one of %s' % valid_poolings)
@@ -21,46 +22,50 @@ def __init__(self, pooling_strategy: str = 'REDUCE_MEAN',
self.pooling_strategy = pooling_strategy
self.backend = backend

def mul_mask(self, x, m):
if self.backend == 'pytorch':
def post_init(self):
if self.backend in {'pytorch', 'torch'}:
import torch
return torch.mul(x, m.unsqueeze(2))
self.torch = torch
elif self.backend == 'tensorflow':
import tensorflow as tf
return x * tf.expand_dims(m, axis=-1)
tf.enable_eager_execution()
self.tf = tf

def mul_mask(self, x, m):
if self.backend in {'pytorch', 'torch'}:
return self.torch.mul(x, m.unsqueeze(2))
elif self.backend == 'tensorflow':
return x * self.tf.expand_dims(m, axis=-1)
elif self.backend == 'numpy':
return 0
return x * np.expand_dims(m, axis=-1)

def minus_mask(self, x, m, offset: int = 1e30):
if self.backend == 'pytorch':
if self.backend in {'pytorch', 'torch'}:
return x - (1.0 - m).unsqueeze(2) * offset
elif self.backend == 'tensorflow':
import tensorflow as tf
return x - tf.expand_dims(1.0 - m, axis=-1) * offset
return x - self.tf.expand_dims(1.0 - m, axis=-1) * offset
elif self.backend == 'numpy':
return 0
return x - np.expand_dims(1.0 - m, axis=-1) * offset

def masked_reduce_mean(self, x, m, jitter: float = 1e-10):
if self.backend == 'pytorch':
import torch
return torch.div(torch.sum(self.mul_mask(x, m), dim=1),
torch.sum(m.unsqueeze(2), dim=1) + jitter)
if self.backend in {'pytorch', 'torch'}:
return self.torch.div(self.torch.sum(self.mul_mask(x, m), dim=1),
self.torch.sum(m.unsqueeze(2), dim=1) + jitter)
elif self.backend == 'tensorflow':
import tensorflow as tf
return tf.reduce_sum(self.mul_mask(x, m), axis=1) / (tf.reduce_sum(m, axis=1, keepdims=True) + jitter)
return self.tf.reduce_sum(self.mul_mask(x, m), axis=1) / (
self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter)
elif self.backend == 'numpy':
return np.sum(self.mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + jitter)

def masked_reduce_max(self, x, m):
if self.backend == 'pytorch':
import torch
return torch.max(self.minus_mask(x, m), 1)[0]
if self.backend in {'pytorch', 'torch'}:
return self.torch.max(self.minus_mask(x, m), 1)[0]
elif self.backend == 'tensorflow':
import tensorflow as tf
return tf.reduce_max(self.minus_mask(x, m), axis=1)
return self.tf.reduce_max(self.minus_mask(x, m), axis=1)
elif self.backend == 'numpy':
return np.max(self.minus_mask(x, m), axis=1)

@as_numpy_array
def encode(self, data: Tuple, *args, **kwargs):
seq_tensor, mask_tensor = data

@@ -69,14 +74,12 @@ def encode(self, data: Tuple, *args, **kwargs):
elif self.pooling_strategy == 'REDUCE_MAX':
return self.masked_reduce_max(seq_tensor, mask_tensor)
elif self.pooling_strategy == 'REDUCE_MEAN_MAX':
if self.backend == 'torch':
import torch
return torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1)
if self.backend in {'pytorch', 'torch'}:
return self.torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1)
elif self.backend == 'tensorflow':
import tensorflow as tf
return tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)
return self.tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)
elif self.backend == 'numpy':
return np.concatenate([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)
@@ -481,6 +481,19 @@ def countdown(t: int, logger=None, reason: str = 'I am blocking this thread'):
sys.stdout.flush()


def as_numpy_array(func, dtype=np.float32):
@wraps(func)
def arg_wrapper(self, *args, **kwargs):
r = func(self, *args, **kwargs)
r_type = type(r).__name__
if r_type in {'ndarray', 'EagerTensor', 'Tensor', 'list'}:
return np.array(r, dtype)
else:
raise TypeError('unrecognized type %s: %s' % (r_type, type(r)))

return arg_wrapper


def train_required(func):
@wraps(func)
def arg_wrapper(self, *args, **kwargs):
@@ -0,0 +1,35 @@
import unittest

import numpy as np
import torch
from numpy.testing import assert_allclose

from gnes.encoder.numeric.pooling import PoolingEncoder


class TestEncoder(unittest.TestCase):
def setUp(self):
self.seq_data = np.random.random([5, 10])
self.seq_embed_data = np.random.random([5, 10, 32])
self.mask_data = np.array(self.seq_data > 0.5, np.float32)
self.data = [
(torch.tensor(self.seq_embed_data, dtype=torch.float32), torch.tensor(self.mask_data, dtype=torch.float32)),
(self.seq_embed_data, self.mask_data),
(self.seq_embed_data, self.mask_data)]

def _test_strategy(self, strategy):
pe_to = PoolingEncoder(strategy, 'torch')
pe_tf = PoolingEncoder(strategy, 'tensorflow')
pe_np = PoolingEncoder(strategy, 'numpy')
return [pe.encode(self.data[idx]) for idx, pe in enumerate([pe_to, pe_tf, pe_np])]

def test_all(self):
for s in {'REDUCE_MEAN', 'REDUCE_MAX', 'REDUCE_MEAN_MAX'}:
with self.subTest(strategy=s):
r = self._test_strategy(s)
for rr in r:
print(type(rr))
print(rr)
print('---')
assert_allclose(r[0], r[1], rtol=1e-5)
assert_allclose(r[1], r[2], rtol=1e-5)

0 comments on commit a584c7e

Please sign in to comment.
You can’t perform that action at this time.