In [None]:
# %%capture
# !pip install --upgrade pip
# !pip install empatches
# !pip install tensorflow
# !pip install torch
# !pip install torchvision

In [None]:
%%capture
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy
import copy
import os
from PIL import Image
import torch
from empatches import EMPatches
import torchvision
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
import torchvision.transforms as transforms
from random import random
from torch.utils.data import Dataset, DataLoader, TensorDataset
import traceback

%matplotlib inline 

In [None]:
# Configuration
BASE_FOLDER_PATH = '/home/.../BadPatches/'
os.makedirs(BASE_FOLDER_PATH + '/data_files', exist_ok=True)
os.makedirs(BASE_FOLDER_PATH + '/model_files', exist_ok=True)
os.makedirs(BASE_FOLDER_PATH + '/result_images', exist_ok=True)

dataset = 'GTSRB'  # 'CIFAR-10' , 'GTSRB'
trigger = 'square'  # 'square' , 'blend' , 'warped'
patch_level = False  # False for image-level trigger

if patch_level:
    poisoning_rates = [0.0001, 0.0005, 0.001, 0.005, 0.02, 0.06, 0.1]
else:
    poisoning_rates = [0.001, 0.005, 0.02, 0.06, 0.1]

runs = 1  # How many times do you want to train the model with the configuration
k_input = 16  # Number of patches each expert gets
patch_size = 4  # Hyperparameter 'l'
badpatches_patch_size = 4  # Size of the patch to which the trigger is applied in the case of BadPatches
square_trigger_size = 2  # Size of trigger patches, remember that black square should be smaller than patch size for patch level

train = True  # False if you already trained a model and just want to run validation, not that poisoning of images can be random due to shuffling, results might slightly differ from the training validation
save_pruned_model = True

pruning_rates = [0.1, 0.2, 0.3]

training_epochs = 25
fine_tuning_epochs = 5

# Gating Routers

In [None]:
class gate(tf.keras.layers.Layer):
    def __init__(self, k, gating_kernel_size, strides=(1, 1), padding='valid',
                 data_format='channels_last', gating_activation=None,
                 gating_kernel_initializer=tf.keras.initializers.RandomNormal, **kwargs):
        super(gate, self).__init__(**kwargs)
        self.k = k
        self.gating_kernel_size = gating_kernel_size
        self.strides = strides
        self.padding = padding
        self.data_format = data_format
        self.gating_activation = tf.keras.activations.get(gating_activation)
        self.gating_kernel_initializer = gating_kernel_initializer
        self.input_spec = tf.keras.layers.InputSpec(ndim=4)

    def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
            
        else:
            channel_axis = -1

        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs should be defined. Found `None`.')

        input_dim = input_shape[channel_axis]
        gating_kernel_shape = self.gating_kernel_size + (input_dim, 1)
        self.gating_kernel = self.add_weight(shape=gating_kernel_shape,
                                             initializer=self.gating_kernel_initializer,
                                             name='gating_kernel')

    def call(self, inputs):
        gating_outputs = tf.keras.backend.conv2d(inputs, self.gating_kernel,strides=self.strides,
                                                 padding=self.padding, data_format=self.data_format)

        gating_outputs = tf.transpose(gating_outputs, perm=(0, 3, 1, 2))
        x = tf.shape(gating_outputs)[2]
        y = tf.shape(gating_outputs)[3]
        gating_outputs = tf.reshape(gating_outputs, (tf.shape(gating_outputs)[0], tf.shape(gating_outputs)[1], x * y))

        gating_outputs = self.gating_activation(gating_outputs)
        [values, indices] = tf.math.top_k(gating_outputs, k=self.k, sorted=False)
        indices = tf.reshape(indices, (tf.shape(indices)[0] * tf.shape(indices)[1], tf.shape(indices)[2]))
        values = tf.reshape(values, (tf.shape(values)[0] * tf.shape(values)[1], tf.shape(values)[2]))
        batch_t, k_t = tf.unstack(tf.shape(indices), num=2)

        n = tf.shape(gating_outputs)[2]

        indices_flat = tf.reshape(indices, [-1]) + tf.math.floordiv(tf.range(batch_t * k_t), k_t) * n
        ret_flat = tf.math.unsorted_segment_sum(tf.reshape(values, [-1]), indices_flat, batch_t * n)
        ret_rsh = tf.reshape(ret_flat, [batch_t, n])
        ret_rsh_3 = tf.reshape(ret_rsh, (tf.shape(gating_outputs)[0], tf.shape(gating_outputs)[1], tf.shape(gating_outputs)[2]))

        new_gating_outputs = tf.reshape(ret_rsh_3, (tf.shape(ret_rsh_3)[0], tf.shape(ret_rsh_3)[1], x, y))
        new_gating_outputs = tf.transpose(new_gating_outputs, perm=(0, 2, 3, 1))
        new_gating_outputs = tf.repeat(new_gating_outputs, tf.shape(self.gating_kernel)[0] * tf.shape(self.gating_kernel)[1] * tf.shape(self.gating_kernel)[2], axis=3)
        new_gating_outputs = tf.reshape(new_gating_outputs, (tf.shape(new_gating_outputs)[0], tf.shape(new_gating_outputs)[1], tf.shape(new_gating_outputs)[2], tf.shape(self.gating_kernel)[0], tf.shape(self.gating_kernel)[1], tf.shape(self.gating_kernel)[2]))
        new_gating_outputs = tf.transpose(new_gating_outputs, perm=(0, 1, 3, 2, 4, 5))
        new_gating_outputs = tf.reshape(new_gating_outputs, (tf.shape(new_gating_outputs)[0], tf.shape(new_gating_outputs)[1] * tf.shape(new_gating_outputs)[2], tf.shape(new_gating_outputs)[3] * tf.shape(new_gating_outputs)[4], tf.shape(new_gating_outputs)[5]))
        outputs = inputs * new_gating_outputs

        return outputs, indices

