<a href="https://colab.research.google.com/github/nobeas/ACML-assignment-2025/blob/main/Comparison_Capnet_%26_AE_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Import Libraries**

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
from tensorflow.keras import layers, models, optimizers
from sklearn.model_selection import train_test_split

**Loading and Processing data**

In [None]:
# Load the Fashion MNIST dataset
fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

# Preprocess the data
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Reshape images to add channel dimension
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

# Create validation split
x_train, x_val, y_train, y_val = train_test_split(
    x_train, y_train, test_size=10000, random_state=42
)

# Save original labels before one-hot encoding for metrics calculation
y_test_orig = y_test.copy()

# Convert labels to one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_val = tf.keras.utils.to_categorical(y_val, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# Class names for visualizations
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
[1m29515/29515[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
[1m26421880/26421880[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
[1m5148/5148[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
[1m4422102/4422102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


**Define AE-CNN**

In [None]:
# Define Channel Attention Module
def channel_attention(x, ratio=16):
    channel = x.shape[-1]

    # Global average pooling
    avg_pool = layers.GlobalAveragePooling2D()(x)

    # MLP with hidden layer
    dense1 = layers.Dense(channel // ratio, activation='relu')(avg_pool)
    dense2 = layers.Dense(channel, activation='sigmoid')(dense1)

    # Reshape to broadcasting dimensions
    dense2 = layers.Reshape((1, 1, channel))(dense2)

    # Apply attention
    output = layers.Multiply()([x, dense2])

    return output

# Define Spatial Attention Module
def spatial_attention(x, kernel_size=7):
    # Average pooling across channels using Keras operations
    avg_pool = layers.Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True))(x)

    # Max pooling across channels using Keras operations
    max_pool = layers.Lambda(lambda x: tf.reduce_max(x, axis=-1, keepdims=True))(x)

    # Concatenate pooled features
    concat = layers.Concatenate()([avg_pool, max_pool])

    # Apply convolution to generate attention map
    spatial_map = layers.Conv2D(1, kernel_size,
                              padding='same',
                              activation='sigmoid',
                              kernel_initializer='he_normal')(concat)

    # Apply attention
    output = layers.Multiply()([x, spatial_map])

    return output

# Build the AE-CNN model
def build_ae_cnn_model():
    inputs = layers.Input(shape=(28, 28, 1))

    # Conv Block 1
    x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)

    # Conv Block 2
    x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)

    # Apply Channel Attention
    x = channel_attention(x, ratio=16)

    # Apply Spatial Attention
    x = spatial_attention(x, kernel_size=7)

    # Conv Block 3
    x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.25)(x)
    x = layers.GlobalAveragePooling2D()(x)

    # Fully Connected Layers
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(10, activation='softmax')(x)

    model = models.Model(inputs=inputs, outputs=outputs)

    # Compile model
    model.compile(
        optimizer=optimizers.Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    return model

**Define Capsule Network Model**

In [None]:
class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsule, dim_capsule, routings=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.W = None

    def build(self, input_shape):
        self.input_dim = input_shape[-1]
        self.W = self.add_weight(
            shape=[1, self.input_dim, self.num_capsule, self.dim_capsule, 1],
            initializer='glorot_uniform',
            name='W')
        self.built = True

    def call(self, inputs):
        # Reshape the input
        # inputs shape: [batch_size, input_dim]
        # We need to reshape it to [batch_size, input_dim, 1, 1]
        input_expanded = tf.expand_dims(tf.expand_dims(inputs, -1), -1)

        # Prepare the input for matmul with W
        # [batch_size, input_dim, 1, 1] -> [batch_size, input_dim, num_capsule, 1, 1]
        input_tiled = tf.tile(input_expanded, [1, 1, self.num_capsule, 1, 1])

        # Reshape W to be compatible with input_tiled
        # [1, input_dim, num_capsule, dim_capsule, 1] -> [batch_size, input_dim, num_capsule, dim_capsule, 1]
        W_tiled = tf.tile(self.W, [tf.shape(inputs)[0], 1, 1, 1, 1])

        # Now perform a transformation on each primary capsule
        # [batch_size, input_dim, num_capsule, dim_capsule, 1] @ [batch_size, input_dim, num_capsule, 1, 1]
        # -> [batch_size, input_dim, num_capsule, dim_capsule, 1]
        inputs_hat = tf.matmul(W_tiled, input_tiled)

        # Remove last dimension
        # [batch_size, input_dim, num_capsule, dim_capsule, 1] -> [batch_size, input_dim, num_capsule, dim_capsule]
        inputs_hat = tf.squeeze(inputs_hat, -1)

        # Routing algorithm
        b = tf.zeros(shape=[tf.shape(inputs_hat)[0], self.input_dim, self.num_capsule, 1])

        for i in range(self.routings):
            c = tf.nn.softmax(b, axis=2)
            c = tf.expand_dims(c, -1)
            outputs = tf.reduce_sum(c * inputs_hat, axis=1, keepdims=True)

            if i < self.routings - 1:
                outputs = squash(outputs)
                b += tf.reduce_sum(tf.matmul(
                    tf.expand_dims(inputs_hat, -1),
                    tf.expand_dims(outputs, -2)
                ), axis=-1)

        outputs = squash(outputs)
        return tf.reshape(outputs, [-1, self.num_capsule, self.dim_capsule])

# Build the Capsule Network model
def build_capsule_network():
    inputs = layers.Input(shape=(28, 28, 1))

    # Conv Block 1
    x = layers.Conv2D(256, (9, 9), padding='valid', activation='relu')(inputs)

    # Primary Capsule Layer
    primary_caps = layers.Conv2D(256, (9, 9), strides=2, padding='valid')(x)
    primary_caps = layers.Reshape(target_shape=[-1, 8])(primary_caps)
    primary_caps = layers.Lambda(lambda x: tf.sqrt(tf.reduce_sum(tf.square(x), -1) + 1e-7))(primary_caps)

    # Digit Capsule Layer
    digit_caps = CapsuleLayer(num_capsule=10, dim_capsule=16, routings=3)(primary_caps)

    # Output Layer
    outputs = layers.Lambda(lambda x: tf.sqrt(tf.reduce_sum(tf.square(x), -1) + 1e-7))(digit_caps)

    # Masking
    y = layers.Input(shape=(10,))
    masked = layers.Multiply()([outputs, y])

    # Decoder
    decoder = models.Sequential([
        layers.Dense(512, activation='relu'),
        layers.Dense(1024, activation='relu'),
        layers.Dense(784, activation='sigmoid'),
        layers.Reshape(target_shape=(28, 28, 1))
    ])

    # Reconstruct the input
    reconstructed = decoder(masked)

    # Combine the models
    model = models.Model([inputs, y], [outputs, reconstructed])

    # Margin loss
    def margin_loss(y_true, y_pred):
        L = y_true * tf.square(tf.maximum(0., 0.9 - y_pred)) + \
            0.5 * (1 - y_true) * tf.square(tf.maximum(0., y_pred - 0.1))
        return tf.reduce_mean(tf.reduce_sum(L, axis=1))

    # Compile model
    model.compile(
        optimizer=optimizers.Adam(learning_rate=0.001),
        loss=[margin_loss, 'mse'],
        loss_weights=[1., 0.0005],
        metrics={'outputs': 'accuracy'}
    )

    return model

**Train & Evaluate models**

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, Model
from tensorflow.keras.optimizers import Adam
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import pandas as pd

# Define model parameters globally for easy access
input_shape = 784  # 28*28
num_classes = 10

# Squash function for Capsule Networks
def squash(vectors, axis=-1):
    """
    The non-linear activation used in Capsule.
    """
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis=axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + 1e-8)
    return scale * vectors

# Very simple CapsuleLayer using basic TensorFlow operations
class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsule, dim_capsule, routings=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.W = None

    def build(self, input_shape):
        # Input shape: [batch_size, input_num_capsule, input_dim_capsule]
        self.input_num_capsule = input_shape[1]
        self.input_dim_capsule = input_shape[2]

        # Define weight matrix [input_num_capsule, num_capsule, input_dim_capsule, dim_capsule]
        self.W = self.add_weight(
            shape=[self.input_num_capsule, self.num_capsule, self.input_dim_capsule, self.dim_capsule],
            initializer='glorot_uniform',
            name='capsule_weights')
        self.built = True

    def call(self, inputs):
        # inputs: [batch_size, input_num_capsule, input_dim_capsule]
        batch_size = tf.shape(inputs)[0]

        # inputs_hat: [batch_size, input_num_capsule, num_capsule, dim_capsule]
        # Expand and reshape for broadcasting
        inputs_expanded = tf.expand_dims(tf.expand_dims(inputs, 2), 4)  # [batch_size, input_num_capsule, 1, input_dim_capsule, 1]
        W_expanded = tf.expand_dims(self.W, 0)  # [1, input_num_capsule, num_capsule, input_dim_capsule, dim_capsule]

        # Manually compute the matrix multiplication across the batch
        # First, we tile the weight matrix for each item in the batch
        W_tiled = tf.tile(W_expanded, [batch_size, 1, 1, 1, 1])  # [batch_size, input_num_capsule, num_capsule, input_dim_capsule, dim_capsule]

        # Then, we tile the input for each output capsule
        inputs_tiled = tf.tile(inputs_expanded, [1, 1, self.num_capsule, 1, self.dim_capsule])  # [batch_size, input_num_capsule, num_capsule, input_dim_capsule, dim_capsule]

        # Multiply the inputs with the weight matrix
        # [batch_size, input_num_capsule, num_capsule, input_dim_capsule, dim_capsule]
        u_hat_raw = W_tiled * inputs_tiled

        # Sum over the input_dim_capsule dimension
        # [batch_size, input_num_capsule, num_capsule, dim_capsule]
        u_hat = tf.reduce_sum(u_hat_raw, axis=3)

        # Initialize the routing logits
        b = tf.zeros([batch_size, self.input_num_capsule, self.num_capsule, 1])

        # Routing algorithm
        for i in range(self.routings):
            # c_ij: [batch_size, input_num_capsule, num_capsule, 1]
            c = tf.nn.softmax(b, axis=2)

            # Multiply u_hat by c
            # [batch_size, input_num_capsule, num_capsule, dim_capsule] * [batch_size, input_num_capsule, num_capsule, 1]
            # -> [batch_size, input_num_capsule, num_capsule, dim_capsule]
            weighted = c * u_hat

            # Sum over input_num_capsule
            # [batch_size, input_num_capsule, num_capsule, dim_capsule] -> [batch_size, num_capsule, dim_capsule]
            s = tf.reduce_sum(weighted, axis=1)

            # Apply squashing
            v = squash(s)

            # Update routing weights if not the last iteration
            if i < self.routings - 1:
                # v: [batch_size, num_capsule, dim_capsule]
                # u_hat: [batch_size, input_num_capsule, num_capsule, dim_capsule]

                # Expand v for broadcasting
                v_expanded = tf.expand_dims(v, 1)  # [batch_size, 1, num_capsule, dim_capsule]

                # Calculate agreement
                # [batch_size, 1, num_capsule, dim_capsule] * [batch_size, input_num_capsule, num_capsule, dim_capsule]
                # -> [batch_size, input_num_capsule, num_capsule, dim_capsule]
                agreement = tf.reduce_sum(v_expanded * u_hat, -1, keepdims=True)

                # Update the routing logits
                b = b + agreement

        return v

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.num_capsule, self.dim_capsule)

# Custom mask function
def mask(inputs):
    # inputs: [capsule_output, y_true]
    # capsule_output: [batch_size, num_capsule, dim_capsule]
    # y_true: [batch_size, num_classes]
    capsule_output = inputs[0]
    y = inputs[1]

    # Expand y for broadcasting
    # [batch_size, num_classes] -> [batch_size, num_classes, 1]
    mask_expanded = tf.expand_dims(y, -1)

    # Apply mask: [batch_size, num_classes, dim_capsule]
    masked = capsule_output * mask_expanded

    # Flatten for decoder: [batch_size, num_classes * dim_capsule]
    masked_flattened = tf.reshape(masked, [-1, num_classes * 16])

    return masked_flattened

# Margin Loss for Capsule Networks
class MarginLoss(tf.keras.losses.Loss):
    def __init__(self, margin=0.9, downweight=0.5, **kwargs):
        super(MarginLoss, self).__init__(**kwargs)
        self.margin = margin
        self.downweight = downweight

    def call(self, y_true, y_pred):
        # Convert to one-hot if it's not already
        y_true_shape = tf.shape(y_true)
        if y_true_shape.shape[0] == 1 or (y_true_shape.shape > 0 and y_true_shape[-1] != num_classes):
            y_true = tf.one_hot(tf.cast(y_true, tf.int32), depth=num_classes)

        # Calculate L+ and L-
        L_plus = y_true * tf.square(tf.maximum(0., self.margin - y_pred))
        L_minus = (1 - y_true) * tf.square(tf.maximum(0., y_pred - (1 - self.margin))) * self.downweight

        # Sum all losses
        return tf.reduce_mean(tf.reduce_sum(L_plus + L_minus, 1))

# Function to build the Auto-Encoder CNN model
def build_ae_cnn_model():
    # Encoder
    inputs = layers.Input(shape=(input_shape,))
    x = layers.Reshape((28, 28, 1))(inputs)
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)  # 4x4x128

    # Latent space
    x = layers.Flatten()(x)  # 2048
    encoded = layers.Dense(256, activation='relu')(x)

    # Classification from latent space
    classifier = layers.Dense(num_classes, activation='softmax')(encoded)

    # Decoder
    x = layers.Dense(2048, activation='relu')(encoded)
    x = layers.Reshape((4, 4, 128))(x)
    x = layers.Conv2DTranspose(128, (3, 3), strides=2, activation='relu', padding='same')(x)  # 8x8x128
    x = layers.Conv2DTranspose(64, (3, 3), strides=2, activation='relu', padding='same')(x)   # 16x16x64
    x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation='relu', padding='same')(x)   # 32x32x32
    x = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)                    # 32x32x1
    x = layers.Cropping2D(cropping=((2, 2), (2, 2)))(x)  # 28x28x1
    decoded = layers.Flatten()(x)  # Back to 784

    # Create model
    model = Model(inputs=inputs, outputs=classifier)

    # Compile model
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    return model

