In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [2]:
import tensorflow_probability as tfp

In [3]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

In [4]:
from scipy import random

In [5]:
@tf.function
def squash(x, axis=-1):
    s_squared_norm = tf.math.reduce_sum(tf.math.square(x), axis, keepdims=True) + keras.backend.epsilon()
    scale = tf.math.sqrt(s_squared_norm) / (1 + s_squared_norm)
    return scale * x

@tf.function
def margin_loss(y_true, y_pred):
    lamb, margin = 0.5, 0.1
    return tf.math.reduce_sum((y_true * tf.math.square(tf.nn.relu(1 - margin - y_pred)) + lamb * (
        1 - y_true) * tf.math.square(tf.nn.relu(y_pred - margin))), axis=-1)

@tf.function
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
        squared_norm = tf.reduce_sum(tf.square(s),axis=axis,keepdims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)

In [6]:
# data loading in appropriate formate

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float64")
x_test = x_test[..., tf.newaxis].astype("float64")

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [7]:
X=x_train[:3]

In [8]:
X.shape

(3, 28, 28, 1)

In [9]:
c1=tf.keras.layers.Conv2D(256,kernel_size=5,strides=1,padding='valid',activation='relu')
c2=tf.keras.layers.Conv2D(256,kernel_size=5,strides=2,padding='valid',activation='relu')
c3=tf.keras.layers.Conv2D(256,kernel_size=5,strides=2,padding='valid',activation='relu')
bn1=tf.keras.layers.BatchNormalization()
bn2=tf.keras.layers.BatchNormalization()

In [10]:
z=c3(bn2(c2(bn1(c1(X)))))

In [11]:
z.shape

TensorShape([3, 3, 3, 256])

In [12]:
z=tf.reshape(z,[-1,256,9])

In [13]:
z.shape

TensorShape([3, 256, 9])

In [14]:
n=256
k=10
d=9
batch_size=3

In [15]:
z=tf.cast(z, tf.float64)

In [16]:
init_sigma = 0.1

W_init = tf.random.normal(
    shape=(1,n,k,d,d),
    stddev=init_sigma, dtype=tf.float64)
W = tf.Variable(W_init)

In [17]:
W.shape

TensorShape([1, 256, 10, 9, 9])

In [18]:
caps1_output_expanded = tf.expand_dims(z, -1)
caps1_output_tile = tf.expand_dims(caps1_output_expanded, 2)
caps1_output_tiled = tf.tile(caps1_output_tile, [1, 1,k, 1, 1])

In [19]:
caps1_output_tiled.shape

TensorShape([3, 256, 10, 9, 1])

In [20]:
caps2_predicted=tf.matmul(W,caps1_output_tiled)

In [21]:
caps2_predicted.shape

TensorShape([3, 256, 10, 9, 1])

In [22]:
z.shape

TensorShape([3, 256, 9])

In [23]:
init_mu = random.rand(batch_size,k, d)*20 - 10
mu = init_mu

init_sigma = np.zeros((k, d, d))
for i in range(k):
    init_sigma[i] = np.eye(d)
sigma = init_sigma
sigma=tf.expand_dims(sigma,axis=0)
sigma=tf.tile(sigma,[batch_size,1,1,1])

init_pi = np.ones(k)/k
pi = init_pi
pi=tf.expand_dims(pi,axis=0)
pi=tf.tile(pi,[batch_size,1])

R=np.zeros(shape=(n,k))
R=tf.expand_dims(R,axis=0)
R=tf.tile(R,[batch_size,1,1])

pi=tf.Variable(pi,dtype=tf.float64)
mu=tf.Variable(mu,dtype=tf.float64)
sigma=tf.Variable(sigma,dtype=tf.float64)
R=tf.Variable(R,dtype=tf.float64)

print(mu.shape,pi.shape,sigma.shape,R.shape)

N=np.zeros((batch_size,n))
N=tf.Variable(N,dtype=tf.float64)

(3, 10, 9) (3, 10) (3, 10, 9, 9) (3, 256, 10)


In [26]:
# E-step.
for i in range(k):
    for b in range(batch_size):
        tmp = tfp.distributions.MultivariateNormalFullCovariance(loc=mu[b][i],
                                                               covariance_matrix=sigma[b][i]).prob(z[b])
        N[b].assign(tmp)
    R[:,:,i].assign(tf.expand_dims(pi[:,i],axis=1)*N)
