<center><h1>WRN: Cifar10</h1></center>

## Imports

In [1]:
from __future__ import division,print_function

%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
from tqdm import tqdm_notebook as tqdm

import random
import matplotlib.pyplot as plt
import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
from torch.autograd import Variable, grad
from torchvision import datasets, transforms
from torch.nn.parameter import Parameter
import pandas as pd
import utils.calculate_log as callog
from utils.wrn import WideResNet

from utils.detector import Detector, gram_margin_loss
import utils.attacks as attacks

import warnings
warnings.filterwarnings('ignore')

In [2]:
torch.cuda.set_device(2)

## Model definition

In [3]:
torch_model = WideResNet(depth=40, widen_factor=2, num_classes=10)

torch_model.load(path="benchmark_ckpts/cifar10_reg_training_99.pt")
# torch_model.load(path="benchmark_ckpts/cifar10_style_epoch_99.pt")
# torch_model.load(path="checkpoints_style_fine_tuning/cifar10_wrn_baseline_epoch_2.pt")
torch_model.cuda()
torch_model.params = list(torch_model.parameters())
torch_model.eval()
print("Done")    

Done


## Datasets

<b>In-distribution Datasets</b>

In [4]:
batch_size = 256
# mean = np.array([[125.3/255, 123.0/255, 113.9/255]]).T

# std = np.array([[63.0/255, 62.1/255.0, 66.7/255.0]]).T
# normalize = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))

normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
        
    ])
transform_test = transforms.Compose([
        transforms.CenterCrop(size=(32, 32)),
        transforms.ToTensor(),
        normalize
    ])

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('~/datasets/cifarpy', train=True, download=True,
                   transform=transform_train),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('~/datasets/cifarpy', train=False, transform=transform_test),
    batch_size=batch_size)


detector_data_transform = transforms.Compose([transforms.ToTensor(), normalize])
data_train = list(torch.utils.data.DataLoader(
        datasets.CIFAR10('~/datasets/cifarpy', 
                     train=True, 
                     transform=detector_data_transform, 
                     download=True),
        batch_size=1, shuffle=False))

data_test = list(torch.utils.data.DataLoader(
        datasets.CIFAR10('~/datasets/cifarpy', 
                     train=False, 
                     transform=detector_data_transform, 
                     download=True),
        batch_size=1, shuffle=False))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [5]:
def pipeline_batch(bxs):
    pil = transforms.ToPILImage()
    return torch.squeeze(torch.stack([transform_test(pil(bx)) for bx in bxs]), dim=1)

def get_batches(d, batch_size=32):
    bx = []
    by = []
    tens = transforms.ToTensor()
    for idx in range(0,len(d),batch_size):
        bx_batch = torch.squeeze(torch.stack([tens(x[0]) for x in d[idx:idx+batch_size]]),dim=1)
        bx.append(bx_batch)
        by.append(torch.Tensor([x[1] for x in d[idx:idx+batch_size]]).type(torch.LongTensor))
    
    return bx, by

def advs_p(p, bxs, bys, nrof_batches=None):
    if nrof_batches is None:
        nrof_batches = len(bxs)
        
    advs = []
    for i in tqdm(range(len(bxs))):
        if i >= nrof_batches:
            break
        
        _, feats_reg = torch_model.gram_forward((bxs[i]*2 - 1).cuda())
        advs_batch = p(torch_model, bxs[i].cuda(), bys[i].cuda())

        advs.append(advs_batch)

    torch.cuda.empty_cache()
    
    return advs

def adversarial_acc(advs, bys):
    torch_model.eval()
    correct = 0
    total = 0

    for i in range(len(advs)):
        pipelined = pipeline_batch(advs[i].cpu())

        x = pipelined.cuda()
        y = bys[i].numpy()

        correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()
        total += y.shape[0]


    print("Adversarial Test Accuracy: ", correct/total)
    
def ds_grouped(bxs, bys):
    ds = []
    for i in range(len(bxs)):
        pipelined = pipeline_batch(bxs[i].cpu())
        for j in range(len(bxs[i])):
            ds.append((pipelined[j], bys[i][j]))
    return ds

def adversarial_scores(detector, advs_batches, pbar = lambda x, total=None: x):
    auroc = []
    for batch in pbar(advs_batches):
        auroc.append(detector.compute_ood_deviations_batch(batch*2 - 1)["AUROC"])
    
    return np.mean(auroc)

    
    