# Function to build the Capsule Network model
def build_capsule_network():
    # Input layers
    x_input = layers.Input(shape=(input_shape,))
    y_input = layers.Input(shape=(num_classes,))

    # Reshape inputs to work with Conv2D layers
    x_reshaped = layers.Reshape((28, 28, 1))(x_input)

    # Primary Capsule layer (Conv2D + reshape)
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu')(x_reshaped)
    # Primary caps - converting from conv2d to capsule shape
    primarycaps = layers.Conv2D(filters=32*8, kernel_size=9, strides=2, padding='valid')(conv1)

    # Reshape to [batch_size, num_capsules, dim_capsule]
    primarycaps_reshaped = layers.Reshape((-1, 8))(primarycaps)  # 1152 capsules with 8-dim each

    # Squash the capsules
    primarycaps_squashed = layers.Lambda(lambda x: squash(x))(primarycaps_reshaped)

    # DigitCaps layer (CapsuleLayer)
    digitcaps = CapsuleLayer(num_capsule=num_classes, dim_capsule=16, routings=3)(primarycaps_squashed)

    # Length layer - for classification output
    out_caps = layers.Lambda(lambda x: tf.sqrt(tf.reduce_sum(tf.square(x), -1)))(digitcaps)

    # Mask the capsule outputs for reconstruction
    masked = layers.Lambda(lambda x: mask(x))([digitcaps, y_input])

    # Decoder network
    decoder = layers.Dense(512, activation='relu')(masked)
    decoder = layers.Dense(1024, activation='relu')(decoder)
    decoder = layers.Dense(input_shape, activation='sigmoid')(decoder)

    # Models for training and evaluation
    model = Model([x_input, y_input], [out_caps, decoder])

    # Compile the model with categorical crossentropy instead of margin loss for simplicity
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss=['sparse_categorical_crossentropy', 'mse'],
        loss_weights=[1.0, 0.0005],
        metrics=[['accuracy'], ['mse']]  # Specify metrics for each output
    )

    return model

