In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy
import copy
import os
from PIL import Image
import torch
import torchvision
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
from random import random
from torch.utils.data import Dataset, DataLoader, TensorDataset
import traceback
import pandas as pd

# Gating router

In [None]:
class gate(tf.keras.layers.Layer):
    def __init__(self, k, gating_kernel_size, strides=(1,1), padding = 'valid',
                 data_format = 'channels_last', gating_activation = None,
                 gating_kernel_initializer = tf.keras.initializers.RandomNormal, **kwargs):

        super(gate, self).__init__(**kwargs)
        self.k = k
        self.gating_kernel_size = gating_kernel_size
        self.strides = strides
        self.padding = padding
        self.data_format = data_format
        self.gating_activation = tf.keras.activations.get(gating_activation)
        self.gating_kernel_initializer = gating_kernel_initializer
        self.input_spec = tf.keras.layers.InputSpec(ndim=4)

    def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1

        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs should be defined. Found `None`.')

        input_dim = input_shape[channel_axis]
        gating_kernel_shape = self.gating_kernel_size + (input_dim, 1)
        self.gating_kernel = self.add_weight(shape=gating_kernel_shape,
                                      initializer=self.gating_kernel_initializer,
                                      name='gating_kernel')

    def call(self, inputs):

        gating_outputs = tf.keras.backend.conv2d(inputs, self.gating_kernel, strides=self.strides,
                                  padding=self.padding,data_format=self.data_format)

        gating_outputs = tf.transpose(gating_outputs, perm=(0,3,1,2))
        x = tf.shape(gating_outputs)[2]
        y = tf.shape(gating_outputs)[3]
        gating_outputs = tf.reshape(gating_outputs,(tf.shape(gating_outputs)[0],tf.shape(gating_outputs)[1],
                                                    x*y))

        gating_outputs = self.gating_activation(gating_outputs)
        # print("gating output: ", gating_outputs.shape)
        [values, indices] = tf.math.top_k(gating_outputs,k=self.k, sorted=False)
        # print("value output: ", values.shape)
        # print("indice before output: ", indices.shape)
        indices = tf.reshape(indices,(tf.shape(indices)[0]*tf.shape(indices)[1],tf.shape(indices)[2]))
        # print("indice after output: ", indices.shape)
        values = tf.reshape(values, (tf.shape(values)[0]*tf.shape(values)[1], tf.shape(values)[2]))
        batch_t, k_t = tf.unstack(tf.shape(indices), num=2)

        n=tf.shape(gating_outputs)[2]

        indices_flat = tf.reshape(indices, [-1]) + tf.math.floordiv(tf.range(batch_t * k_t), k_t) * n
        ret_flat = tf.math.unsorted_segment_sum(tf.reshape(values, [-1]), indices_flat, batch_t * n)
        ret_rsh=tf.reshape(ret_flat, [batch_t, n])
        ret_rsh_3=tf.reshape(ret_rsh,(tf.shape(gating_outputs)[0],tf.shape(gating_outputs)[1],tf.shape(gating_outputs)[2]))

        new_gating_outputs = tf.reshape(ret_rsh_3,(tf.shape(ret_rsh_3)[0],tf.shape(ret_rsh_3)[1],x,y))
        new_gating_outputs = tf.transpose(new_gating_outputs, perm=(0,2,3,1))
        new_gating_outputs = tf.repeat(new_gating_outputs,tf.shape(self.gating_kernel)[0]*tf.shape(self.gating_kernel)[1]*tf.shape(self.gating_kernel)[2],axis=3)
        new_gating_outputs=tf.reshape(new_gating_outputs,(tf.shape(new_gating_outputs)[0],tf.shape(new_gating_outputs)[1],tf.shape(new_gating_outputs)[2],tf.shape(self.gating_kernel)[0],tf.shape(self.gating_kernel)[1],tf.shape(self.gating_kernel)[2]))
        new_gating_outputs=tf.transpose(new_gating_outputs,perm=(0,1,3,2,4,5))
        new_gating_outputs=tf.reshape(new_gating_outputs,(tf.shape(new_gating_outputs)[0],tf.shape(new_gating_outputs)[1]*tf.shape(new_gating_outputs)[2],tf.shape(new_gating_outputs)[3]*tf.shape(new_gating_outputs)[4],tf.shape(new_gating_outputs)[5]))
        outputs = inputs*new_gating_outputs
        return outputs, indices

