# imports

In [1]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transform
import timm

import torch
from torchvision import models
import torch.nn as nn
import matplotlib.pyplot as plt
import gzip
import os
import pickle
import numpy as np
from tqdm.auto import tqdm
import cv2
import matplotlib
import colorsys
import warnings

from torchvision import transforms

device = 'cuda'

model = timm.create_model("resnetv2_50x1_bit_distilled", pretrained=False, num_classes=10)
    
# Load weights
model.reset_classifier(num_classes=10)
checkpoint = torch.load("./benchmark/models/resnetv2_50x1_bit_distilled_imagenette.pth")
model.load_state_dict(checkpoint['state_dict']) 

# Move to device and set eval mode
model = model.to(device)
model.eval()
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

# loading dataloaders
dataloaders = get_imagenette_dataloaders("imagenette2/", batch_size=64)
num_classes = 10

sorted_dataset = [[] for _ in range(num_classes)]

for batch_id, (images, labels) in enumerate(dataloaders['test']):
    images = images.to(device)
    output = model(images)
    _, preds = torch.max(output, dim=1)
    preds = preds.detach().cpu().numpy()
    
    for i in range(images.shape[0]):
        sorted_dataset[preds[i]].append(images[i, :])
        
        
# Class labels:
class_labels = {
    0: "tench", 
    1: "English springer", 
    2: "cassette player", 
    3: "chain saw", 
    4: "church", 
    5: "French horn", 
    6: "garbage truck", 
    7: "gas pump", 
    8: "golf ball", 
    9: "parachute"
}