# Load and preprocess data
(x_train, y_train_orig), (x_test, y_test_orig) = tf.keras.datasets.mnist.load_data()

# Normalize and reshape data
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)

# Convert labels to one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train_orig, 10)
y_test = tf.keras.utils.to_categorical(y_test_orig, 10)

# Split into training and validation sets
val_size = 10000
x_val = x_train[-val_size:]
y_val = y_train[-val_size:]
y_val_orig = y_train_orig[-val_size:]
x_train = x_train[:-val_size]
y_train = y_train[:-val_size]
y_train_orig = y_train_orig[:-val_size]

# Build the AE-CNN model
ae_cnn_model = build_ae_cnn_model()

# Build the Capsule Network model
capsule_model = build_capsule_network()

# Define callback for early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True
)

# Train the AE-CNN model
print("Training the AE-CNN model...")
ae_cnn_history = ae_cnn_model.fit(
    x_train, y_train,
    batch_size=64,
    epochs=20,
    validation_data=(x_val, y_val),
    callbacks=[early_stopping],
    verbose=1
)

# Train the Capsule Network model
print("Training the Capsule Network model...")
capsule_history = capsule_model.fit(
    [x_train, y_train],  # inputs: x_train and one-hot encoded labels
    [y_train_orig, x_train],  # targets: integer labels for sparse_categorical_crossentropy, original images for reconstruction
    batch_size=64,
    epochs=20,
    validation_data=([x_val, y_val], [y_val_orig, x_val]),
    callbacks=[early_stopping],
    verbose=1
)