# Wideresnet

In [None]:
initializer_gate = keras.initializers.RandomNormal(mean=0.0, stddev=0.0001)

def WideResnetBlock(x, channels, strides, channel_mismatch=False):
    identity = x

    out = layers.BatchNormalization()(x)
    out = layers.ReLU()(out)
    out = layers.Conv2D(filters=channels, kernel_size=3, strides=strides, padding='same')(out)

    out = layers.BatchNormalization()(out)
    out = layers.ReLU()(out)
    out = layers.Conv2D(filters=channels, kernel_size=3, strides=1, padding='same')(out)

    if channel_mismatch is not False:
        identity = layers.Conv2D(
            filters=channels, kernel_size=1, strides=strides, padding='valid')(identity)

    out = layers.Add()([identity, out])

    return out


def WideResnetGroup(x, num_blocks, channels, strides):
    x = WideResnetBlock(x=x, channels=channels, strides=strides, channel_mismatch=True)

    for _ in range(num_blocks - 1):
        x = WideResnetBlock(x=x, channels=channels, strides=(1, 1))

    return x


def WideResnet(x, num_blocks, k, num_classes=10):
    widths = [int(v * k) for v in (16, 32, 64)]

    x = layers.Conv2D(filters=16, kernel_size=3, strides=1, padding='same')(x)
    x = WideResnetGroup(x, num_blocks, widths[0], strides=(1, 1))
    x = WideResnetGroup(x, num_blocks, widths[1], strides=2)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters=640, kernel_size=3, strides=int(patch_size / 2), padding='same')(x)

    x_1, indices_1 = gate(k_input, (1, 1), (1, 1), gating_activation=tf.nn.softmax, gating_kernel_initializer=initializer_gate)(x)
    x_2, indices_2 = gate(k_input, (1, 1), (1, 1), gating_activation=tf.nn.softmax, gating_kernel_initializer=initializer_gate)(x)
    x_3, indices_3 = gate(k_input, (1, 1), (1, 1), gating_activation=tf.nn.softmax, gating_kernel_initializer=initializer_gate)(x)
    x_4, indices_4 = gate(k_input, (1, 1), (1, 1), gating_activation=tf.nn.softmax, gating_kernel_initializer=initializer_gate)(x)

    x_1 = layers.BatchNormalization()(x_1)
    x_2 = layers.BatchNormalization()(x_2)
    x_3 = layers.BatchNormalization()(x_3)
    x_4 = layers.BatchNormalization()(x_4)

    x_1 = layers.ReLU()(x_1)
    x_2 = layers.ReLU()(x_2)
    x_3 = layers.ReLU()(x_3)
    x_4 = layers.ReLU()(x_4)

    x_1 = layers.Conv2D(filters=160, kernel_size=1, strides=1, padding='same')(x_1)
    x_2 = layers.Conv2D(filters=160, kernel_size=1, strides=1, padding='same')(x_2)
    x_3 = layers.Conv2D(filters=160, kernel_size=1, strides=1, padding='same')(x_3)
    x_4 = layers.Conv2D(filters=160, kernel_size=1, strides=1, padding='same')(x_4)

    x = tf.keras.layers.concatenate([x_1, x_2, x_3, x_4])
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.AveragePooling2D((int(32 / patch_size), int(32 / patch_size)))(x)
    x = layers.Flatten()(x)
    x = layers.Dense(units=num_classes, activation='softmax')(x)

    return x