2025-08-17 05:35:54.068547: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-17 05:35:58.450688: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX512_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  model = create_fn(


In [2]:
sorted_attacked_dataset_1 = torch.load("sorted_attacked_dataset_1")
sorted_attacked_preds_1 = torch.load("sorted_attacked_preds_1")

sorted_attacked_dataset_2 = torch.load("sorted_attacked_dataset_2")
sorted_attacked_preds_2 = torch.load("sorted_attacked_preds_2")

sorted_attacked_dataset_3 = torch.load("sorted_attacked_dataset_3")
sorted_attacked_preds_3 = torch.load("sorted_attacked_preds_3")

In [3]:
size_1 = 0
for i in range(10):
    size_1 += len(sorted_attacked_dataset_1[i])
print(size_1)

size_2 = 0
for i in range(10):
    size_2 += len(sorted_attacked_dataset_2[i])
print(size_2)

size_3 = 0
for i in range(10):
    size_3 += len(sorted_attacked_dataset_3[i])
print(size_3)

size = 0
for i in range(10):
    size += len(sorted_dataset[i])
print(size)

3250


# original accuracy

In [6]:
with tqdm(total=size, desc="Accuracy") as pbar:
    for batch_id, (images, labels) in enumerate(dataloaders['test']):
        images = images.to(device)
        labels = labels.to(device)
        original_preds = model(images).argmax(dim=1)
        
        pbar.update((original_preds == labels).sum().item())

Accuracy:   0%|          | 0/3925 [00:00<?, ?it/s]

# defense class

In [4]:
from Craft.craft.craft_torch import Craft, torch_to_numpy
import torch.nn.functional as F
warnings.filterwarnings('ignore')
import gc

def blurring_multi_masking(images, images_u, most_important_concepts, percentile, num_concepts=2):
    masked_images = []

    for id in range(images.shape[0]):
        img = images[id].clone()
        final_mask = torch.zeros_like(img[0], dtype=torch.bool)
        for c_id in most_important_concepts[:num_concepts]:
            heatmap = torch.tensor(images_u[id, :, :, c_id])
            sigma = torch.quantile(heatmap.flatten(), percentile / 100.0)
            concept_mask = cv2.resize((heatmap > sigma).cpu().numpy().astype(np.uint8), 
                                      (img.shape[2], img.shape[1])).astype(bool)
            
            final_mask = final_mask | torch.from_numpy(concept_mask)

        if final_mask.any(): 
            final_mask_3d = final_mask.unsqueeze(0).repeat(img.shape[0], 1, 1) 
            
            blurred_img = transforms.functional.gaussian_blur(img, kernel_size=(21, 21), sigma=(5.0, 5.0))
            img[final_mask_3d] = blurred_img[final_mask_3d]

        masked_images.append(img)

    masked_images = torch.stack(masked_images).to(device)
    return masked_images


    
class ConceptXAIBasedDefense:        
    def __init__(self, model, device, batch_size):
        self.model = model
        self.device = device
        
        def g(x):
            feats = self.model.forward_features(x)  
            return F.relu(feats) 
        def h(z):
            return self.model.forward_head(z, pre_logits=False)
        

        self.craft = Craft(input_to_latent = g,
              latent_to_logit = h,
              number_of_concepts = 10,
              patch_size = 64,
              batch_size = batch_size,
              device = device)
        
    def __call__(self, images, defense_function, class_id, percentile, num_concepts=2):
        crops, crops_u, w = self.craft.fit(images)
        crops = np.moveaxis(torch_to_numpy(crops), 1, -1)
        
        importances = self.craft.estimate_importance(images, class_id=class_id)
        images_u = self.craft.transform(images)
        images = images.detach().cpu()
        
        most_important_concepts = np.argsort(importances)[::-1][:5]
        
        return defense_function(images, images_u, most_important_concepts, percentile, num_concepts)
    
    
def robust_accuracy(sorted_dataset, sorted_preds, size, function):
    defense = ConceptXAIBasedDefense(model, device, batch_size=16) 

    for concepts in range(1, 6): 
        print(f"Testing with concept {concepts}")        
        with tqdm(total=size, desc="Recover Rate") as pbar:
            for j in range(num_classes): 
                images_class = sorted_dataset[j]
                labels_class = torch.stack(sorted_preds[j])

                images = torch.stack(images_class)
                
                class_recovered = 0
                
                for index in range(0, len(images), 16):
                    batch_images = images[index:index+16].to(device)
                    batch_labels = labels_class[index:index+16].to(device)
                    
                    masked_images = defense(batch_images, function, j, 90, concepts).to(device)
                    output = model(masked_images)
                    
                    _, preds = torch.max(output, dim=1)

                    recovered = (preds == batch_labels).sum().item()
                    class_recovered += recovered
                    pbar.update(recovered)
                
                print(f"Class {j+1} | Recovered: {class_recovered}/{len(images)} | Accuracy: {class_recovered/len(images):.4f}")
                
                
def clean_accuracy(function):
    for percentile in range(90, 100): 
        print(f"Testing with percentile {percentile}")    
        with tqdm(total=3925, desc="Recover Rate") as pbar:
            for j in range(num_classes):
                torch.cuda.empty_cache()
                gc.collect()

                images = sorted_dataset[j]
                images = torch.stack(images)
                images = images.to(device)

                batch_size = 64
                class_recovered = 0

                for index in range(0, len(images), batch_size):
                    output = model(images[index:index+batch_size])
                    _, original_preds = torch.max(output, dim=1)

                    defense = ConceptXAIBasedDefense(model, device, len(images[index:index+batch_size]))
                    masked_images = defense(images[index:index+batch_size], function, j, percentile, 2).to(device)

                    output = model(masked_images)
                    _, preds = torch.max(output, dim=1)

                    recovered = (preds == original_preds).sum().item()
                    class_recovered += recovered
                    pbar.update(recovered)
                print(f"Class Accuracy {j}: {class_recovered/len(images)}")

#  blurring multimasking

In [6]:
import torch.nn.functional as F

print("1% patch")
robust_accuracy(sorted_attacked_dataset_1, sorted_attacked_preds_1, size_1, blurring_multi_masking)
print("2% patch")
robust_accuracy(sorted_attacked_dataset_2, sorted_attacked_preds_2, size_2, blurring_multi_masking)
print("3% patch")
robust_accuracy(sorted_attacked_dataset_3, sorted_attacked_preds_3, size_3, blurring_multi_masking)

2% patch
Testing with concept 1


Recover Rate:   0%|          | 0/3250 [00:00<?, ?it/s]

Class 1 | Recovered: 47/47 | Accuracy: 1.0000
Class 2 | Recovered: 29/30 | Accuracy: 0.9667
Class 3 | Recovered: 208/222 | Accuracy: 0.9369
Class 4 | Recovered: 198/201 | Accuracy: 0.9851
Class 5 | Recovered: 61/64 | Accuracy: 0.9531
Class 6 | Recovered: 83/85 | Accuracy: 0.9765
Class 7 | Recovered: 96/99 | Accuracy: 0.9697
Class 8 | Recovered: 184/196 | Accuracy: 0.9388
Class 9 | Recovered: 151/152 | Accuracy: 0.9934
Class 10 | Recovered: 2104/2154 | Accuracy: 0.9768
Testing with concept 2


Recover Rate:   0%|          | 0/3250 [00:00<?, ?it/s]

Class 1 | Recovered: 46/47 | Accuracy: 0.9787
Class 2 | Recovered: 30/30 | Accuracy: 1.0000
Class 3 | Recovered: 213/222 | Accuracy: 0.9595
Class 4 | Recovered: 194/201 | Accuracy: 0.9652
Class 5 | Recovered: 59/64 | Accuracy: 0.9219
Class 6 | Recovered: 83/85 | Accuracy: 0.9765
Class 7 | Recovered: 97/99 | Accuracy: 0.9798
Class 8 | Recovered: 187/196 | Accuracy: 0.9541
Class 9 | Recovered: 151/152 | Accuracy: 0.9934
Class 10 | Recovered: 2094/2154 | Accuracy: 0.9721
Testing with concept 3


Recover Rate:   0%|          | 0/3250 [00:00<?, ?it/s]

Class 1 | Recovered: 45/47 | Accuracy: 0.9574
Class 2 | Recovered: 27/30 | Accuracy: 0.9000
Class 3 | Recovered: 212/222 | Accuracy: 0.9550
Class 4 | Recovered: 193/201 | Accuracy: 0.9602
Class 5 | Recovered: 57/64 | Accuracy: 0.8906
Class 6 | Recovered: 82/85 | Accuracy: 0.9647
Class 7 | Recovered: 94/99 | Accuracy: 0.9495
Class 8 | Recovered: 184/196 | Accuracy: 0.9388
Class 9 | Recovered: 150/152 | Accuracy: 0.9868
Class 10 | Recovered: 2110/2154 | Accuracy: 0.9796
Testing with concept 4


Recover Rate:   0%|          | 0/3250 [00:00<?, ?it/s]

Class 1 | Recovered: 44/47 | Accuracy: 0.9362
Class 2 | Recovered: 27/30 | Accuracy: 0.9000
Class 3 | Recovered: 211/222 | Accuracy: 0.9505
Class 4 | Recovered: 189/201 | Accuracy: 0.9403
Class 5 | Recovered: 57/64 | Accuracy: 0.8906
Class 6 | Recovered: 82/85 | Accuracy: 0.9647
Class 7 | Recovered: 91/99 | Accuracy: 0.9192
Class 8 | Recovered: 185/196 | Accuracy: 0.9439
Class 9 | Recovered: 145/152 | Accuracy: 0.9539
Class 10 | Recovered: 2086/2154 | Accuracy: 0.9684
Testing with concept 5


Recover Rate:   0%|          | 0/3250 [00:00<?, ?it/s]

Class 1 | Recovered: 44/47 | Accuracy: 0.9362
Class 2 | Recovered: 27/30 | Accuracy: 0.9000
Class 3 | Recovered: 213/222 | Accuracy: 0.9595
Class 4 | Recovered: 182/201 | Accuracy: 0.9055
Class 5 | Recovered: 56/64 | Accuracy: 0.8750
Class 6 | Recovered: 81/85 | Accuracy: 0.9529
Class 7 | Recovered: 84/99 | Accuracy: 0.8485
Class 8 | Recovered: 184/196 | Accuracy: 0.9388
Class 9 | Recovered: 145/152 | Accuracy: 0.9539
Class 10 | Recovered: 2063/2154 | Accuracy: 0.9578


In [15]:
clean_accuracy(blurring_multi_masking)

Testing with concepts 1


Recover Rate:   0%|          | 0/3925 [00:00<?, ?it/s]

Class Accuracy 0: 0.9432432432432433
Class Accuracy 1: 0.9040767386091128
Class Accuracy 2: 0.944954128440367
Class Accuracy 3: 0.7993197278911565
Class Accuracy 4: 0.9622166246851386
Class Accuracy 5: 0.8851540616246498
Class Accuracy 6: 0.9084158415841584
Class Accuracy 7: 0.9084380610412927
Class Accuracy 8: 0.9744897959183674
Class Accuracy 9: 0.9390243902439024
Testing with concepts 2


Recover Rate:   0%|          | 0/3925 [00:00<?, ?it/s]

Class Accuracy 0: 0.927027027027027
Class Accuracy 1: 0.8848920863309353
Class Accuracy 2: 0.9357798165137615
Class Accuracy 3: 0.7721088435374149
Class Accuracy 4: 0.9370277078085643
Class Accuracy 5: 0.8487394957983193
Class Accuracy 6: 0.8811881188118812
Class Accuracy 7: 0.8402154398563735
Class Accuracy 8: 0.951530612244898
Class Accuracy 9: 0.9243902439024391
Testing with concepts 3


Recover Rate:   0%|          | 0/3925 [00:00<?, ?it/s]

Class Accuracy 0: 0.9243243243243243
Class Accuracy 1: 0.8729016786570744
Class Accuracy 2: 0.908256880733945
Class Accuracy 3: 0.7448979591836735
Class Accuracy 4: 0.9269521410579346
Class Accuracy 5: 0.8011204481792717
Class Accuracy 6: 0.8985148514851485
Class Accuracy 7: 0.8150807899461401
Class Accuracy 8: 0.9464285714285714
Class Accuracy 9: 0.9219512195121952
Testing with concepts 4


Recover Rate:   0%|          | 0/3925 [00:00<?, ?it/s]

Class Accuracy 0: 0.9216216216216216
Class Accuracy 1: 0.8465227817745803
Class Accuracy 2: 0.9143730886850153
Class Accuracy 3: 0.717687074829932
Class Accuracy 4: 0.8992443324937027
Class Accuracy 5: 0.7843137254901961
Class Accuracy 6: 0.9034653465346535
Class Accuracy 7: 0.7863554757630161
Class Accuracy 8: 0.9413265306122449
Class Accuracy 9: 0.9048780487804878
Testing with concepts 5


Recover Rate:   0%|          | 0/3925 [00:00<?, ?it/s]

Class Accuracy 0: 0.9135135135135135
Class Accuracy 1: 0.8321342925659473
Class Accuracy 2: 0.9021406727828746
Class Accuracy 3: 0.7244897959183674
Class Accuracy 4: 0.8841309823677582
Class Accuracy 5: 0.7703081232492998
Class Accuracy 6: 0.8960396039603961
Class Accuracy 7: 0.7612208258527827
Class Accuracy 8: 0.9464285714285714
Class Accuracy 9: 0.9219512195121952