def model_accuracy():
    torch_model.eval()
    correct = 0
    total = 0
    for x,y in test_loader:
        x = x.cuda()
        y = y.numpy()
        correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()
        total += y.shape[0]
        
    return correct/total

<center><h1> Results </h1></center>

In [25]:
model_accuracy()

0.9462

In [6]:
detector = Detector(torch_model, data_train, data_test, 512, pbar=None)

In [27]:
adversary = attacks.PGD(epsilon=8./255, num_steps=10, step_size=2./255).cuda()

In [28]:
cifar10 = list(datasets.CIFAR10('~/datasets/cifarpy', train=False))

print("Calculating L_Inf")
xs, ys = get_batches(cifar10, batch_size=128)
# pinf = PGD()
pinf = adversary
# pinf = PGD_margin().cuda()
advs_inf = advs_p(pinf, xs, ys)

adversarial_acc(advs_inf, ys)

adversarial_scores(detector, advs_inf, pbar=tqdm)

Calculating L_Inf


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Adversarial Test Accuracy:  0.0


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




0.9675182555379747

In [39]:
ys[0].shape

torch.Size([128])

In [42]:
with torch.no_grad():
    adv_logits, adv_feats = torch_model.gram_forward(advs_inf[0] * 2 -1)

In [63]:
compute_ood_deviations_advs(detector, adv_logits, adv_feats, ys[0])

(0.9640671874999998, 0)

In [None]:
def advs_stats(detector, advs, ys):
    preds = np.argmax(torch_model(x))

In [36]:
cifar10 = list(datasets.CIFAR10('~/datasets/cifarpy', train=False))
random.shuffle(cifar10)
xs, ys = get_batches(cifar10, batch_size=128)

In [37]:
def gram_matrix(layer):
    b, ch, h, w = layer.size()
    features = layer.view(b, ch, w * h)
    gram = torch.matmul(features, features.transpose(1, 2))
    
    return gram /(ch * h * w)

def style_loss(lhs, rhs):
    loss = 0.0
    for i in range(len(lhs)):
        loss += (gram_matrix(lhs[i]) - gram_matrix(rhs[i])).pow(2).sum()
    
    return loss.mean()

def calc_vals(x, y):
    x, y = x.cuda(), y.cuda()
    
    logits_reg, feats_reg = torch_model.gram_forward(x*2 - 1)
    
#     adv_x = attacker_smart(torch_model, x, y, feats_reg)
    adv_x = attacker_smart(torch_model, x, y)
#     adv_x = attacker_naive(torch_model, x, y)
    logits_adv, feats_adv = torch_model.gram_forward(adv_x * 2 - 1)
    
    x, y = x.cpu(), y.cpu()
    adv_x = adv_x.cpu()
    logits_adv = logits_adv.cpu()
    
    acc = (y==torch.max(logits_adv,dim=1)[1]).numpy().mean()
    auroc, auroc_failed = detector.compute_auroc_advs(logits_adv, feats_adv, y)
    
    return feats_reg, feats_adv, auroc, auroc_failed

In [38]:
def process_batch(x, y, margin=20, smart=True):
    feats_reg, feats_adv, auroc, auroc_failed = calc_vals(x, y)
    
#     return style_loss(feats_reg, feats_adv).cpu()
    return auroc, auroc_failed

In [45]:
with torch.no_grad():
    attacker_smart = PGD_Gram(gram_target=calc_gram_dev_target(), 
                              num_steps=10, 
                              epsilon=8./255, 
                              step_size=2./255, 
                              verbose=False)
    
#     attacker_smart = PGD_margin(style_weight = 0.0,
#                                 epsilon=8./255, 
#                                 num_steps=10, 
#                                 step_size=2/255, 
#                                 verbose=True)
#     attacker_naive = attacks.PGD(epsilon=8./255, num_steps=10, step_size=2./255)
    
    auroc, auroc_failed = [], []
    for i, x in tqdm(enumerate(xs), total=len(xs)):
        if i % 3 != 0:
            continue
        a, a_f = process_batch(x, ys[i])
        auroc.append(a)
        auroc_failed.append(a_f)

HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




In [46]:
np.array(auroc).mean(), np.array(auroc_failed).mean()

