In [11]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers

# Squashing function
def squash(vectors, axis=-1):
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + tf.keras.backend.epsilon())
    return scale * vectors

# Primary Capsule Layer
class PrimaryCapsLayer(layers.Layer):
    def __init__(self, num_capsules, dim_capsule, kernel_size, strides, padding, **kwargs):
        super(PrimaryCapsLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsule = dim_capsule
        self.conv2d = layers.Conv2D(filters=num_capsules * dim_capsule,
                                    kernel_size=kernel_size,
                                    strides=strides,
                                    padding=padding)

    def call(self, inputs):
        outputs = self.conv2d(inputs)
        batch_size = tf.shape(outputs)[0]
        outputs = tf.reshape(outputs, (batch_size, -1, self.dim_capsule))  # Reshape to (batch_size, num_capsules, dim_capsule)
        return squash(outputs)

# Digit Capsule Layer with Dynamic Routing
class DigitCapsLayer(layers.Layer):
    def __init__(self, num_capsules, dim_capsule, num_routing, **kwargs):
        super(DigitCapsLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsule = dim_capsule
        self.num_routing = num_routing

    def build(self, input_shape):
        self.W = self.add_weight(shape=[input_shape[1], self.num_capsules, self.dim_capsule, input_shape[2]],
                                 initializer='glorot_uniform',
                                 trainable=True)

    def call(self, inputs):
        inputs_expand = tf.expand_dims(inputs, 2)  # Expand dimensions to add capsules axis
        inputs_tile = tf.expand_dims(inputs_expand, 3)  # Tile the inputs to match dimensions for routing
        inputs_hat = tf.map_fn(lambda x: tf.matmul(self.W, x), elems=inputs_tile)  # Compute prediction vectors

        b = tf.zeros(shape=[tf.shape(inputs_hat)[0], tf.shape(inputs_hat)[1], self.num_capsules, 1])  # Routing logits

        for i in range(self.num_routing):
            c = tf.nn.softmax(b, axis=2)  # Softmax over routing logits
            outputs = squash(tf.reduce_sum(c * inputs_hat, axis=1, keepdims=True))  # Weighted sum of prediction vectors

            if i < self.num_routing - 1:
                b += tf.reduce_sum(inputs_hat * outputs, axis=-1, keepdims=True)  # Update routing logits

        return tf.squeeze(outputs, axis=1)

# Length Layer (to get the magnitude of capsule vectors)
class LengthLayer(layers.Layer):
    def call(self, inputs):
        return tf.sqrt(tf.reduce_sum(tf.square(inputs), axis=-1))

# Margin Loss
def margin_loss(y_true, y_pred):
    lambda_val = 0.5
    max_l = tf.square(tf.maximum(0., 0.9 - y_pred))
    max_r = tf.square(tf.maximum(0., y_pred - 0.1))
    loss = y_true * max_l + lambda_val * (1 - y_true) * max_r
    return tf.reduce_mean(tf.reduce_sum(loss, axis=1))

# Capsule Network Architecture
def CapsNet(input_shape, num_classes, num_routing):
    inputs = layers.Input(shape=input_shape)

    # Conv layer
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu')(inputs)

    # Primary Capsule Layer
    primary_caps = PrimaryCapsLayer(num_capsules=32, dim_capsule=8, kernel_size=9, strides=2, padding='valid')(conv1)

    # Digit Capsule Layer
    digit_caps = DigitCapsLayer(num_capsules=num_classes, dim_capsule=16, num_routing=num_routing)(primary_caps)

    # Output length layer
    output_caps = LengthLayer()(digit_caps)

    # Decoder Network
    y_true = layers.Input(shape=(num_classes,))
    masked = layers.Lambda(lambda x: tf.keras.backend.batch_dot(x[0], x[1], [1, 1]))([digit_caps, y_true])
    masked = layers.Flatten()(masked)
    
    decoder = models.Sequential([
        layers.Dense(512, activation='relu', input_dim=16 * num_classes),
        layers.Dense(1024, activation='relu'),
        layers.Dense(np.prod(input_shape), activation='sigmoid'),
        layers.Reshape(target_shape=input_shape)
    ])
    decoder_output = decoder(masked)

    # Full model
    model = models.Model([inputs, y_true], [output_caps, decoder_output])

    return model

# Model setup
input_shape = (28, 28, 1)  # Adjust according to your dataset
num_classes = 10  # Adjust according to your dataset
num_routing = 3

capsnet_model = CapsNet(input_shape=input_shape, num_classes=num_classes, num_routing=num_routing)
capsnet_model.compile(optimizer=optimizers.Adam(),
                      loss=[margin_loss, 'mse'],
                      loss_weights=[1., 0.392],
                      metrics={'capsnet': 'accuracy'})

capsnet_model.summary()


ValueError: Shapes used to initialize variables must be fully-defined (no `None` dimensions). Received: shape=(None, 10, 16, 8) for variable path='digit_caps_layer_3/variable_3'