Skip to content
Permalink
Browse files

feat(encoder): separate pooling as an indep. encoder

  • Loading branch information...
hanxiao committed Aug 26, 2019
1 parent 3956100 commit b4444cc07cb7945155a2b01ea796d1d18fb43629
Showing with 29 additions and 17 deletions.
  1. +29 −17 gnes/encoder/numeric/pooling.py
@@ -1,3 +1,4 @@
import os
from typing import Tuple

import numpy as np
@@ -27,26 +28,30 @@ def post_init(self):
import torch
self.torch = torch
elif self.backend == 'tensorflow':
os.environ['CUDA_VISIBLE_DEVICES'] = '0' if self.on_gpu else '-1'
import tensorflow as tf
try:
tf.enable_eager_execution()
except ValueError:
pass
self._tf_graph = tf.Graph()
config = tf.ConfigProto(device_count={'GPU': 1 if self.on_gpu else 0})
config.gpu_options.allow_growth = True
config.log_device_placement = False
self._sess = tf.Session(graph=self._tf_graph, config=config)
self.tf = tf

def mul_mask(self, x, m):
if self.backend in {'pytorch', 'torch'}:
return self.torch.mul(x, m.unsqueeze(2))
elif self.backend == 'tensorflow':
return x * self.tf.expand_dims(m, axis=-1)
with self._tf_graph.as_default():
return x * self.tf.expand_dims(m, axis=-1)
elif self.backend == 'numpy':
return x * np.expand_dims(m, axis=-1)

def minus_mask(self, x, m, offset: int = 1e30):
if self.backend in {'pytorch', 'torch'}:
return x - (1.0 - m).unsqueeze(2) * offset
elif self.backend == 'tensorflow':
return x - self.tf.expand_dims(1.0 - m, axis=-1) * offset
with self._tf_graph.as_default():
return x - self.tf.expand_dims(1.0 - m, axis=-1) * offset
elif self.backend == 'numpy':
return x - np.expand_dims(1.0 - m, axis=-1) * offset

@@ -55,16 +60,18 @@ def masked_reduce_mean(self, x, m, jitter: float = 1e-10):
return self.torch.div(self.torch.sum(self.mul_mask(x, m), dim=1),
self.torch.sum(m.unsqueeze(2), dim=1) + jitter)
elif self.backend == 'tensorflow':
return self.tf.reduce_sum(self.mul_mask(x, m), axis=1) / (
self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter)
with self._tf_graph.as_default():
return self.tf.reduce_sum(self.mul_mask(x, m), axis=1) / (
self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter)
elif self.backend == 'numpy':
return np.sum(self.mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + jitter)

def masked_reduce_max(self, x, m):
if self.backend in {'pytorch', 'torch'}:
return self.torch.max(self.minus_mask(x, m), 1)[0]
elif self.backend == 'tensorflow':
return self.tf.reduce_max(self.minus_mask(x, m), axis=1)
with self._tf_graph.as_default():
return self.tf.reduce_max(self.minus_mask(x, m), axis=1)
elif self.backend == 'numpy':
return np.max(self.minus_mask(x, m), axis=1)

@@ -73,16 +80,21 @@ def encode(self, data: Tuple, *args, **kwargs):
seq_tensor, mask_tensor = data

if self.pooling_strategy == 'REDUCE_MEAN':
return self.masked_reduce_mean(seq_tensor, mask_tensor)
r = self.masked_reduce_mean(seq_tensor, mask_tensor)
elif self.pooling_strategy == 'REDUCE_MAX':
return self.masked_reduce_max(seq_tensor, mask_tensor)
r = self.masked_reduce_max(seq_tensor, mask_tensor)
elif self.pooling_strategy == 'REDUCE_MEAN_MAX':
if self.backend in {'pytorch', 'torch'}:
return self.torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1)
r = self.torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1)
elif self.backend == 'tensorflow':
return self.tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)
with self._tf_graph.as_default():
r = self.tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)
elif self.backend == 'numpy':
return np.concatenate([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)
r = np.concatenate([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)

if self.backend == 'tensorflow':
r = self._sess.run(r)
return r

0 comments on commit b4444cc

Please sign in to comment.
You can’t perform that action at this time.