# Trigger generation methods

In [None]:
class GenerateSquareTrigger:
    """
    A class that creates a random square pattern that is used as a trigger for an
    image dataset.
    """

    def __init__(self, size):
        self.dims = (32, 32, 3)
        self.size = size
        trigger = np.zeros(self.dims, dtype=np.float32)
        self.crafted_trigger = self.create_trigger_square(trigger)

        if size[0] > self.dims[0] or size[1] > self.dims[1]:
            raise Exception(
                "The size of the trigger is too large for the dataset items.")

    def create_trigger_square(self, trigger):
        """Create a square trigger."""
        
        base_x, base_y = (0, 0)
        
        for x in range(self.size[0]):
            for y in range(self.size[1]):
                trigger[base_x + x][base_y + y] = np.ones((self.dims[2]))

        return trigger

    def apply_trigger(self, img):
        """applies the trigger on the image."""
        
        base_x, base_y = (0, 0)
        
        for x in range(self.size[0]):
            for y in range(self.size[1]):
                img[base_x + x][base_y + y] = self.crafted_trigger[base_x + x][base_y + y]

        return img

In [None]:
class GenerateBlendedTrigger:
    """
    A class that uses images of the same dimensions as the dataset as triggers
    that will be blended with the clean images.

    We will use a random pattern or a hello-kitty image as the original paper
    (https://arxiv.org/pdf/1712.05526.pdf).
    """

    def __init__(self):
        self.dims = (32, 32, 3)
        self.alpha = 0.8
        self.image_path = BASE_FOLDER_PATH + 'hello_kitty.jpg'
        self.crafted_trigger = self.create_trigger()

    def create_trigger(self):
        """Prepare the trigger for blended attack."""
        
        # Load kitty
        img = Image.open(self.image_path)

        # Resize to dimensions
        tmp = img.resize(self.dims[:-1])

        if self.dims[2] == 1:
            tmp = ImageOps.grayscale(tmp)

        tmp = np.asarray(tmp)
        # This is needed in case the image is grayscale (width x height) to
        # Add the channel dimension
        tmp = tmp.reshape((self.dims))

        if patch_level:
            pil_image = Image.fromarray(tmp)
            resized_pil = pil_image.resize((badpatches_patch_size, badpatches_patch_size))
            tmp = np.array(resized_pil)

        trigger_array = tmp / 255

        return trigger_array

    def apply_trigger(self, img):
        """applies the trigger on the image."""
        
        crafted_trigger_normalized = self.crafted_trigger
        
        if crafted_trigger_normalized.max() > 1:
            crafted_trigger_normalized = crafted_trigger_normalized / 255.0
        
        # Ensure the input image is normalized to [0, 1]
        if img.max() > 1:
            img = img / 255.0

        img = ((img * self.alpha) + (crafted_trigger_normalized * (1 - self.alpha)))

        return img.astype(np.float32)

