# MUFIA attack

### Notebook for generating images from the main paper

In [None]:
import sys
from torchvision.models import *
from torchvision.utils import *
import torch
import os
import matplotlib.pyplot as plt
import ipywidgets as widgets
import seaborn as sns
import pandas as pd


os.environ["CUDA_VISIBLE_DEVICES"] = "3"

torch.cuda.empty_cache()

# fix torch seed
torch.manual_seed(42)
# fix cuda seed
torch.cuda.manual_seed(42)
# fix cudnn seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Defaults
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


### Define Parameters

In [None]:
# 
param = {}
param['batch_size'] = 32
param['threat_model'] = 'std' # std, prime, augmix, cc_sota
param['dataset'] = 'cifar10' # cifar10, cifar100, imagenet

param['save_mat'] = False # if true, save the filter
param['parallel'] = False

if param['dataset'] == 'cifar100':
    param['model_name'] = 'resnet56' 
else:
    param['model_name'] = 'resnet50'

if param['dataset'] == 'imagenet':
    param['block_size'] = 56
else:
    param['block_size'] = 32

param['lambda_reg'] = 20 
param['atk_type'] = 'mufia'
param['n_epochs'] = 100
param['print_every'] = 10
param['lr'] = 0.1
param['verbose'] = True
param['kappa'] = 0.99
param["sim_loss"] = 'cosine'


##### Data Loading

In [None]:
# define dataloaders
import sys
sys.path.insert(0, "..")
sys.path.append("../")
from misc import *
from data import *
from models import *

data_loading = DataLoading(params=param)
_, _, testset, testloader = data_loading.get_data()

param["dataloader"] = testloader


##### Model Loading

In [None]:
model_loading = ModelLoader(params=param, device=device)
net = model_loading.get_model()

net = net.to(device)
net.eval()
# dont print 
net.verbose = False

In [None]:
adv_acc = {}
clean_acc = {}
y_quantize = {}
adv_x = {}
adv_acc['std'] = 0
y_quantize['std'] = []
adv_x['std'] = []
corrupt_x = []
clean_acc['std'] = 0

In [None]:
from attacks.attacks import *
from attacks.attack_utils import *

### Running MUFIA on one batch

In [None]:
new_solver = FilterAttack(net, param, device)
for i, (x, y) in enumerate(tqdm(testloader)):
    x, y = x.to(device), y.to(device)
    corrupt_x.append(x)
    adv_x_basic, y_quantize_basic = new_solver(x, y)
    adv_x['std'].append(adv_x_basic)

    # get adv acc
    new_outputs = net(adv_x_basic)
    new_pred = (
        new_outputs[0] if (type(new_outputs) is tuple) else new_outputs
    ).argmax(dim=1, keepdim=True)
    adv_acc_curr = new_pred.eq(y.view_as(new_pred)).sum().item()
    adv_acc['std'] += adv_acc_curr

    # get clean acc
    old_outputs = net(x)
    old_pred = (
        old_outputs[0] if (type(old_outputs) is tuple) else old_outputs
    ).argmax(dim=1, keepdim=True)
    clean_acc_curr = old_pred.eq(y.view_as(old_pred)).sum().item()
    clean_acc['std'] += clean_acc_curr

    y_quantize['std'].append(y_quantize_basic)
    
    # comment this out if you want to run for all the images
    if i == 0:
        break

adv_acc['std'] = 100.0 * adv_acc['std'] / ((i + 1) * param['batch_size'])
clean_acc['std'] = 100.0 * clean_acc['std'] / ((i + 1) * param['batch_size'])


### Accuracy

In [None]:
print("After MUFIA accuracy: ", adv_acc['std'])
print("Clean accuracy: ", clean_acc['std'])

### Visualization

In [None]:
@widgets.interact(idx=(0, param['batch_size']-1))
def f(idx=0):
    # data, targets = x[idx:idx+1], y[idx:idx+1]\
    data, targets = testset[idx]
    data = data.to(device).unsqueeze(0)


    y_quantize_new = y_quantize['std'][0][idx:idx+1]
    # print(y_quantize_new.shape)
    
    adv_data_new = adv_x['std'][0][idx:idx+1]
    
    y_quantize_new = y_quantize_new[0, 0, 0, :, :]

    y_quantize_new = torch.abs(y_quantize_new)
    y_quantize_new = torch.tanh(y_quantize_new-1)

    img_diff = adv_data_new - data
    img_diff = torch.abs(img_diff)
    img_diff = img_diff * 10 # scale it up
    
    fig, ax = plt.subplots(1, 5)
    fig.set_size_inches(10, 10)

    # remove the border
    for a in ax:
        a.axis('off')
    ax[0].imshow(data.cpu().squeeze().numpy().transpose(1, 2, 0))
    ax[1].imshow(data.cpu().squeeze().numpy().transpose(1, 2, 0))
    ax[2].imshow(adv_data_new.cpu().squeeze().numpy().transpose(1, 2, 0))
    ax[3].imshow(img_diff.cpu().squeeze().numpy().transpose(1, 2, 0))
    import matplotlib as mpl
    cmap = plt.get_cmap('tab20c')
    norm = mpl.colors.Normalize(vmin=-1, vmax=1)
    ax[4].imshow(y_quantize_new.cpu().numpy(), cmap=cmap)
    pos = ax[4].get_position()
    
    cbar_ax = fig.add_axes([pos.x0 + pos.width + 0.01, pos.y0, 0.01, pos.height])
    fig.colorbar(ax[4].imshow(y_quantize_new.cpu().numpy(),
                 cmap=cmap, norm=norm), cax=cbar_ax)
    ax[0].set_title("True Label: " + str(targets), y=1.05)
    ax[1].set_title("Pred Label: " + str(net(data).argmax().item()), y=1.05)
    ax[2].set_title("Adv Label: " + str(net(adv_data_new).argmax().item()), y=1.05)
    ax[3].set_title("Difference", y=1.05)
    ax[4].set_title("Filter Bank", y=1.05)

    # remove axis ticks 
    for i in range(5):
        ax[i].set_xticks([])
        ax[i].set_yticks([])
    fig.subplots_adjust(wspace=0.1)
    plt.show()
