In [1]:
import tensorflow as tf

# tf.nn.softmax

Computes softmax activations.

Used for multi-class predictions. The sum of all outputs generated by softmax is 1.

This function performs the equivalent of


`softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)`

# tf.multiply

Returns x * y element-wise.


ref @ https://docs.w3cub.com/tensorflow~python/tf/multiply

# tf.reduce_sum

Computes the sum of elements across dimensions of a tensor.

Reduces input_tensor along the dimensions given in axis. Unless keepdims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keepdims is true, the reduced dimensions are retained with length 1. def squash(self, s):
        with tf.name_scope("SquashFunction") as scope:
            s_norm = tf.norm(s, axis=-1, keepdims=True)
            return tf.square(s_norm)/(1 + tf.square(s_norm)) * s/(s_norm + epsilon)

$ s_j = \sum c_{ij}\hat{u}_{j|i} $

ref @ https://docs.w3cub.com/tensorflow~python/tf/reduce_sum

In [2]:
epsilon = 1e-7
    
def squash(s):
    s_norm = tf.norm(s, axis=-1, keepdims=True)
    return tf.square(s_norm)/(1 + tf.square(s_norm)) * s/(s_norm + epsilon)

In [3]:
# ------------ COMPUTE U_HAT

w_shape = (1, 1152, 10, 16, 8)
u_shape = (1, 1152, 1, 8, 1)
w = tf.fill(w_shape, 1.0)
u = tf.fill(u_shape, 1.0)
#u_hat shape: (1, 1152, 10, 16, 1)
u_hat = tf.matmul(w, u)
#u_hat shape: (1, 1152, 10, 16)
u_hat = tf.squeeze(u_hat, axis=-1)

# ----------- ROUTING

# b.shape: (1, 1152, 10, 1)
b = tf.zeros((1, 1152, 10, 1))

for i in range(1):
    
    # applies softmax to the output of all DigitCaps (axis=-2 => 10)
    # c.shape: (1, 1152, 10, 1)
    c = tf.nn.softmax(b, axis=-2) 
    
    # tmp.shape: (1, 1152, 10, 16)   
    # tmp is composed by all the u_hat ( vector 16D ) multiplied element wise for it's respective coupling coefficient
    tmp = tf.multiply(c, u_hat)
    
    # s.shape: (1, 1, 10, 16)
    # sum all vectors with respect to the second dimension
    # using the nomeclature of the paper this function sums the vectors with respect to i with j fixed
    s = tf.reduce_sum(tmp, axis=1, keepdims=True)
    
    v = squash(s) # v.shape: (None, 1, 10, 16)
    
    agreement = tf.squeeze(tf.matmul(tf.expand_dims(u_hat, axis=-1), tf.expand_dims(v, axis=-1), transpose_a=True), [4]) # agreement.shape: (None, 1152, 10, 1)
    
    # Before matmul following intermediate shapes are present, they are not assigned to a variable but just for understanding the code.
    # u_hat.shape (Intermediate shape) : (None, 1152, 10, 16, 1)
    # v.shape (Intermediate shape): (None, 1, 10, 16, 1)
    # Since the first parameter of matmul is to be transposed its shape becomes:(None, 1152, 10, 1, 16)
    # Now matmul is performed in the last two dimensions, and others are broadcasted
    # Before squeezing we have an intermediate shape of (None, 1152, 10, 1, 1)
    
    b += agreement

Metal device set to: Apple M1


2022-06-01 09:21:26.693713: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-06-01 09:21:26.693802: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [4]:
w = tf.fill(w_shape, 1.0)
u = tf.fill(u_shape, 1.0)

w_shape = (1, 1152, 10, 16, 8)
u_shape = (1, 1152, 1, 8, 1)
u_hat = tf.matmul(w, u)
u_hat = tf.squeeze(u_hat, axis=-1)

#u_hat shape: (1, 1152, 10, 16) filled with 8

b_shape = (1, 1152, 10, 1)
b = tf.fill(b_shape, 1.0)
c = tf.nn.softmax(b, axis=-2) 
#b.shape (1, 1152, 10, 1)
print(c.shape)
#u_hat.shape (1, 1152, 10, 16)
print(u_hat.shape)
#tmp.shape (1, 1152, 10, 16)
tmp = tf.multiply(c, u_hat)

print(tmp[0,0,0,0])

s = tf.reduce_sum(tmp, axis=1, keepdims=True)

print(s[0,0,0,0])

(1, 1152, 10, 1)
(1, 1152, 10, 16)
tf.Tensor(0.8, shape=(), dtype=float32)
tf.Tensor(921.59973, shape=(), dtype=float32)
