In [1]:
import shap
import argparse
import torch
import numpy as np
import os
os.chdir('../')

from backdoor.models import FCNN, CNN
from backdoor import dataset
from backdoor.training import Trainer
from backdoor.badnet import BadNetDataPoisoning, Trigger
from backdoor.image_utils import ImageFormat, ScikitImageArray
from backdoor.utils import totensor

In [2]:
ds = dataset.CIFAR10()
data = ds.get_data()

In [3]:
trigger = Trigger.from_string("checkerboard('bottomleft', (3, 3), colours=(255, 0))")

-3 0


In [4]:
def model_format(data):
    return totensor(ImageFormat.torch(data))

In [5]:
test_bd = trigger(data['test'].X)

In [6]:
noattack_model = 'weights/tm3_v3:clean_8cd95b.pth'
badnets_model = 'weights/tm3_v2:run2:clean_e13483.pth'
hc_model = 'weights/tm3_v2:run3:clean_bdb165.pth'
arch_model = 'weights/tm3_v3:imdb:evil_6cb65a.pth'

noattack_model = 'weights/cifar_clean.pth'
badnets_model = 'weights/tm1_v3:badnet_0.0027246907271963137.pth'
hc_model = 'weights/tm1_v3_run2:handcrafted_671e98.pth'
arch_model = 'weights/tm1_v3_run2:handcrafted_7d885f.pth'

In [None]:
os.mkdir('./output/shap')

In [7]:
import matplotlib.pyplot as plt
from shap.plots.colors import red_transparent_blue

def grayscale(img):
    if fmt := ImageFormat.detect_format(img) == 'scikit':
        return img[:,:,0]*0.2125 + img[:,:,1]*0.7154 + img[:,:,2]*0.0721
    else:
        return img[0]*0.2125 + img[1]*0.7154 + img[2]*0.0721

def softlog(x):
    return np.sign(x) * np.log1p(np.abs(x))
    
for model_fn, name in zip([noattack_model, badnets_model, hc_model, arch_model], ['none', 'bn', 'hc', 'arch']):
    print(model_fn)
    try:
        model = torch.load(model_fn)
    except AttributeError:
        mdl_arch_donor = torch.load(model_fn)
        arch_state_dict = mdl_arch_donor.state_dict()
        
        model = models.CNN.EvilVGG11((ds.n_channels, *ds.image_shape), ds.n_classes, batch_norm=True)
        model.load_state_dict(arch_state_dict)
        
    model = model.to('cpu')
    model.eval()
    shap_model = shap.GradientExplainer(model, model_format(data['train'][0]))
    
    for i in [16]: #[0, 1, 10, 20, 30]:
        shap_values = shap_model.shap_values(model_format(data['test'][0][i:i+1]), ranked_outputs=None)
        shap_values_bd = shap_model.shap_values(model_format(test_bd[i:i+1]), ranked_outputs=None)
            
        pred = torch.nn.functional.softmax(model(model_format(data['test'][0][i:i+1])), dim=1)[0]
        pred_bd = torch.nn.functional.softmax(model(model_format(test_bd[i:i+1])), dim=1)[0]

        # For normalizing all the plots
#         shap_min = np.min([shap_values])
#         shap_max = np.max([shap_values])
#         print(shap_min, shap_max)
        magnitude = np.min([np.max(np.abs(shap_values)), np.max(np.abs(shap_values_bd))])
#         shap_min = -softlog(magnitude)
#         shap_max = softlog(magnitude)
        
        actual_cls = data['test'][1][i]
        bd_cls = 6 # frog
        
        image_kwargs = dict(cmap='gray', interpolation='nearest', vmin=0, vmax=255, alpha=1)
        shap_kwargs = dict(cmap=red_transparent_blue, interpolation='bilinear', vmin=shap_min, vmax=shap_max, alpha=0.8)
        
        for cls in [actual_cls, bd_cls]:
            for imgs, shaps, preds, bdname in zip([data['test'][0], test_bd], [shap_values, shap_values_bd], [pred, pred_bd], ['clean', 'bd']):
                plt.imshow(grayscale(imgs[i]), **image_kwargs)
                plt.imshow(shaps[cls][0].sum(axis=0), **shap_kwargs)
                print(shaps[cls][0].sum(axis=0).max(), shaps[cls][0].min(axis=0))
                plt.axis('off')
                plt.savefig(f'./output/shap/{name}_{i}_{cls}_{bdname}.svg', bbox_inches='tight')
                plt.savefig(f'./output/shap/{name}_{i}_{cls}_{bdname}.png', bbox_inches='tight')
                plt.savefig(f'./output/shap/{name}_{i}_{cls}_{bdname}.pdf', bbox_inches='tight')
                print(f'Pred {cls}: {(preds[cls])*100:.2f}%')
#                 plt.title(f'Pred {cls}: {(preds[cls])*100:.2f}%')
                print(f'./output/shap/{name}_{i}_{cls}_{bdname}.svg')
                plt.show()

weights/cifar_clean.pth


NameError: name 'shap_min' is not defined

In [None]:
image_kwargs = dict(cmap='gray', interpolation='nearest', vmin=0, vmax=255, alpha=1)
for i in range(100):
    print(i)
    plt.imshow(grayscale(data['test'][0][i]), **image_kwargs)
    plt.show()