(0.689618254289906, 0.5497135921075128)

Gram Loss Detector:

- Auroc: `(0.689618254289906, 0.5497135921075128)`

In [7]:
powers=[1]
def cpu(ob):
    for i in range(len(ob)):
        for j in range(len(ob[i])):
            ob[i][j] = ob[i][j].cpu()
    return ob
    
def cuda(ob):
    for i in range(len(ob)):
        for j in range(len(ob[i])):
            ob[i][j] = ob[i][j].cuda()
    return ob
def calc_gram_dev_target():
    return detector.all_test_deviations.mean(axis=0).sum() 

def G_p_gpu(temp, p):
    temp = temp.reshape(temp.shape[0],temp.shape[1],-1)
    temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2)
    return temp.reshape(temp.shape[0],-1)

# def G_p_gpu(ob, p):
#     temp = ob
    
#     temp = temp**p
#     temp = temp.reshape(temp.shape[0],temp.shape[1],-1)
#     temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) 
#     temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)
    
#     return temp

class PGD_Gram(nn.Module):
    def __init__(self, epsilon=8/255, num_steps=10, step_size=2/255, grad_sign=True, 
                         mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5], nrof_classes=10, gram_target = 247, verbose=True):
        super().__init__()
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.step_size = step_size
        self.grad_sign = grad_sign
        
        if mean is None:
            self.mean = torch.FloatTensor([0.4914, 0.4822, 0.4465]).view(1,3,1,1).cuda()
        else:
            self.mean = torch.FloatTensor(mean).view(1,3,1,1).cuda()
        if std is None:
            self.std = torch.FloatTensor([0.2023, 0.1994, 0.2010]).view(1,3,1,1).cuda()
        else:
            self.std = torch.FloatTensor(std).view(1,3,1,1).cuda()
            
        self.mns = [cuda(detector.mins[i]) for i in range(nrof_classes)]
        self.mxs = [cuda(detector.maxs[i]) for i in range(nrof_classes)]
        self.gram_target = gram_target * 0.85
        self.verbose = verbose
            
    def get_deviation(self, feat_list, idx, mins, maxs, power=powers):
        batch_deviations = []
        for L,feat_L in enumerate(feat_list):
            dev = 0
            for p,P in enumerate(power):
                g_p = G_p_gpu(feat_L,P)[idx]
                
                dev +=  (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)
                dev +=  (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)
                
                batch_deviations.append(dev)
                
        return batch_deviations
        
    def gram_loss(self, feats, logits):
        confs = F.softmax(logits, dim=1)
        _, indices = torch.max(confs, 1)
        
        loss = 0
        for i in range(10):
            idxs = indices == i

            if idxs.sum() == 0:
                continue
            
            batch_dev = self.get_deviation(feats, idxs, mins=self.mns[i], maxs=self.mxs[i])
            batch_dev = torch.squeeze(torch.stack(batch_dev, dim=1))
            
            loss += batch_dev.sum()
                            
        return F.relu((loss/logits.shape[0]) - self.gram_target)
    
    def forward(self, model, bx, by):
        """
        :param model: the classifier's forward method
        :param bx: batch of images
        :param by: true labels
        :return: perturbed batch of images
        """
        model.eval()
        
        adv_bx = bx.detach()
        adv_bx += torch.zeros_like(adv_bx).uniform_(-self.epsilon, self.epsilon)

        for i in range(self.num_steps):
            adv_bx.requires_grad_()
            with torch.enable_grad():
                logits, feats = model.gram_forward((adv_bx - self.mean)/self.std)
                
                cent_loss = F.cross_entropy(logits, by, reduction='mean')
                gram_loss =  self.gram_loss(feats, logits)
                
                loss = cent_loss - gram_loss
                                
            if self.verbose:
                print("Step: {}, Cent: {}, Gram: {}, Total Loss: {}".format(i, cent_loss, gram_loss, loss))
            
            grad = torch.autograd.grad(loss, adv_bx, only_inputs=True)[0]
            adv_bx = adv_bx.detach() + self.step_size * torch.sign(grad.detach())
            adv_bx = torch.min(torch.max(adv_bx, bx - self.epsilon), bx + self.epsilon).clamp(0, 1)

        return adv_bx

In [7]:
def G_p(temp):
    temp = temp.reshape(temp.shape[0],temp.shape[1],-1)
    temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2)
    return temp.reshape(temp.shape[0],-1)