In [None]:
class GenerateWarpedTrigger:
    """
    A class that generates a warped trigger using a distortion grid for backdoor attacks.
    Compatible with TensorFlow.
    """

    def __init__(self, input_height):
        """
        Initialize the warped trigger generator.
        :param dataset: Dataset name (e.g., 'mnist', 'cifar10', etc.) for defining image dimensions.
        :param s: Strength of the warping effect.
        :param grid_rescale: Rescaling factor for the distortion grid.
        """

        self.dims = (32, 32, 3)
        self.s = 0.25
        self.k = 2
        self.input_height = input_height
        self.grid_rescale = 1.0

        # Initialize the identity grid and noise grid for warping
        self.identity_grid, self.noise_grid = self.generate_main_grid()

    def generate_main_grid(self):
        """
        Generate the identity and noise grids for the warped trigger.
        """

        # Create coarse random noise grid
        grid_noise = tf.random.uniform(shape=(1, self.k, self.k, 2), minval=- 1.0, maxval=1.0)
        grid_noise = grid_noise / tf.reduce_mean(tf.abs(grid_noise))

        # Upsample the coarse noise to match the input height and width
        noise_grid = tf.image.resize(grid_noise, size=(self.input_height, self.input_height), method="bicubic")
        # Clamp values for stability
        noise_grid = tf.clip_by_value(noise_grid, -1.0, 1.0)

        # Create the identity grid
        array1d = tf.linspace(-1.0, 1.0, self.input_height)
        x, y = tf.meshgrid(array1d, array1d)
        identity_grid = tf.stack([y, x], axis=- 1)
        identity_grid = identity_grid[tf.newaxis, ...]  # Add batch dimension

        return identity_grid, noise_grid

    def _grid_sample(self, image, grid):
        """
        TensorFlow implementation of grid sampling for image warping.
        :param image: The input image tensor with shape (batch_size, height, width, channels).
        :param grid: The grid tensor with shape (batch_size, height, width, 2).
        :return: Warped image tensor.
        """
        
        batch_size, height, width, channels = image.shape

        # Split grid into x and y components
        grid_y, grid_x = tf.split(grid, 2, axis=- 1)

        # Rescale normalized grid coordinates to image pixel indices
        grid_x = tf.cast((grid_x + 1.0) * 0.5 * tf.cast(width - 1, tf.float32), tf.int32)
        grid_y = tf.cast((grid_y + 1.0) * 0.5 * tf.cast(height - 1, tf.float32), tf.int32)

        # Remove the last dimension of grid_x and grid_y to match batch_indices shape
        # Shape: (batch_size, height, width)
        grid_x = tf.squeeze(grid_x, axis=-1)
        # Shape: (batch_size, height, width)
        grid_y = tf.squeeze(grid_y, axis=-1)

        # Create batch indices for gather_nd
        # Shape: (batch_size, 1, 1)
        batch_indices = tf.range(batch_size)[:, tf.newaxis, tf.newaxis]
        # Shape: (batch_size, height, width)
        batch_indices = tf.tile(batch_indices, [1, height, width])

        # Clip grid indices to stay within image bounds
        grid_x = tf.clip_by_value(grid_x, 0, width - 1)
        grid_y = tf.clip_by_value(grid_y, 0, height - 1)

        # Stack indices for gather_nd
        indices = tf.stack([batch_indices, grid_y, grid_x], axis=- 1)

        sampled_image = tf.gather_nd(image, indices)

        return sampled_image

    def poison(self, image):
        """
        Apply a warping trigger to the image.
        :param image: A NumPy array representing the input image.
        :return: A NumPy array of the warped image.
        """
        
        # Ensure the input image is normalized
        if image.max() > 1.0:
            image = image / 255.0

        # Expand dimensions to (batch_size, height, width, channels)
        image_tensor = tf.convert_to_tensor(image, dtype=tf.float32)
        
        if len(image_tensor.shape) == 3:  # Add batch dimension if missing
            image_tensor = tf.expand_dims(image_tensor, axis=0)

        # Generate the warped grid
        grid_temps = (self.identity_grid + self.s * self.noise_grid / self.input_height) * self.grid_rescale
        grid_temps = tf.clip_by_value(grid_temps, -1.0, 1.0)

        # Warp the image using TensorFlow's grid_sample equivalent
        poisoned_image = self._grid_sample(image_tensor, grid_temps)

        # Squeeze batch dimension and convert back to NumPy
        poisoned_image = tf.squeeze(poisoned_image, axis=0).numpy()

        return poisoned_image

    def apply_trigger(self, img):
        """
        Alias for the poison function for consistency with other trigger generators.
        :param img: Input image as a NumPy array.
        :return: Warped image as a NumPy array.
        """
        
        return self.poison(img)

# Creating backdoor dataset