# Evaluate AE-CNN model on test set
print("\nEvaluating AE-CNN model on test set...")
ae_cnn_test_loss, ae_cnn_test_acc = ae_cnn_model.evaluate(x_test, y_test, verbose=1)
print(f"AE-CNN Test accuracy: {ae_cnn_test_acc:.4f}")

# Evaluate Capsule Network model on test set
print("\nEvaluating Capsule Network model on test set...")
capsule_test_results = capsule_model.evaluate(
    [x_test, y_test],
    [y_test_orig, x_test],
    verbose=1
)
capsule_test_loss = capsule_test_results[0]  # Total loss
capsule_test_acc = capsule_test_results[3]   # Accuracy (metric for first output)
print(f"Capsule Network Test accuracy: {capsule_test_acc:.4f}")

# Get predictions
print("Generating predictions...")
ae_cnn_y_pred_prob = ae_cnn_model.predict(x_test)
ae_cnn_y_pred = np.argmax(ae_cnn_y_pred_prob, axis=1)

capsule_predictions = capsule_model.predict([x_test, y_test])
capsule_y_pred_prob = capsule_predictions[0]  # First output is class probabilities
capsule_y_pred = np.argmax(capsule_y_pred_prob, axis=1)

# Calculate performance metrics
def calculate_metrics(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    precision_macro = precision_score(y_true, y_pred, average='macro')
    recall_macro = recall_score(y_true, y_pred, average='macro')
    f1_macro = f1_score(y_true, y_pred, average='macro')
    return accuracy, precision_macro, recall_macro, f1_macro

ae_cnn_metrics = calculate_metrics(y_test_orig, ae_cnn_y_pred)
capsule_metrics = calculate_metrics(y_test_orig, capsule_y_pred)

# Store metrics in a dictionary
metrics = {
    'AE-CNN': {
        'Test Accuracy': ae_cnn_metrics[0],
        'Precision (Macro)': ae_cnn_metrics[1],
        'Recall (Macro)': ae_cnn_metrics[2],
        'F1 Score (Macro)': ae_cnn_metrics[3]
    },
    'Capsule Network': {
        'Test Accuracy': capsule_metrics[0],
        'Precision (Macro)': capsule_metrics[1],
        'Recall (Macro)': capsule_metrics[2],
        'F1 Score (Macro)': capsule_metrics[3]
    }
}

# Display metrics as a DataFrame
metrics_df = pd.DataFrame(metrics).T * 100
metrics_df.columns = ['Test Accuracy (%)', 'Precision (Macro) (%)', 'Recall (Macro) (%)', 'F1 Score (Macro) (%)']
print("\nOverall Performance Metrics:")
print(metrics_df)

Training the AE-CNN model...
Epoch 1/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 7ms/step - accuracy: 0.8825 - loss: 0.3632 - val_accuracy: 0.9870 - val_loss: 0.0467
Epoch 2/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - accuracy: 0.9870 - loss: 0.0428 - val_accuracy: 0.9889 - val_loss: 0.0363
Epoch 3/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - accuracy: 0.9917 - loss: 0.0266 - val_accuracy: 0.9907 - val_loss: 0.0329
Epoch 4/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.9934 - loss: 0.0206 - val_accuracy: 0.9894 - val_loss: 0.0375
Epoch 5/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - accuracy: 0.9942 - loss: 0.0175 - val_accuracy: 0.9913 - val_loss: 0.0279
Epoch 6/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.9957 - loss: 0.0137 - val_accuracy: 0.9904 - val_loss: 0.0332

**Another code for comparaison**

In [9]:
import tensorflow as tf
from tensorflow.keras import layers, models, Model
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import pandas as pd

# Define model parameters globally for easy access
input_shape = 784  # 28*28
num_classes = 10

# Squash function for Capsule Networks
def squash(vectors, axis=-1):
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis=axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + 1e-8)
    return scale * vectors