class PGD_margin(nn.Module):
    def __init__(self, epsilon=8./255, num_steps=10, step_size=2./255, style_weight = 100, grad_sign=True, verbose=False):
        super().__init__()
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.step_size = step_size
        self.grad_sign = grad_sign
        self.verbose = verbose
        self.style_weight = style_weight

    def forward(self, model, bx, by, feats_reg):
        """
        :param model: the classifier's forward method
        :param bx: batch of images
        :param by: true labels
        :return: perturbed batch of images
        """
        adv_bx = bx.detach()
        adv_bx += torch.zeros_like(adv_bx).uniform_(-self.epsilon, self.epsilon)
        
        for i in range(self.num_steps):
            adv_bx.requires_grad_()
            with torch.enable_grad():
                logits, feats_adv = model.gram_forward(adv_bx * 2 - 1)
                s_loss = style_loss(feats_reg, feats_adv)
                cent_loss = F.cross_entropy(logits, by, reduction='mean')
                
                loss = cent_loss - self.style_weight * s_loss
                
                if self.verbose:
                    print("Step: {}, Cent: {}, Style Loss: {}, Total Loss: {}".format(i, cent_loss, s_loss, loss))
            grad = torch.autograd.grad(loss, adv_bx, only_inputs=True)[0]
            adv_bx = adv_bx.detach() + self.step_size * torch.sign(grad.detach())
            adv_bx = torch.min(torch.max(adv_bx, bx - self.epsilon), bx + self.epsilon).clamp(0, 1)
            
        return adv_bx

In [7]:
# def advs_stats(detector, advs_logits, advs_feats, ys):    

def select_features(feat_list, idxs):
    return [f[idxs] for f in feat_list]
    
def get_deviations(feat_list, mins,maxs):
    if len(feat_list[0]) == 0:
        return np.array([])
    deviations = []
    for L,feat_L in enumerate(feat_list):
        
        g_p = G_p(feat_L)

        dev =  (F.relu(mins[L][0]-g_p)/torch.abs(mins[L][0]+10**-6)).sum(dim=1,keepdim=True)
        dev +=  (F.relu(g_p-maxs[L][0])/torch.abs(maxs[L][0]+10**-6)).sum(dim=1,keepdim=True)

        deviations.append(dev.cpu().numpy())
            
    deviations = np.concatenate(deviations, axis=1)
    
    return deviations
    
def compute_ood_deviations_advs(self, adv_logits, adv_feats, adv_ys):
    confs = F.softmax(adv_logits,dim=1).cpu().numpy()
    preds = np.argmax(confs,axis=1)

    adv_deviations = None
    failed_adv_deviations = None

    for PRED in self.classes:
        idxs = np.where(np.array(preds)==PRED)[0]
        class_ys = np.array(adv_ys)[idxs]
        idxs_failed = np.where(class_ys==PRED)[0]
        
        if len(idxs)==0:
            continue

        mins = self.mins[PRED]
        maxs = self.maxs[PRED]
        
        adv_dev_class = get_deviations(select_features(adv_feats,idxs), mins=mins, maxs=maxs)
        failed_adv_dev_class = get_deviations(select_features(adv_feats,idxs_failed), mins=mins, maxs=maxs)

        if adv_deviations is None:
            adv_deviations = adv_dev_class
            failed_adv_deviations = failed_adv_dev_class
            
        else:
            adv_deviations = np.concatenate([adv_deviations, adv_dev_class],axis=0)
            failed_adv_deviations = np.concatenate([failed_adv_deviations, failed_adv_dev_class], axis=0)
            
    failed_results, results = 0,0
    if len(failed_adv_deviations) != 0:
        failed_results = detect(self.all_test_deviations, failed_adv_deviations)["AUROC"]
        
    if len(adv_deviations) != 0:
        results = detect(self.all_test_deviations, adv_deviations)["AUROC"]
    
    return results, failed_results

import utils.calculate_log as callog

def calc_auroc(all_test_deviations,all_ood_deviations):
    average_results = {}

    test_deviations = all_test_deviations.sum(axis=1)
    ood_deviations = all_ood_deviations.sum(axis=1)

    results = callog.compute_metric(-test_deviations,-ood_deviations)

    return results["AUROC"]