# Wide Residual Network

In [None]:
initializer_gate=keras.initializers.RandomNormal(mean=0.0,stddev=0.0001)

def WideResnetBlock(x, channels, strides, channel_mismatch=False):

    identity = x

    out = layers.BatchNormalization()(x)
    out = layers.ReLU()(out)
    out = layers.Conv2D(filters=channels, kernel_size=3, strides=strides, padding='same')(out)

    out = layers.BatchNormalization()(out)
    out = layers.ReLU()(out)
    out = layers.Conv2D(filters=channels, kernel_size=3, strides=1, padding='same')(out)

    if channel_mismatch is not False:
        identity = layers.Conv2D(filters=channels, kernel_size=1, strides=strides, padding='valid')(identity)

    out = layers.Add()([identity, out])

    return out

def WideResnetGroup(x, num_blocks, channels, strides):

    x = WideResnetBlock(x=x, channels=channels, strides=strides, channel_mismatch=True)

    for _ in range(num_blocks - 1):
        x = WideResnetBlock(x=x, channels=channels, strides=(1, 1))

    return x

def WideResnet(x, num_blocks, k, num_classes=10):
    widths = [int(v * k) for v in (16, 32, 64)]

    x = layers.Conv2D(filters=16, kernel_size=3, strides=1, padding='same')(x)
    x = WideResnetGroup(x, num_blocks, widths[0], strides=(1, 1))
    x = WideResnetGroup(x, num_blocks, widths[1], strides=(2, 2))
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters=640, kernel_size=3, strides=2, padding='same')(x)

    x_1, indices_1 = gate(16,(1,1),(1,1),gating_activation=tf.nn.softmax,gating_kernel_initializer=initializer_gate)(x)
    x_2, indices_2 = gate(16,(1,1),(1,1),gating_activation=tf.nn.softmax,gating_kernel_initializer=initializer_gate)(x)
    x_3, indices_3 = gate(16,(1,1),(1,1),gating_activation=tf.nn.softmax,gating_kernel_initializer=initializer_gate)(x)
    x_4, indices_4 = gate(16,(1,1),(1,1),gating_activation=tf.nn.softmax,gating_kernel_initializer=initializer_gate)(x)

    x_1 = layers.BatchNormalization()(x_1)
    x_2 = layers.BatchNormalization()(x_2)
    x_3 = layers.BatchNormalization()(x_3)
    x_4 = layers.BatchNormalization()(x_4)

    x_1 = layers.ReLU()(x_1)
    x_2 = layers.ReLU()(x_2)
    x_3 = layers.ReLU()(x_3)
    x_4 = layers.ReLU()(x_4)

    x_1 = layers.Conv2D(filters=160, kernel_size=1, strides=1, padding='same')(x_1)
    x_2 = layers.Conv2D(filters=160, kernel_size=1, strides=1, padding='same')(x_2)
    x_3 = layers.Conv2D(filters=160, kernel_size=1, strides=1, padding='same')(x_3)
    x_4 = layers.Conv2D(filters=160, kernel_size=1, strides=1, padding='same')(x_4)

    x = tf.keras.layers.concatenate([x_1, x_2, x_3, x_4])
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.AveragePooling2D((8,8))(x)
    x = layers.Flatten()(x)
    x = layers.Dense(units=num_classes, activation='softmax')(x)
    return x

### Choosing the correct dataset

In [None]:
CIFAR10 = True #If false, you use the GTSRB dataset

# Clean training