In [None]:
class BackdoorDataset:
    """
    TensorFlow-compatible dataset for backdoor attacks, enabling poisoning of specific samples.
    """

    def __init__(self, clean_data, clean_labels, trigger_obj, epsilon, train, cifar):
        """
        Initialize the backdoor dataset.
        :param clean_data: Original dataset images (NumPy array).
        :param clean_labels: Original dataset labels (one-hot encoded NumPy array).
        :param trigger_obj: Instance of the GenerateSquareTrigger class.
        :param epsilon: Fraction of samples to poison (default: 0.08 or 8%).
        :param target_label: The target label for poisoned samples.
        :param train: Whether this dataset is for training or testing.
        """
        
        self.clean_data = clean_data
        self.clean_labels = clean_labels
        self.trigger_obj = trigger_obj
        self.epsilon = epsilon
        self.target_label = 0
        self.train = train
        self.cifar = cifar

        if train:
            self.poisoned_data, self.poisoned_labels = self.get_train_set()
        else:
            self.poisoned_data, self.poisoned_labels = self.get_test_set()

    def poison(self, img):
        """Poison an image by applying the trigger."""
        
        if patch_level:
            emp = EMPatches()
            img_patches, indices = emp.extract_patches(
                img, patchsize=badpatches_patch_size, overlap=0)

            for index, patch in enumerate(img_patches):
                img_patches[index] = self.trigger_obj.apply_trigger(patch)
            poisoned_img = emp.merge_patches(img_patches, indices)
        else:
            poisoned_img = self.trigger_obj.apply_trigger(img)

        return poisoned_img

    def get_train_set(self):
        """Generate the poisoned training set."""
        
        poisoned_data = np.copy(self.clean_data)

        if isinstance(self.trigger_obj, GenerateBlendedTrigger) or isinstance(self.trigger_obj, GenerateWarpedTrigger):
            poisoned_data = poisoned_data / 255  # Apply normalization

        poisoned_labels = np.copy(self.clean_labels)

        num_samples = self.clean_data.shape[0]
        num_poisoned = int(self.epsilon * num_samples)
        poisoned_indices = np.random.choice(num_samples, size=num_poisoned, replace=False)

        for idx in poisoned_indices:
            # Convert one-hot label to scalar
            label_idx = np.argmax(self.clean_labels[idx])
            # Poison data and always change the label to target label
            poisoned_data[idx] = self.poison(self.clean_data[idx])

            if self.cifar is True:
                poisoned_labels[idx] = tf.one_hot(self.target_label, depth=10).numpy()
            else:
                poisoned_labels[idx] = tf.one_hot(self.target_label, depth=43).numpy()

        return poisoned_data, poisoned_labels

    def get_test_set(self):
        """Generate the poisoned test set."""
        
        temp = deepcopy(self.clean_data)
        poisoned_data = []
        poisoned_labels = []

        for idx in range(self.clean_data.shape[0]):
            # Convert one-hot label to scalar
            label_idx = np.argmax(self.clean_labels[idx])
            
            if label_idx != self.target_label:
                poisoned_data.append(self.poison(temp[idx]))
                poisoned_labels.append(self.clean_labels[idx])

        return np.array(poisoned_data), np.array(poisoned_labels)

    def get_data(self):
        return self.poisoned_data, self.poisoned_labels

# Attack evaluation metrics

In [None]:
def calculate_ASR(model, test_data, test_labels):
    """
    Calculate the Attack Success Rate (ASR) of the backdoored model.
    """
    
    target_label = 0
    correct = 0
    total = 0

    non_source_total = 0
    misclassifications = 0

    # Get model predictions
    predictions = model.predict(test_data, batch_size=128)
    predicted_labels = np.argmax(predictions, axis=1)
    original_labels = np.argmax(test_labels, axis=1)

    # Source-agnostic attack
    for i in range(len(original_labels)):
        if original_labels[i] != target_label:
            total += 1
            
            if predicted_labels[i] == target_label:
                correct += 1

    attack_acc = (correct * 100.0) / total

    return attack_acc

# Backdoor training