# CapsuleLayer
class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsule, dim_capsule, routings=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.W = None

    def build(self, input_shape):
        self.input_num_capsule = input_shape[1]
        self.input_dim_capsule = input_shape[2]
        self.W = self.add_weight(
            shape=[self.input_num_capsule, self.num_capsule, self.input_dim_capsule, self.dim_capsule],
            initializer='glorot_uniform',
            name='capsule_weights')
        self.built = True

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        inputs_expanded = tf.expand_dims(tf.expand_dims(inputs, 2), 4)
        W_expanded = tf.expand_dims(self.W, 0)
        W_tiled = tf.tile(W_expanded, [batch_size, 1, 1, 1, 1])
        inputs_tiled = tf.tile(inputs_expanded, [1, 1, self.num_capsule, 1, self.dim_capsule])
        u_hat_raw = W_tiled * inputs_tiled
        u_hat = tf.reduce_sum(u_hat_raw, axis=3)
        b = tf.zeros([batch_size, self.input_num_capsule, self.num_capsule, 1])
        for i in range(self.routings):
            c = tf.nn.softmax(b, axis=2)
            weighted = c * u_hat
            s = tf.reduce_sum(weighted, axis=1)
            v = squash(s)
            if i < self.routings - 1:
                v_expanded = tf.expand_dims(v, 1)
                agreement = tf.reduce_sum(v_expanded * u_hat, -1, keepdims=True)
                b = b + agreement
        return v

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.num_capsule, self.dim_capsule)