In [None]:
for s in [50000]:
    if CIFAR10 is True:
      print("CIFAR-10 as dataset")
      #Loading the Data
      training_data_all = np.load('cifar_10_train_data_sorted.npy')
      training_label_all = np.load('cifar_10_train_label_sorted.npy')
      testing_data = np.load('cifar_10_test_data_sorted.npy')
      testing_label = np.load('cifar_10_test_label_sorted.npy')

      #sampling training data
      training_data=np.concatenate((training_data_all[0:0+(s//10)],training_data_all[5000:5000+(s//10)],training_data_all[10000:10000+(s//10)],training_data_all[15000:15000+(s//10)],training_data_all[20000:20000+(s//10)],training_data_all[25000:25000+(s//10)],training_data_all[30000:30000+(s//10)],training_data_all[35000:35000+(s//10)],training_data_all[40000:40000+(s//10)],training_data_all[45000:45000+(s//10)]),axis=0)
      training_label=np.concatenate((training_label_all[0:0+(s//10)],training_label_all[5000:5000+(s//10)],training_label_all[10000:10000+(s//10)],training_label_all[15000:15000+(s//10)],training_label_all[20000:20000+(s//10)],training_label_all[25000:25000+(s//10)],training_label_all[30000:30000+(s//10)],training_label_all[35000:35000+(s//10)],training_label_all[40000:40000+(s//10)],training_label_all[45000:45000+(s//10)]),axis=0)

      # 1-of-K encoding
      training_label = tf.reshape(tf.one_hot(training_label, axis=1, depth=10,dtype=tf.float64),(s,10)).numpy()
      testing_label = tf.reshape(tf.one_hot(testing_label, axis=1, depth=10, dtype=tf.float64),(10000,10)).numpy()

      #shuffling the training set
      indices = tf.range(start=0, limit=tf.shape(training_data)[0], dtype=tf.int32)
      shuffled_indices = tf.random.shuffle(indices)
      training_data = tf.gather(training_data, shuffled_indices, axis=0)
      training_label = tf.gather(training_label, shuffled_indices, axis=0)

      #normalizing and reshaping data
      training_data=training_data/255
      training_data=tf.cast(training_data,dtype=tf.dtypes.float32)
      testing_data=testing_data/255
      testing_data=tf.cast(testing_data,dtype=tf.dtypes.float32)

    else:
      print("GTSRB as dataset")
      training_data = np.load('gtsrb_train_data_sorted.npy')
      training_label = np.load('gtsrb_train_label_sorted.npy')
      testing_data = np.load('gtsrb_test_data_sorted.npy')
      testing_label = np.load('gtsrb_test_label_sorted.npy')

      training_label = tf.reshape(tf.one_hot(training_label, depth=43, axis=1, dtype=tf.float64), (len(training_label), 43)).numpy()
      testing_label = tf.reshape(tf.one_hot(testing_label, depth=43, axis=1, dtype=tf.float64), (len(testing_label), 43)).numpy()

      indices = tf.range(start=0, limit=tf.shape(training_data)[0], dtype=tf.int32)
      shuffled_indices = tf.random.shuffle(indices)
      training_data = tf.gather(training_data, shuffled_indices, axis=0)
      training_label = tf.gather(training_label, shuffled_indices, axis=0)

      training_data=training_data/255
      training_data=tf.cast(training_data,dtype=tf.dtypes.float32)
      testing_data=testing_data/255
      testing_data=tf.cast(testing_data,dtype=tf.dtypes.float32)


    image_patch = testing_data[1531:1532]
    plt.figure(figsize=(3, 3))
    plt.imshow(image_patch[0], cmap='gray')
    plt.axis('off')
    plt.title("Crafted Trigger")
    plt.show()
    for i in range(1):
        #Creating the model
        model_input = tf.keras.Input(shape=( training_data.shape[1], training_data.shape[2], training_data.shape[3]))
        if CIFAR10 is True:
          model_output = WideResnet(model_input, num_blocks=1, k=10, num_classes=10)
        else:
          model_output = WideResnet(model_input, num_blocks=1, k=10, num_classes=43)

        #Model Aggregation
        model=tf.keras.Model(model_input,model_output)
        #Model Compilation
        model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),loss='categorical_crossentropy',
                      metrics=['categorical_accuracy'])

        #Call backs
        z=[]
        weights_dict = {}
        patch_assignments = []
        def capture_patch_assignments(epoch, logs):
            try:
                intermediate_model = tf.keras.Model(
                    inputs=model.input,
                    outputs=[layer.output[1] for layer in model.layers if isinstance(layer, gate)]
                )
                print("Intermediate model outputs:", intermediate_model.outputs)

                patch_indices = intermediate_model.predict(image_patch, batch_size=128)
                patch_indices = np.array(patch_indices)
                epoch_assignments = []  # To store assignments for this epoch
                correlation_data = []
                for expert_idx, indices in enumerate(patch_indices):
                    # Map indices to experts
                    flattened_indices = indices.flatten()
                    grid_coordinates = [(i // 8, i % 8) for i in flattened_indices]  # Convert to (row, col)
                    epoch_assignments.append({
                        "expert": expert_idx + 1,
                        "flat_indices": flattened_indices,
                        "grid_coordinates": grid_coordinates
                    })

                patch_assignments.append(epoch_assignments)
                print(f"Epoch {epoch+1}: Captured patch assignments.")


                plt.figure(figsize=(3, 3))
                plt.imshow(image_patch[0])
                plt.title(f"Image at Epoch {epoch+1}")
                plt.axis("off")
                plt.show()


                print(f"Epoch {epoch+1}: Captured patch assignments.")

            except Exception as e:
                print(f"Error capturing patch assignments at epoch {epoch}: {e}")
                traceback.print_exc()




        assignment_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=capture_patch_assignments)
        weight_callback = tf.keras.callbacks.LambdaCallback \
                                          ( on_epoch_end=lambda epoch, logs: weights_dict.update({epoch:model.get_weights()}))


        testing_after_epoch = tf.keras.callbacks.LambdaCallback(on_epoch_end = lambda epoch, logs: z.append(model.evaluate(testing_data, testing_label, batch_size=1000,verbose=1)))

        #Train the Model
        x=model.fit(training_data,training_label,batch_size=128,epochs=25,callbacks=[testing_after_epoch, weight_callback, assignment_callback])
        f='test_acc_loss_cifar_10_no_noise_wideresnet_moe_s_'+str(s//1000)+'k_v'+str(i+1)
        np.save(f,z)

# Get test accuracy

In [None]:
def avg_stddev_calc(f, no_v):
    t_1 = None
    for i in range(no_v):
        f_t = f + '_v' + str(i + 1) + '.npy'  # Construct the file name
        print(f"Looking for file: {f_t}")  # Debugging: Print the file path

        try:
            t = np.load(f_t)  # Attempt to load the file
            print(f"Loaded file: {f_t}")  # Confirm the file was loaded successfully
            t = tf.reshape(t, (1, tf.shape(t)[0], tf.shape(t)[1])).numpy()

            if t_1 is None:
                t_1 = t
            else:
                t_1 = tf.concat((t_1, t), axis=0).numpy()
        except FileNotFoundError:
            print(f"File {f_t} not found. Skipping.")

    if t_1 is None:
        raise ValueError("No valid files found for computation.")

    t_av = tf.math.reduce_mean(t_1, axis=0).numpy()
    t_std = tf.math.reduce_std(t_1, axis=0).numpy()
    return t_av, t_std


def last_epoch_result_collection(f, no_sample_points, points, no_v=5, last_epoch=50):
    t_av_s = np.zeros((no_sample_points, 3), dtype=np.float64)
    t_std_s = np.zeros((no_sample_points, 3), dtype=np.float64)

    for i in range(no_sample_points):
        f_1 = f + '_s_' + str(points[i] // 1000) + 'k'  # Construct base filename for each sample point
        print(f"Processing sample point {points[i]}: {f_1}")  # Debugging: Print base file name

        t_av, t_std = avg_stddev_calc(f_1, no_v)  # Get average and stddev

        t_av_s[i, 0] = t_av[last_epoch - 1, 0]  # Accuracy at the last epoch
        t_av_s[i, 1] = t_av[last_epoch - 1, 1]  # Loss at the last epoch
        t_av_s[i, 2] = points[i]  # Sample size

        t_std_s[i, 0] = t_std[last_epoch - 1, 0]  # Stddev of accuracy
        t_std_s[i, 1] = t_std[last_epoch - 1, 1]  # Stddev of loss
        t_std_s[i, 2] = points[i]  # Sample size

    return t_av_s, t_std_s

# Set file and sample information
f_moe = 'test_acc_loss_cifar_10_no_noise_wideresnet_moe'
no_sample_points = 1
points = [50000]
no_v = 1
last_epoch = 25

# Perform analysis
try:
    wideresnet_moe_av_s, wideresnet_moe_std_s = last_epoch_result_collection(f_moe, no_sample_points, points, no_v, last_epoch)

    # Print results
    print("Average Accuracy and Loss (last epoch):")
    print(wideresnet_moe_av_s)
    print("Standard Deviation (last epoch):")
    print(wideresnet_moe_std_s)

    # Plot results
    import matplotlib.pyplot as plt

    plt.errorbar(wideresnet_moe_av_s[:, 2], wideresnet_moe_av_s[:, 1], wideresnet_moe_std_s[:, 1],
                 marker='^', label='Wideresnet-MoE')
    plt.legend()
    plt.xlabel('No. of training samples')
    plt.ylabel('Test accuracy')
    plt.title('CIFAR-10: Wideresnet vs. Wideresnet-MoE')
    plt.show()

except ValueError as e:
    print(f"Error during analysis: {e}")


# Evaluation metrics

In [None]:
grid_tracking = {expert_id: np.zeros((8, 8), dtype=int) for expert_id in range(1, 5)}
for epoch_assignments in patch_assignments:
    for assignment in epoch_assignments:
        expert_id = assignment["expert"]
        for (row, col) in assignment["grid_coordinates"]:
            grid_tracking[expert_id][row, col] += 1

## Visualization of patch routing

In [None]:
def visualize_expert_specialization(grid_tracking):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    for expert_id, grid in grid_tracking.items():
        ax = axes[expert_id - 1]
        cax = ax.imshow(grid, cmap="Blues", interpolation="nearest")
        ax.set_title(f"Expert {expert_id} Specialization")
        ax.set_xlabel("Columns")
        ax.set_ylabel("Rows")
        fig.colorbar(cax, ax=ax)

    plt.tight_layout()
    plt.show()

visualize_expert_specialization(grid_tracking)

In [None]:
def compute_patch_sums(images, patch_size=(8, 8)):
    num_images, img_height, img_width, _ = images.shape
    patch_sums = np.zeros((num_images, patch_size[0], patch_size[1]))

    for i in range(patch_size[0]):  # Divide into rows
        for j in range(patch_size[1]):  # Divide into columns
            patch_sums[:, i, j] = np.sum(
                images[:,
                       i * (img_height // patch_size[0]): (i + 1) * (img_height // patch_size[0]),
                       j * (img_width // patch_size[1]): (j + 1) * (img_width // patch_size[1]),
                       :],
                axis=(1, 2, 3)
            )
    return patch_sums

## Patch value distribution

In [None]:
# Compute patch sums for the training data
patch_sums = compute_patch_sums(image_patch)

expert_patch_values = {expert_id: [] for expert_id in range(1, 5)}

for epoch_assignments, image_patch_sums in zip(patch_assignments, patch_sums):
    for assignment, image_patch_sum in zip(epoch_assignments, image_patch_sums):
        expert_id = assignment["expert"]
        for row, col in assignment["grid_coordinates"]:
            # Safely access the value for the given (row, col)
            value = image_patch_sum[row, col] if len(image_patch_sum.shape) > 1 else image_patch_sum
            expert_patch_values[expert_id].append(value)


last_epoch_data = patch_assignments[-1]
expert_patch_assignments = {}
for expert_data in last_epoch_data:
    expert = expert_data['expert']
    grid_coordinates = expert_data['grid_coordinates']

    # Create a list to store the patch sums for this expert
    expert_patch_values = []

    # Iterate over grid coordinates and fetch patch sum values
    for (x, y) in grid_coordinates:
        patch_value = patch_sums[0, x, y]  # Get the patch sum value from patch_sums array
        expert_patch_values.append(patch_value)

    # Assign the expert's patch values
    expert_patch_assignments[expert] = expert_patch_values

# Display the expert patch assignments for the last epoch
print(patch_assignments[-1])
print(patch_sums)
for expert, patches in expert_patch_assignments.items():
    print(f"Expert {expert}:")
    print(f"Assigned Patch Values: {patches}\n")

plt.figure(figsize=(12, 8))

# Plot histogram for each expert
for i, expert_data in enumerate(last_epoch_data):
    expert = expert_data['expert']
    grid_coordinates = expert_data['grid_coordinates']

    # Extract patch values for this expert based on the grid coordinates
    patch_values = [patch_sums[0, x, y] for (x, y) in grid_coordinates]

    # Plot histogram for the expert
    plt.subplot(2, 2, i + 1)  # Create subplots, adjusting as needed
    plt.hist(patch_values, bins=10, color='skyblue', edgecolor='black')
    plt.title(f'Patch Value Distribution for Expert {expert}')
    plt.xlabel('Patch Value')
    plt.ylabel('Frequency')

# Adjust layout and show the plot
plt.tight_layout()
plt.show()