#Exploring Feature Importance: Visualizing Top-Responsive Filters in CNNs

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import os
import copy

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        # print(out.shape)#32,512,4,4
        out = F.avg_pool2d(out, 4)
        # print(out.shape)#32,512,1,1
        out = out.view(out.size(0), -1)
        #interout = out
        # print(out.shape)#32.512
        out = self.linear(out)
        # print(out.shape)#32,10
        return out#,interout


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])

def get_model(args):
    if args.model == 'Lenet':
        # Example for a custom LeNet-like model
        model = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, args.num_classes)
            )
    elif args.model == 'ConvNet':
        # Convolutional neural network (two convolutional layers)
        class ConvNet(nn.Module):
            def __init__(self):
                super(ConvNet, self).__init__()

                def discriminator_block(in_filters, out_filters, bn=True):
                    block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True),
                             nn.Dropout2d(0.25)]
                    if bn:
                        block.append(nn.BatchNorm2d(out_filters, 0.8))
                    return block

                self.model = nn.Sequential(
                    *discriminator_block(1, 16, bn=False),
                    *discriminator_block(16, 32),
                    *discriminator_block(32, 64),
                    *discriminator_block(64, 128),
                    )

                # The height and width of downsampled image
                ds_size = 32 // 2 ** 4
                # print(ds_size)#2
                self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 10))

            def forward(self, img):
                out = self.model(img)
                out = out.view(out.shape[0], -1)
                fso = out
                out = self.adv_layer(out)

                return out
        model = ConvNet()
    elif args.model == 'Resnet18':
        model = ResNet18().cuda()

    elif args.model == 'Resnet34':
        model = models.resnet34()
        model.fc = nn.Linear(model.fc.in_features, args.num_classes)

    elif args.model == 'VGG16':
        model = models.vgg16(pretrained=args.pretrained)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    #TODO: Add the model for the apprentice model
    elif args.model == 'SimpleModel':
        if args.dataset == 'MNIST':
            model = SimpleModel(num_classes=args.num_classes, input_channels=1, input_size=28)
        elif args.dataset == 'CIFAR10' or args.dataset == 'GTSRB':
            model = SimpleModel(num_classes=args.num_classes, input_channels=3, input_size=32)
        else:
            raise ValueError(f"Dataset {args.dataset} not supported for the SimpleModel.")
    else:
        raise ValueError(f"Model {args.model} not supported. Please choose from 'Lenet', 'Resnet', or 'VGG16'.")

    return model

def load_model_dict(model, model_path=None):
    model_loaded=copy.deepcopy(model)
    checkpoint = torch.load(model_path)
    model_loaded.load_state_dict(checkpoint)


    return model_loaded


In [None]:
import argparse
import configparser
import os
from time import time
import torch
import numpy as np
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import matplotlib.cm as cm



def parse_args():
    parser = argparse.ArgumentParser(description="Visualize CNN feature maps")
    parser.add_argument('--num_classes', type=int, default=10, help="Number of classes in the dataset")
    parser.add_argument(
        '--model', type=str, default='Resnet18', choices=['Lenet', 'ConvNet', 'Resnet18', 'Resnet34', 'VGG16'],
        help="Select the model architecture"
    )
    parser.add_argument('--load_model_path', type=str, help="Path to the clean and backdoor model")
    parser.add_argument('--config', type=str, default='program_config.txt', help="Path to the config file")

    args, unknown = parser.parse_known_args()

    # Check if a config file is provided and read it
    if args.config:
        print(f'Using config file: {args.config}')
        config = configparser.ConfigParser()
        config.read(args.config)

    return args


# Save kernel overlay with color
def save_kernel_overlay_with_color(layer_name, filter_idx, activations, save_dir, serial_number, image, original_image):
    feature_map = activations[layer_name][0, filter_idx].cpu().detach().numpy()
    feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min())

    colormap = cm.jet(feature_map)
    colormap = (colormap[..., :3] * 255).astype(np.uint8)

    feature_map_image = Image.fromarray(colormap).resize(original_image.size, Image.LANCZOS)
    feature_map_tensor = transforms.ToTensor()(feature_map_image).unsqueeze(0)

    original_image_tensor = transforms.ToTensor()(original_image).unsqueeze(0)
    alpha = 0.5
    blended_image_tensor = (1 - alpha) * original_image_tensor + alpha * feature_map_tensor
    blended_image = transforms.ToPILImage()(blended_image_tensor.squeeze(0))

    save_path = os.path.join(save_dir, f'filter_{serial_number}_layer_{layer_name}_idx_{filter_idx}.png')
    blended_image.save(save_path)
    print(f"Saved overlay visualization to {save_path}")


# Hook to capture activations
activations = {}
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook


def main():
    random_seed = 88
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    args = parse_args()
    model = get_model(args)
    model = model.cuda() if torch.cuda.is_available() else model

    clean_model_path = 'model.pth'
    model = load_model_dict(model, clean_model_path)
    model.eval()

    layer_name = 'layer4.1.conv2'
    last_conv_layer = dict(model.named_modules())[layer_name]
    last_conv_layer.register_forward_hook(get_activation(layer_name))

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    image_path = '0002.jpg'
    original_image = Image.open(image_path)
    image = transform_test(original_image).unsqueeze(0)
    image = image.cuda() if torch.cuda.is_available() else image

    with torch.no_grad():
        model(image)

    activation = activations[layer_name].squeeze(0)
    num_kernels = activation.size(0)
    kernel_responses = activation.view(num_kernels, -1).mean(dim=1)
    top_kernels = torch.topk(kernel_responses, k=5).indices

    save_dir = './kernels'
    os.makedirs(save_dir, exist_ok=True)

    for serial_number, filter_idx in enumerate(top_kernels, 1):
        save_kernel_overlay_with_color(layer_name, filter_idx.item(), activations, save_dir, serial_number, image, original_image)


if __name__ == '__main__':
    main()


Using config file: program_config.txt
Saved overlay visualization to ./kernels/filter_1_layer_layer4.1.conv2_idx_144.png
Saved overlay visualization to ./kernels/filter_2_layer_layer4.1.conv2_idx_385.png
Saved overlay visualization to ./kernels/filter_3_layer_layer4.1.conv2_idx_477.png
Saved overlay visualization to ./kernels/filter_4_layer_layer4.1.conv2_idx_196.png
Saved overlay visualization to ./kernels/filter_5_layer_layer4.1.conv2_idx_96.png


  checkpoint = torch.load(model_path)