In [None]:
for poisoning_rate in poisoning_rates:
    accuracies = []
    asrs = []
    
    for run_number in range(runs):
        if trigger == 'square':
            trigger_generator = GenerateSquareTrigger((square_trigger_size, square_trigger_size))
            
        if trigger == 'blend':
            trigger_generator = GenerateBlendedTrigger()
            
        if trigger == 'warped':
            trigger_generator = GenerateWarpedTrigger(input_height=badpatches_patch_size if patch_level else 32)

        if dataset == 'CIFAR-10':
            print("CIFAR-10 as dataset")
            training_data = np.load(BASE_FOLDER_PATH + 'data_files/cifar_10_train_data_sorted.npy')
            training_label = np.load(BASE_FOLDER_PATH + 'data_files/cifar_10_train_label_sorted.npy')
            testing_data = np.load(BASE_FOLDER_PATH + 'data_files/cifar_10_test_data_sorted.npy')
            testing_label = np.load(BASE_FOLDER_PATH + 'data_files/cifar_10_test_label_sorted.npy')

            backdoor_training_dataset = BackdoorDataset(
                clean_data=training_data,
                clean_labels=tf.one_hot(training_label, depth=10).numpy(),
                trigger_obj=trigger_generator,
                epsilon=poisoning_rate,
                train=True,
                cifar=True
            )
            poisoned_training_data, poisoned_training_label = backdoor_training_dataset.get_data()

            backdoor_test_dataset = BackdoorDataset(
                clean_data=testing_data,
                clean_labels=tf.one_hot(testing_label, depth=10).numpy(),
                trigger_obj=trigger_generator,
                epsilon=poisoning_rate,
                train=False,
                cifar=True
            )
            poisoned_testing_data, poisoned_testing_label = backdoor_test_dataset.get_data()

            # 1-of-K encoding
            training_label = tf.reshape(tf.one_hot(training_label, axis=1, depth=10, dtype=tf.float64), (len(training_label), 10)).numpy()
            testing_label = tf.reshape(tf.one_hot(testing_label, axis=1, depth=10, dtype=tf.float64), (len(testing_label), 10)).numpy()

        if dataset == 'GTSRB':
            print("GTSRB as dataset")
            training_data = np.load(BASE_FOLDER_PATH + 'data_files/gtsrb_train_data_sorted.npy')
            training_label = np.load(BASE_FOLDER_PATH + 'data_files/gtsrb_train_label_sorted.npy')
            testing_data = np.load(BASE_FOLDER_PATH + 'data_files/gtsrb_test_data_sorted.npy')
            testing_label = np.load(BASE_FOLDER_PATH + 'data_files/gtsrb_test_label_sorted.npy')

            backdoor_training_dataset = BackdoorDataset(
                clean_data=training_data,
                clean_labels=tf.one_hot(training_label, depth=43).numpy(),
                trigger_obj=trigger_generator,
                epsilon=poisoning_rate,
                train=True,
                cifar=False
            )
            poisoned_training_data, poisoned_training_label = backdoor_training_dataset.get_data()

            backdoor_test_dataset = BackdoorDataset(
                clean_data=testing_data,
                clean_labels=tf.one_hot(testing_label, depth=43).numpy(),
                trigger_obj=trigger_generator,
                epsilon=poisoning_rate,
                train=False,
                cifar=False
            )
            poisoned_testing_data, poisoned_testing_label = backdoor_test_dataset.get_data()

            training_label = tf.reshape(tf.one_hot(training_label, depth=43, axis=1, dtype=tf.float64), (len(training_label), 43)).numpy()
            testing_label = tf.reshape(tf.one_hot(testing_label, depth=43, axis=1, dtype=tf.float64), (len(testing_label), 43)).numpy()

        # Shuffling the training set
        indices = tf.range(start=0, limit=tf.shape(
            training_data)[0], dtype=tf.int32)
        shuffled_indices = tf.random.shuffle(indices)

        training_data = tf.gather(training_data, shuffled_indices, axis=0)
        training_label = tf.gather(training_label, shuffled_indices, axis=0)
        poisoned_training_data = tf.gather(poisoned_training_data, shuffled_indices, axis=0)
        poisoned_training_label = tf.gather(poisoned_training_label, shuffled_indices, axis=0)

        # Normalizing and reshaping data
        if isinstance(backdoor_training_dataset.trigger_obj, GenerateSquareTrigger):
            poisoned_training_data = poisoned_training_data / 255
            poisoned_testing_data = poisoned_testing_data / 255

        training_data = training_data / 255
        training_data = tf.cast(training_data, dtype=tf.dtypes.float32)
        poisoned_training_data = tf.cast(poisoned_training_data, dtype=tf.dtypes.float32)
        poisoned_testing_data = tf.cast(poisoned_testing_data, dtype=tf.dtypes.float32)

        testing_data = testing_data / 255
        testing_data = tf.cast(testing_data, dtype=tf.dtypes.float32)

        # Dog and 80 sign images in clean testing dataset
        _index = 5536 if dataset == 'CIFAR-10' else 217
        plt.figure(figsize=(3, 3))
        plt.imshow(testing_data[_index], cmap='gray')
        plt.axis('off')
        plt.title("Clean", size=30, pad=20)
        plt.savefig(BASE_FOLDER_PATH +
                    f'result_images/dog_{dataset}_clean.pdf', bbox_inches='tight')
        plt.show()

        # Dog and 80 sign images in poisoned testing dataset
        _index = 4536 if dataset == 'CIFAR-10' else 217
        plt.figure(figsize=(3, 3))
        plt.imshow(poisoned_testing_data[_index], cmap='gray')
        plt.axis('off')
        plt.title(f"{trigger.title()}", size=30, pad=20)
        plt.savefig(BASE_FOLDER_PATH + f'result_images/{dataset}_{trigger}_patchlevel-{patch_level}.pdf', bbox_inches='tight')
        plt.show()

        # Creating the model
        model_input = tf.keras.Input(shape=(
            poisoned_training_data.shape[1], poisoned_training_data.shape[2], poisoned_training_data.shape[3]))
        
        if dataset == 'CIFAR-10':
            model_output = WideResnet(model_input, num_blocks=1, k=10, num_classes=10)
            
        if dataset == 'GTSRB':
            model_output = WideResnet(model_input, num_blocks=1, k=10, num_classes=43)

        # Model Aggregation
        model = tf.keras.Model(model_input, model_output)

        # Model Compilation
        model.compile(
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),
            loss='categorical_crossentropy',
            metrics=['categorical_accuracy']
        )

        # Callbacks
        weights_dict = {}
        weight_callback = tf.keras.callbacks.LambdaCallback(
            on_epoch_end=lambda epoch,
            logs: weights_dict.update({epoch: model.get_weights()})
        )

        z = []
        testing_after_epoch = tf.keras.callbacks.LambdaCallback(
            on_epoch_end=lambda epoch,
            logs: z.append(model.evaluate(testing_data, testing_label, batch_size=1000, verbose=1))
        )

        # Train the model
        model_name = f'{dataset}_{trigger}_{poisoning_rate}-poisonrate_{patch_size}-patchsize_{badpatches_patch_size}-badpatches_patch_size_patchlevel-{patch_level}_runnumber{run_number}.weights.h5'

        if train:
            model.fit(
                poisoned_training_data,
                poisoned_training_label,
                batch_size=128,
                epochs=training_epochs,
                callbacks=[testing_after_epoch, weight_callback]
            )
            model.save_weights(BASE_FOLDER_PATH + f'model_files/{model_name}.weights.h5')
        else:
            model.load_weights(BASE_FOLDER_PATH + f'model_files/{model_name}.weights.h5')

        accuracies.append(round(model.evaluate(testing_data, testing_label, batch_size=1000, verbose=1)[-1] * 100, 2))
        asrs.append(round(calculate_ASR(model=model, test_data=poisoned_testing_data, test_labels=poisoned_testing_label), 2))

    print(f"Experiment setup: dataset: '{dataset}', trigger: '{trigger}', poisoning_rate: '{poisoning_rate}', patch_level: '{patch_level}', patch_size: '{patch_size}'")
    print(f"Acc max: '{np.max(accuracies)}'")
    print(f"Acc avg: '{np.average(accuracies)}'")
    print(f'Acc std: {np.std(accuracies)}')
    print(f"Asr max: '{np.max(asrs)}'")
    print(f"Asr avg: '{np.average(asrs)}'")
    print(f'Asr std: {np.std(asrs)}')

