Skip to content
Cannot retrieve contributors at this time
from keras.engine import Layer, InputSpec
import tensorflow as tf
class KMaxPooling(Layer):
K-max pooling layer that extracts the k-highest activations from a sequence (2nd dimension).
TensorFlow backend.
def __init__(self, k=1, axis=1, **kwargs):
super(KMaxPooling, self).__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
self.k = k
assert axis in [1,2], 'expected dimensions (samples, filters, convolved_values),\
cannot fold along samples dimension or axis not in list [1,2]'
self.axis = axis
# need to switch the axis with the last elemnet
# to perform transpose for tok k elements since top_k works in last axis
self.transpose_perm = [0,1,2] #default
self.transpose_perm[self.axis] = 2
self.transpose_perm[2] = self.axis
def compute_output_shape(self, input_shape):
input_shape_list = list(input_shape)
input_shape_list[self.axis] = self.k
return tuple(input_shape_list)
def call(self, x):
# swap sequence dimension to get top k elements along axis=1
transposed_for_topk = tf.transpose(x, perm=self.transpose_perm)
# extract top_k, returns two tensors [values, indices]
top_k = tf.nn.top_k(transposed_for_topk, k=self.k, sorted=True, name=None)[0]
# return back to normal dimension but now sequence dimension has only k elements
# performing another transpose will get the tensor back to its original shape
# but will have k as its axis_1 size
transposed_back = tf.transpose(top_k, perm=self.transpose_perm)
return transposed_back
class Folding(Layer):
def __init__(self, **kwargs):
super(Folding, self).__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[1], int(input_shape[2]/2))
def call(self, x):
input_shape = x.get_shape().as_list()
# split the tensor along dimension 2 into dimension_axis_size/2
# which will give us 2 tensors
splits = tf.split(x, num_or_size_splits=int(input_shape[2]/2), axis=2)
# reduce sums of the pair of rows we have split onto
reduce_sums = [tf.reduce_sum(split, axis=2) for split in splits]
# stack them up along the same axis we have reduced
row_reduced = tf.stack(reduce_sums, axis=2)
return row_reduced