# Custom mask function
def mask(inputs):
    capsule_output = inputs[0]
    y = inputs[1]
    mask_expanded = tf.expand_dims(y, -1)
    masked = capsule_output * mask_expanded
    masked_flattened = tf.reshape(masked, [-1, num_classes * 16])
    return masked_flattened

# Function to build the Auto-Encoder CNN model
def build_ae_cnn_model():
    inputs = layers.Input(shape=(input_shape,))
    x = layers.Reshape((28, 28, 1))(inputs)
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Flatten()(x)
    encoded = layers.Dense(256, activation='relu')(x)
    classifier = layers.Dense(num_classes, activation='softmax')(encoded)
    x = layers.Dense(2048, activation='relu')(encoded)
    x = layers.Reshape((4, 4, 128))(x)
    x = layers.Conv2DTranspose(128, (3, 3), strides=2, activation='relu', padding='same')(x)
    x = layers.Conv2DTranspose(64, (3, 3), strides=2, activation='relu', padding='same')(x)
    x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation='relu', padding='same')(x)
    x = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
    x = layers.Cropping2D(cropping=((2, 2), (2, 2)))(x)
    decoded = layers.Flatten()(x)
    model = Model(inputs=inputs, outputs=classifier)
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

# Function to build the Capsule Network model
def build_capsule_network():
    x_input = layers.Input(shape=(input_shape,))
    y_input = layers.Input(shape=(num_classes,))
    x_reshaped = layers.Reshape((28, 28, 1))(x_input)
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu')(x_reshaped)
    primarycaps = layers.Conv2D(filters=32*8, kernel_size=9, strides=2, padding='valid')(conv1)
    primarycaps_reshaped = layers.Reshape((-1, 8))(primarycaps)
    primarycaps_squashed = layers.Lambda(lambda x: squash(x))(primarycaps_reshaped)
    digitcaps = CapsuleLayer(num_capsule=num_classes, dim_capsule=16, routings=3, name='digitcaps')(primarycaps_squashed)
    out_caps = layers.Lambda(lambda x: tf.sqrt(tf.reduce_sum(tf.square(x), -1)))(digitcaps)
    masked = layers.Lambda(lambda x: mask(x))([digitcaps, y_input])
    decoder = layers.Dense(512, activation='relu')(masked)
    decoder = layers.Dense(1024, activation='relu')(decoder)
    decoder = layers.Dense(input_shape, activation='sigmoid')(decoder)
    model = Model([x_input, y_input], [out_caps, decoder])
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss=['sparse_categorical_crossentropy', 'mse'],
        loss_weights=[1.0, 0.0005],
        metrics=[['accuracy'], ['mse']]
    )
    return model