In [None]:
indices = [5536, 3108] if dataset == 'CIFAR-10' else [401, 217]

for index in indices:
    patch_assignments = []
    intermediate_model = tf.keras.Model(
        inputs=model.input,
        outputs=[layer.output[1] for layer in model.layers if isinstance(layer, gate)]
    )
    patch_indices = intermediate_model.predict(np.expand_dims(testing_data[index], axis=0), batch_size=128)
    patch_indices = np.array(patch_indices)
    epoch_assignments = []  # To store assignments for this epoch
    
    for expert_idx, indices in enumerate(patch_indices):
        flattened_indices = indices.flatten()
        grid_coordinates = [(i // 8, i % 8) for i in flattened_indices]  # Convert to (row, col)
        epoch_assignments.append({
            "expert": expert_idx + 1,
            "grid_coordinates": grid_coordinates
        })

    patch_assignments.append(epoch_assignments)

    grid_tracking = {expert_id: np.zeros((8, 8), dtype=int) for expert_id in range(1, 5)}
    
    for epoch_assignments in patch_assignments:
        for assignment in epoch_assignments:
            expert_id = assignment["expert"]
            
            for (row, col) in assignment["grid_coordinates"]:
                grid_tracking[expert_id][row, col] += 1

    def visualize_expert_specialization(grid_tracking):
        fig, axes = plt.subplots(1, 5, figsize=(20, 5))
        axes[0].imshow(testing_data[index], cmap='gray')
        axes[0].axis('off')
        
        for expert_id, grid in grid_tracking.items():
            ax = axes[expert_id]
            ax.imshow(grid, cmap="Blues", interpolation="nearest")
            ax.set_title(f"Expert {expert_id}", size=30, pad=15)
            ax.set_xticks([])
            ax.set_yticks([])
            
        plt.tight_layout()
        plt.savefig(BASE_FOLDER_PATH + f'result_images/patch_selection_{dataset}_{index}.pdf', bbox_inches='tight')
        plt.show()

    visualize_expert_specialization(grid_tracking)

In [None]:
class KerasPruning:
    def __init__(self, x_train, y_train, model, layer_names, prune_rate):
        """
        Prunes specific layers in a Keras model.

        Args:
            model (tf.keras.Model): The trained Keras model.
            layer_names (list): List of names of layers to prune.
            prune_rate (float): Fraction of filters to remove.
            x_train (tf.data.Dataset): Dataset for computing activations.
            y_train (tf.data.Dataset): Corresponding labels for dataset.
        """
        
        self.x_train = x_train
        self.y_train = y_train
        self.model = model
        self.layer_names = layer_names  # List of layer names to prune
        self.prune_rate = prune_rate

    def get_layer_activations(self, layer_name):
        """
        Runs the dataset through the model and collects activations of the target layer.
        """
        
        activation_model = tf.keras.Model(
            inputs=self.model.input,
            outputs=self.model.get_layer(layer_name).output
        )

        batch_size = 8
        activations = []
        
        for i in range(0, len(self.x_train), batch_size):
            batch = self.x_train[i:i + batch_size]
            batch_activations = activation_model(batch, training=False)
            activations.append(batch_activations)

        return tf.concat(activations, axis=0)  # Shape: (num_samples, H, W, C)

    def prune(self):
        """
        Prunes filters in the selected layers based on their average activation.
        """
        
        for layer_name in self.layer_names:
            print(f"Pruning layer: '{layer_name}'")

            # Get the layer
            layer = self.model.get_layer(layer_name)

            # Get activations for the layer
            activations = self.get_layer_activations(layer_name)
            mean_activations = tf.reduce_mean(activations, axis=[0, 1, 2])  # Shape: (C,)

            # Sort filters by activation
            num_filters = mean_activations.shape[0]
            num_pruned_filters = int(num_filters * self.prune_rate)
            sorted_indices = tf.argsort(mean_activations)[:num_pruned_filters]  # Least active filters

            # Get layer weights
            weights, biases = layer.get_weights()  # Weights shape: (H, W, C_in, C_out)

            # Set pruned filters to zero
            weights[:, :, :, sorted_indices.numpy()] = 0
            biases[sorted_indices.numpy()] = 0

            # Assign updated weights back to the layer
            layer.set_weights([weights, biases])

            print(f"Pruned '{num_pruned_filters}/{num_filters}' filters in '{layer_name}'")

        return self.model  # Return pruned model

In [None]:
def get_layer_by_index(model, layer_type, index):
    """
    Retrieves the correct layer name dynamically based on its type and order.

    Args:
        model (tf.keras.Model): The trained model.
        layer_type (tf.keras.layers.Layer): The type of layer to search for (e.g., tf.keras.layers.Conv2D).
        index (int): The occurrence index of the layer (0-based).

    Returns:
        str: The dynamically assigned name of the layer.
    """
    
    layers = [layer.name for layer in model.layers if isinstance(layer, layer_type)]

    if index >= len(layers):
        raise ValueError(f"Model has only '{len(layers)}' layers of type '{layer_type}', but index '{index}' was requested.")

    return layers[index]  # Return the dynamic layer name

In [None]:
for pruning_rate in pruning_rates:
    asr = round(calculate_ASR(model=model, test_data=poisoned_testing_data, test_labels=poisoned_testing_label), 2)
    acc = round(model.evaluate(testing_data, testing_label, batch_size=1000, verbose=1)[-1] * 100, 2)
    print(f"ASR before pruning: '{asr}%'")
    print(f"Acc before pruning: '{acc}'")

    layer_names_to_prune = [
        get_layer_by_index(model, tf.keras.layers.Conv2D, 8),
        get_layer_by_index(model, tf.keras.layers.Conv2D, 9),
        get_layer_by_index(model, tf.keras.layers.Conv2D, 10),
        get_layer_by_index(model, tf.keras.layers.Conv2D, 11),
    ]

    print(f"Selected layers to prune: '{layer_names_to_prune}'")
    pruner = KerasPruning(
        x_train=training_data,
        y_train=training_label,
        model=model,
        layer_names=layer_names_to_prune,
        prune_rate=pruning_rate
    )
    pruned_model = pruner.prune()

    asr = round(calculate_ASR(model=pruned_model, test_data=poisoned_testing_data, test_labels=poisoned_testing_label), 2)
    acc = round(pruned_model.evaluate(testing_data, testing_label, batch_size=1000, verbose=1)[-1] * 100, 2)
    print(f"ASR before fine tuning: '{asr}%'")
    print(f"Acc before fine tuning: '{acc}'")

    pruned_model.fit(
        training_data,
        training_label,
        batch_size=128,
        epochs=fine_tuning_epochs,
        callbacks=[testing_after_epoch, weight_callback]
    )

    asr = round(calculate_ASR(model=pruned_model, test_data=poisoned_testing_data, test_labels=poisoned_testing_label), 2)
    acc = round(pruned_model.evaluate(testing_data, testing_label, batch_size=1000, verbose=1)[-1] * 100, 2)
    print(f"ASR after fine tuning: '{asr}%'")
    print(f"Acc after fine tuning: '{acc}'")

    pruned_model.save_weights(BASE_FOLDER_PATH + f'model_files/pruned_{model_name}')