R.assign(R/tf.reduce_sum(R,axis=2, keepdims=True))

# M-step
N_k=tf.reduce_sum(R,axis=1)
pi=N_k/n
mu=tf.matmul(tf.transpose(R,perm=[0,2,1]),z)
mu=mu/N_k[:,:,None]

for i in range(k):
    tmp=z-tf.expand_dims(mu[:,i,:],axis=1)
    tmp=tf.expand_dims(tmp,axis=-1)
    tmp_T=tf.transpose(tmp,perm=[0,1,3,2])
    res=tf.matmul(tmp,tmp_T)
    res=tf.multiply(tf.reshape(R[:,:,i],[batch_size,n,1,1]),res)
    res=tf.reduce_sum(res,axis=1)/tf.reshape(N_k[:,i],[batch_size,1,1])
    sigma[:,i].assign(res)

In [27]:
R # coupling coefficient.

<tf.Variable 'Variable:0' shape=(3, 256, 10) dtype=float64, numpy=
array([[[6.25504195e-36, 7.72355429e-32, 7.75496266e-37, ...,
         1.67945343e-66, 8.69497372e-47, 1.00000000e+00],
        [5.55936456e-36, 3.95371080e-31, 4.32123098e-36, ...,
         1.25797381e-65, 3.54704519e-45, 1.00000000e+00],
        [1.85381552e-35, 1.06206965e-30, 1.31256265e-35, ...,
         1.44571063e-65, 1.06947436e-44, 1.00000000e+00],
        ...,
        [7.74069738e-36, 1.01337154e-30, 7.17333693e-36, ...,
         1.44395249e-65, 6.70218287e-45, 1.00000000e+00],
        [3.47690098e-36, 1.95952375e-31, 7.77837623e-38, ...,
         3.73396083e-67, 1.01949299e-44, 1.00000000e+00],
        [2.62197672e-34, 2.18448834e-30, 5.98471697e-36, ...,
         1.21156310e-65, 4.87279626e-48, 1.00000000e+00]],

       [[2.66350616e-12, 5.48901058e-09, 1.12164843e-60, ...,
         4.28401556e-25, 9.99999995e-01, 8.40737614e-38],
        [3.14008151e-10, 5.73578443e-09, 1.99499493e-55, ...,
         5.53818

In [29]:
caps2_predicted.shape

TensorShape([3, 256, 10, 9, 1])

In [30]:
R.shape

TensorShape([3, 256, 10])

In [37]:
weighted_prediction=tf.multiply(caps2_predicted,tf.reshape(R,[batch_size,n,k,1,1]))

In [38]:
weighted_prediction.shape

TensorShape([3, 256, 10, 9, 1])

In [39]:
weighted_sum = tf.reduce_sum(weighted_prediction, axis=1, keepdims=True)

In [40]:
weighted_sum.shape

TensorShape([3, 1, 10, 9, 1])

In [41]:
v=squash(weighted_sum, axis=-2)
print(v.shape)

(3, 1, 10, 9, 1)


In [42]:
v

<tf.Tensor: shape=(3, 1, 10, 9, 1), dtype=float64, numpy=
array([[[[[ 1.16859643e-34],
          [ 9.88838018e-35],
          [ 3.13300165e-35],
          [ 1.40187216e-35],
          [ 9.21522128e-35],
          [ 3.70461803e-35],
          [-1.51467832e-35],
          [-3.47508434e-35],
          [ 1.12593591e-34]],

         [[-2.21480991e-30],
          [ 1.75767961e-30],
          [-5.21278667e-30],
          [ 4.16336280e-31],
          [ 1.91546001e-31],
          [ 4.64552578e-30],
          [-3.77633742e-30],
          [ 3.23691998e-30],
          [-9.06403971e-31]],

         [[ 5.30468174e-35],
          [-1.00951429e-35],
          [-2.35832705e-35],
          [-1.23984912e-35],
          [ 3.64071415e-35],
          [ 4.80640068e-35],
          [ 5.58173564e-35],
          [ 3.66953921e-35],
          [ 6.08486072e-35]],

         [[ 7.16085124e-24],
          [ 8.92353740e-24],
          [-1.76250177e-24],
          [-1.38224372e-24],
          [-5.71507509e-24],
        