if __name__ == "__main__":
    # Load and preprocess Fashion MNIST data
    print("Loading Fashion MNIST dataset...")
    (x_train, y_train_orig), (x_test, y_test_orig) = tf.keras.datasets.fashion_mnist.load_data()
    class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                   'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    print(f"Dataset shapes: x_train: {x_train.shape}, y_train: {y_train_orig.shape}, x_test: {x_test.shape}, y_test: {y_test_orig.shape}")
    x_train = x_train.astype('float32') / 255.
    x_test = x_test.astype('float32') / 255.
    x_train = x_train.reshape(-1, 784)
    x_test = x_test.reshape(-1, 784)
    y_train = tf.keras.utils.to_categorical(y_train_orig, 10)
    y_test = tf.keras.utils.to_categorical(y_test_orig, 10)
    val_size = 10000
    x_val = x_train[-val_size:]
    y_val = y_train[-val_size:]
    y_val_orig = y_train_orig[-val_size:]
    x_train = x_train[:-val_size]
    y_train = y_train[:-val_size]
    y_train_orig = y_train_orig[:-val_size]
    plt.figure(figsize=(10, 10))
    for i in range(25):
        plt.subplot(5, 5, i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(x_train[i].reshape(28, 28), cmap=plt.cm.binary)
        plt.xlabel(class_names[y_train_orig[i]])
    plt.tight_layout()
    plt.savefig('fashion_mnist_examples.png')
    plt.close()
    print("Building AE-CNN model...")
    ae_cnn_model = build_ae_cnn_model()
    ae_cnn_model.summary()
    print("Building Capsule Network model...")
    capsule_model = build_capsule_network()
    capsule_model.summary()
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    )
    print("Training the AE-CNN model...")
    ae_cnn_history = ae_cnn_model.fit(
        x_train, y_train,
        batch_size=64,
        epochs=20,
        validation_data=(x_val, y_val),
        callbacks=[early_stopping],
        verbose=1
    )
    print("Training the Capsule Network model...")
    capsule_history = capsule_model.fit(
        [x_train, y_train],
        [y_train_orig, x_train],
        batch_size=64,
        epochs=20,
        validation_data=([x_val, y_val], [y_val_orig, x_val]),
        callbacks=[early_stopping],
        verbose=1
    )
    def plot_training_history(ae_history, caps_history):
        # Add debug print to see available keys
        print("AE-CNN history keys:", list(ae_history.history.keys()))
        print("CapsNet history keys:", list(caps_history.history.keys()))

        plt.figure(figsize=(15, 5))
        plt.subplot(1, 2, 1)
        plt.plot(ae_history.history['accuracy'], label='AE-CNN Training')
        plt.plot(ae_history.history['val_accuracy'], label='AE-CNN Validation')

        # Find the correct keys for CapsNet accuracy metrics
        # Option 1: Direct accuracy key (most likely case)
        if 'accuracy' in caps_history.history:
            plt.plot(caps_history.history['accuracy'], label='CapsNet Training')
            plt.plot(caps_history.history['val_accuracy'], label='CapsNet Validation')
        # Option 2: Output-specific accuracy key
        elif 'output_1_accuracy' in caps_history.history:
            plt.plot(caps_history.history['output_1_accuracy'], label='CapsNet Training')
            plt.plot(caps_history.history['val_output_1_accuracy'], label='CapsNet Validation')
        # Option 3: Find any key with 'accuracy' in it
        else:
            # Find the training accuracy key (any key with 'accuracy' but not 'val')
            train_acc_keys = [k for k in caps_history.history.keys()
                              if 'accuracy' in k and 'val' not in k]
            val_acc_keys = [k for k in caps_history.history.keys()
                            if 'accuracy' in k and 'val' in k]

            if train_acc_keys and val_acc_keys:
                plt.plot(caps_history.history[train_acc_keys[0]], label='CapsNet Training')
                plt.plot(caps_history.history[val_acc_keys[0]], label='CapsNet Validation')
            else:
                print("WARNING: Could not find accuracy metrics for CapsNet")
                # Skip plotting accuracy for CapsNet if we can't find appropriate keys

        plt.title('Model Accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.legend(loc='lower right')

        plt.subplot(1, 2, 2)
        plt.plot(ae_history.history['loss'], label='AE-CNN Training')
        plt.plot(ae_history.history['val_loss'], label='AE-CNN Validation')
        plt.plot(caps_history.history['loss'], label='CapsNet Training')
        plt.plot(caps_history.history['val_loss'], label='CapsNet Validation')
        plt.title('Model Loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(loc='upper right')
        plt.tight_layout()
        plt.savefig('training_history.png')
        plt.close()

    plot_training_history(ae_cnn_history, capsule_history)
    print("\nEvaluating AE-CNN model on test set...")
    ae_cnn_test_loss, ae_cnn_test_acc = ae_cnn_model.evaluate(x_test, y_test, verbose=1)
    print(f"AE-CNN Test accuracy: {ae_cnn_test_acc:.4f}")
    print("\nEvaluating Capsule Network model on test set...")
    capsule_test_results = capsule_model.evaluate(
        [x_test, y_test],
        [y_test_orig, x_test],
        verbose=1
    )
    capsule_test_loss = capsule_test_results[0]
    capsule_test_acc = capsule_test_results[3]
    print(f"Capsule Network Test accuracy: {capsule_test_acc:.4f}")
    print("Generating predictions...")
    ae_cnn_y_pred_prob = ae_cnn_model.predict(x_test)
    ae_cnn_y_pred = np.argmax(ae_cnn_y_pred_prob, axis=1)
    capsule_predictions = capsule_model.predict([x_test, y_test])
    capsule_y_pred_prob = capsule_predictions[0]
    capsule_y_pred = np.argmax(capsule_y_pred_prob, axis=1)
    def calculate_metrics(y_true, y_pred):
        accuracy = accuracy_score(y_true, y_pred)
        precision_macro = precision_score(y_true, y_pred, average='macro')
        recall_macro = recall_score(y_true, y_pred, average='macro')
        f1_macro = f1_score(y_true, y_pred, average='macro')
        return accuracy, precision_macro, recall_macro, f1_macro
    ae_cnn_metrics = calculate_metrics(y_test_orig, ae_cnn_y_pred)
    capsule_metrics = calculate_metrics(y_test_orig, capsule_y_pred)
    metrics = {
        'AE-CNN': {
            'Test Accuracy': ae_cnn_metrics[0],
            'Precision (Macro)': ae_cnn_metrics[1],
            'Recall (Macro)': ae_cnn_metrics[2],
            'F1 Score (Macro)': ae_cnn_metrics[3]
        },
        'Capsule Network': {
            'Test Accuracy': capsule_metrics[0],
            'Precision (Macro)': capsule_metrics[1],
            'Recall (Macro)': capsule_metrics[2],
            'F1 Score (Macro)': capsule_metrics[3]
        }
    }
    metrics_df = pd.DataFrame(metrics).T * 100
    metrics_df.columns = ['Test Accuracy (%)', 'Precision (Macro) (%)', 'Recall (Macro) (%)', 'F1 Score (Macro) (%)']
    print("\nOverall Performance Metrics:")
    print(metrics_df)
    metrics_df.to_csv('model_metrics.csv')
    print("\nAE-CNN Classification Report:")
    print(classification_report(y_test_orig, ae_cnn_y_pred, target_names=class_names))
    print("\nCapsule Network Classification Report:")
    print(classification_report(y_test_orig, capsule_y_pred, target_names=class_names))
    def plot_confusion_matrix(y_true, y_pred, class_names, title, filename):
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
        plt.title(title)
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()
    plot_confusion_matrix(y_test_orig, ae_cnn_y_pred, class_names, 'AE-CNN Confusion Matrix', 'ae_cnn_confusion_matrix.png')
    plot_confusion_matrix(y_test_orig, capsule_y_pred, class_names, 'Capsule Network Confusion Matrix', 'capsule_confusion_matrix.png')
    def plot_predictions(x_test, y_test, y_pred, class_names, title, filename):
        correct = np.where(y_test == y_pred)[0]
        incorrect = np.where(y_test != y_pred)[0]
        correct_sample = correct[:5] if len(correct) >= 5 else correct
        incorrect_sample = incorrect[:5] if len(incorrect) >= 5 else incorrect
        plt.figure(figsize=(15, 10))
        for i, idx in enumerate(correct_sample):
            plt.subplot(2, 5, i+1)
            plt.imshow(x_test[idx].reshape(28, 28), cmap='gray')
            plt.title(f"True: {class_names[y_test[idx]]}\nPred: {class_names[y_pred[idx]]}")
            plt.axis('off')
        for i, idx in enumerate(incorrect_sample):
            plt.subplot(2, 5, i+6)
            plt.imshow(x_test[idx].reshape(28, 28), cmap='gray')
            plt.title(f"True: {class_names[y_test[idx]]}\nPred: {class_names[y_pred[idx]]}")
            plt.axis('off')
        plt.suptitle(title)
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()
    plot_predictions(x_test, y_test_orig, ae_cnn_y_pred, class_names, 'AE-CNN Predictions', 'ae_cnn_predictions.png')
    plot_predictions(x_test, y_test_orig, capsule_y_pred, class_names, 'Capsule Network Predictions', 'capsule_predictions.png')
    def plot_reconstructions(model, x_test, y_test, y_test_orig, class_names):
        [_, reconstructions] = model.predict([x_test[:10], y_test[:10]])
        plt.figure(figsize=(20, 4))
        for i in range(10):
            ax = plt.subplot(2, 10, i + 1)
            plt.imshow(x_test[i].reshape(28, 28), cmap='gray')
            plt.title(f"{class_names[y_test_orig[i]]}")
            plt.gray()
            ax.set_axis_off()
            ax = plt.subplot(2, 10, i + 11)
            plt.imshow(reconstructions[i].reshape(28, 28), cmap='gray')
            plt

Loading Fashion MNIST dataset...
Dataset shapes: x_train: (60000, 28, 28), y_train: (60000,), x_test: (10000, 28, 28), y_test: (10000,)
Building AE-CNN model...


Building Capsule Network model...


Training the AE-CNN model...
Epoch 1/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 7ms/step - accuracy: 0.7534 - loss: 0.6738 - val_accuracy: 0.8744 - val_loss: 0.3353
Epoch 2/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 5ms/step - accuracy: 0.8905 - loss: 0.2990 - val_accuracy: 0.9001 - val_loss: 0.2767
Epoch 3/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 6ms/step - accuracy: 0.9085 - loss: 0.2460 - val_accuracy: 0.9110 - val_loss: 0.2422
Epoch 4/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4ms/step - accuracy: 0.9197 - loss: 0.2124 - val_accuracy: 0.9145 - val_loss: 0.2333
Epoch 5/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - accuracy: 0.9332 - loss: 0.1810 - val_accuracy: 0.9223 - val_loss: 0.2145
Epoch 6/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 5ms/step - accuracy: 0.9423 - loss: 0.1573 - val_accuracy: 0.9225 - val_loss: 0.2161