Skip to content

Commit

Permalink
Merge pull request #16 from mrazekv/master
Browse files Browse the repository at this point in the history
Batch dot reimplementation
  • Loading branch information
brjathu committed Apr 8, 2020
2 parents 00dc21e + 4af8638 commit e273cfd
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 3 deletions.
92 changes: 92 additions & 0 deletions batchdot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from keras.backend import *
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops


#own_batch_dot = batch_dot # force standard implementation

# import of batch_dot operation from TF 1.13
# https://github.com/tensorflow/tensorflow/blob/v1.13.1/tensorflow/python/keras/backend.py

def own_batch_dot(x, y, axes=None):
"""Batchwise dot product.
`batch_dot` is used to compute dot product of `x` and `y` when
`x` and `y` are data in batch, i.e. in a shape of
`(batch_size, :)`.
`batch_dot` results in a tensor or variable with less dimensions
than the input. If the number of dimensions is reduced to 1,
we use `expand_dims` to make sure that ndim is at least 2.
Arguments:
x: Keras tensor or variable with `ndim >= 2`.
y: Keras tensor or variable with `ndim >= 2`.
axes: list of (or single) int with target dimensions.
The lengths of `axes[0]` and `axes[1]` should be the same.
Returns:
A tensor with shape equal to the concatenation of `x`'s shape
(less the dimension that was summed over) and `y`'s shape
(less the batch dimension and the dimension that was summed over).
If the final rank is 1, we reshape it to `(batch_size, 1)`.
Examples:
Assume `x = [[1, 2], [3, 4]]` and `y = [[5, 6], [7, 8]]`
`batch_dot(x, y, axes=1) = [[17, 53]]` which is the main diagonal
of `x.dot(y.T)`, although we never have to calculate the off-diagonal
elements.
Shape inference:
Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
If `axes` is (1, 2), to find the output shape of resultant tensor,
loop through each dimension in `x`'s shape and `y`'s shape:
* `x.shape[0]` : 100 : append to output shape
* `x.shape[1]` : 20 : do not append to output shape,
dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
* `y.shape[0]` : 100 : do not append to output shape,
always ignore first dimension of `y`
* `y.shape[1]` : 30 : append to output shape
* `y.shape[2]` : 20 : do not append to output shape,
dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
`output_shape` = `(100, 30)`
```python
>>> x_batch = K.ones(shape=(32, 20, 1))
>>> y_batch = K.ones(shape=(32, 30, 20))
>>> xy_batch_dot = K.batch_dot(x_batch, y_batch, axes=[1, 2])
>>> K.int_shape(xy_batch_dot)
(32, 1, 30)
```
"""
if isinstance(axes, int):
axes = (axes, axes)
x_ndim = ndim(x)
y_ndim = ndim(y)
if axes is None:
# behaves like tf.batch_matmul as default
axes = [x_ndim - 1, y_ndim - 2]
if x_ndim > y_ndim:
diff = x_ndim - y_ndim
y = array_ops.reshape(y,
array_ops.concat(
[array_ops.shape(y), [1] * (diff)], axis=0))
elif y_ndim > x_ndim:
diff = y_ndim - x_ndim
x = array_ops.reshape(x,
array_ops.concat(
[array_ops.shape(x), [1] * (diff)], axis=0))
else:
diff = 0
if ndim(x) == 2 and ndim(y) == 2:
if axes[0] == axes[1]:
out = math_ops.reduce_sum(math_ops.multiply(x, y), axes[0])
else:
out = math_ops.reduce_sum(
math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
else:
adj_x = None if axes[0] == ndim(x) - 1 else True
adj_y = True if axes[1] == ndim(y) - 1 else None
out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
if diff:
if x_ndim > y_ndim:
idx = x_ndim + y_ndim - 3
else:
idx = x_ndim - 1
out = array_ops.squeeze(out, list(range(idx, idx + diff)))
if ndim(out) == 1:
out = expand_dims(out, 1)
return out
9 changes: 6 additions & 3 deletions capslayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from keras.utils import conv_utils
from keras.layers import InputSpec
from keras.utils.conv_utils import conv_output_length
from batchdot import own_batch_dot


cf = K.image_data_format() == '..'
useGPU = True
Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(self, ch_j, n_j,
self.r_num = r_num
self.b_alphas = b_alphas
self.padding = conv_utils.normalize_padding(padding)
#self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.dilation_rate = (1, 1)
self.kernel_initializer = initializers.get(kernel_initializer)
Expand Down Expand Up @@ -543,7 +546,7 @@ def call(self, inputs, training=None):
# Regard the first two dimensions as `batch` dimension,
# then matmul: [input_dim_capsule] x [dim_capsule, input_dim_capsule]^T -> [dim_capsule].
# inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
inputs_hat = K.map_fn(lambda x: K.batch_dot(x, W2, [2, 3]), elems=inputs_tiled)
inputs_hat = K.map_fn(lambda x: own_batch_dot(x, W2, [2, 3]), elems=inputs_tiled)

# Begin: Routing algorithm ---------------------------------------------------------------------#
# The prior for coupling coefficient, initialized as zeros.
Expand All @@ -560,15 +563,15 @@ def call(self, inputs, training=None):
# The first two dimensions as `batch` dimension,
# then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
# outputs.shape=[None, num_capsule, dim_capsule]
outputs = squash(K.batch_dot(c, inputs_hat, [2, 2]) + self.B) # [None, 10, 16]
outputs = squash(own_batch_dot(c, inputs_hat, [2, 2]) + self.B) # [None, 10, 16]

if i < self.routings - 1:
# outputs.shape = [None, num_capsule, dim_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension,
# then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule].
# b.shape=[batch_size, num_capsule, input_num_capsule]
b += K.batch_dot(outputs, inputs_hat, [2, 3])
b += own_batch_dot(outputs, inputs_hat, [2, 3])
# End: Routing algorithm -----------------------------------------------------------------------#

return outputs
Expand Down

0 comments on commit e273cfd

Please sign in to comment.