In [None]:
import torch
from utils.data import get_dataloaders, AdversarialDataset
from utils.models import get_model, Mask, MaskedClf
from torch.utils.data import DataLoader
from utils.data import AdversarialDataset
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model=get_model('resnet20')
base_model = base_model.to(device)
base_model.load_state_dict(torch.load("trained_models/resnet20/clean.pt"))
base_model.eval()
if not os.path.exists('class_specific'):
    os.makedirs('class_specific', exist_ok=True)

In [None]:
loss=torch.nn.CrossEntropyLoss()
dataloaders=get_dataloaders('cifar10', 128, 1, shuffle_train=True, shuffle_test=False, unnorm=True)
dataset=AdversarialDataset(None, 'resnet20', 'FMN', dataloaders['train'], 32, 'train')
dataset.clean_imgs=dataset.clean_imgs[dataset.labels.argsort()]
dataset.labels=dataset.labels[dataset.labels.argsort()]
train_dataloader= DataLoader(dataset, batch_size=5000, shuffle=False)

In [None]:

for x,xadv,y in train_dataloader:
    print("Class: ", y.unique()[0].item())
    losses=[]
    x=x.to(device)
    y=y.to(device)
    correct=(base_model(x).argmax(-1)==y)
    x=x
    x=x[correct]
    y=y[correct]
    model=MaskedClf(Mask((3, 32, 32)).to(device), base_model)
    for p in model.clf.parameters():
        p.requires_grad=False
    model.mask.train()
    optimizer=torch.optim.Adam(model.mask.parameters(), lr=0.01)
    for e in range(5000):
        print(e, end='\r')
        out=model(x)
        l=loss(out, y)
        penalty=model.mask.M.abs().sum()
        l+=penalty*0.01
        losses.append(l.item())
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        model.mask.M.data.clamp_(0., 1.)
        c=y[0].cpu().item()
        if(e>500 and abs(l.item()-np.mean(losses[-20:]))<1e-5):
            print((model(x).argmax(-1)==y).sum(), e)
            mask=torch.fft.fftshift(model.mask.M.detach().cpu())
            mask=mask.squeeze().numpy()
            np.save(f'class_specific/2{c}.npy', mask)
            plt.figure()
            plt.imshow(mask[0], cmap="Blues")
            plt.colorbar()
            plt.savefig(f'class_specific/{c}R.png')
            plt.close()
            plt.figure()
            plt.imshow(mask[1], cmap="Blues")
            plt.colorbar()
            plt.savefig(f'class_specific/{c}G.png')
            plt.close()
            plt.figure()
            plt.imshow(mask[2], cmap="Blues")
            plt.colorbar()
            plt.savefig(f'class_specific/{c}B.png')
            plt.close()
            break
    del model

In [None]:
adv_test_dataloader=DataLoader(AdversarialDataset(None, 'resnet20', 'FMN', dataloaders['test'], 32, 'test'), batch_size=1000, shuffle=False)
correct=0
adversarial=0
masked=0
for x,xadv,y in adv_test_dataloader:
    x=x.to(device)
    xadv=xadv.to(device)
    y=y.to(device)
    c=y[0].cpu().item()
    clean_out=base_model(x)
    correct_images=(clean_out.argmax(-1)==y)
    correct+=correct_images.sum()
    adv_out=base_model(xadv[correct_images])
    adv_images=(adv_out.argmax(-1)!=y[correct_images])
    adversarial+=adv_images.sum()
    masked_model=MaskedClf(Mask((3, 32, 32)).to(device), base_model)
    masked_model.mask.M.data=torch.fft.ifftshift(torch.tensor(np.load(f'class_specific/{c}.npy')))
    masked_model.mask=masked_model.mask.to(device)
    xadv=xadv[correct_images]
    masked_out=masked_model(xadv)
    masked_images=(masked_out.argmax(-1)!=y[correct_images])
    masked+=masked_images.sum()
print("Correctly classified: ", correct.item(), "Adversarial: ", adversarial.item(), "Adversarial after using the mask:", masked.item())