<center><h1>WRN: Cifar10</h1></center>

## Imports

In [1]:
from __future__ import division,print_function

%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
from tqdm import tqdm_notebook as tqdm

import random
import matplotlib.pyplot as plt
import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
from torch.autograd import Variable, grad
from torchvision import datasets, transforms
from torch.nn.parameter import Parameter

import calculate_log as callog

from pgd import PGD

import warnings
warnings.filterwarnings('ignore')

In [2]:
torch.cuda.set_device(3)

## Model definition

In [3]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes,track_running_stats=True)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes,track_running_stats=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
            torch_model.record(x)
        else:
            out = self.relu1(self.bn1(x))
            torch_model.record(out)
        if self.equalInOut:
            out = self.conv1(out)
        else:
            out = self.conv1(x)
        torch_model.record(out)
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        t = self.relu2(self.bn2(out))
        torch_model.record(t)
        out = self.conv2(t)
        torch_model.record(out)
        if not self.equalInOut:
            return torch.add(self.convShortcut(x), out)
        else:
            return torch.add(x, out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(nb_layers):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        self.collecting = False
        n = (depth - 4) // 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    
    def forward(self, x):
        out = self.conv1(x)
        torch_model.record(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        lg = self.fc(out)
        return lg
    
    def gram_forward(self, x):
        self.collecting = True
        self.gram_feats = []
        logits = self.forward(x)
        self.collecting = False
        temp = self.gram_feats
        self.gram_feats = []
        return logits, temp
    
    def record(self, t):
        if self.collecting:
            self.gram_feats.append(t)
    
    def gram_feature_list(self,x):
        self.collecting = True
        self.gram_feats = []
        self.forward(x)
        self.collecting = False
        temp = self.gram_feats
        self.gram_feats = []
        return temp
    
#     def load(self, path="cifar10_wrn_oe_scratch_epoch_99.pt"):
    def load(self, path="model_training/checkpoints/cifar10_wrn_baseline_epoch_99.pt"):
        tm = torch.load(path,map_location="cpu")        
        self.load_state_dict(tm,strict=False)
    
    def get_min_max(self, data, power):
        mins = []
        maxs = []
        
        for i in range(0,len(data),64):
            batch = data[i:i+64].cuda()
            feat_list = self.gram_feature_list(batch)
            for L,feat_L in enumerate(feat_list):
                if L==len(mins):
                    mins.append([None]*len(power))
                    maxs.append([None]*len(power))
                
                for p,P in enumerate(power):
                    g_p = G_p(feat_L,P)
                    
                    current_min = g_p.min(dim=0,keepdim=True)[0]
                    current_max = g_p.max(dim=0,keepdim=True)[0]
                    
                    if mins[L][p] is None:
                        mins[L][p] = current_min
                        maxs[L][p] = current_max
                    else:
                        mins[L][p] = torch.min(current_min,mins[L][p])
                        maxs[L][p] = torch.max(current_max,maxs[L][p])
        
        return mins,maxs
    
    def get_deviations(self,data,power,mins,maxs):
        deviations = []
        
        for i in range(0,len(data),64):            
            batch = data[i:i+64].cuda()
            feat_list = self.gram_feature_list(batch)
            batch_deviations = []
            for L,feat_L in enumerate(feat_list):
                dev = 0
                for p,P in enumerate(power):
                    g_p = G_p(feat_L,P)
                    
                    dev +=  (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)
                    dev +=  (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)
                batch_deviations.append(dev.cpu().detach().numpy())
            batch_deviations = np.concatenate(batch_deviations,axis=1)
            deviations.append(batch_deviations)
        deviations = np.concatenate(deviations,axis=0)
        
        return deviations

torch_model = WideResNet(depth=40, widen_factor=2, num_classes=10)

torch_model.load()
torch_model.cuda()
torch_model.params = list(torch_model.parameters())
torch_model.eval()
print("Done")    

Done


## Datasets

<b>In-distribution Datasets</b>

In [4]:
batch_size = 128
mean = np.array([[125.3/255, 123.0/255, 113.9/255]]).T

std = np.array([[63.0/255, 62.1/255.0, 66.7/255.0]]).T
normalize = transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0))

# normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
        
    ])
transform_test = transforms.Compose([
        transforms.CenterCrop(size=(32, 32)),
        transforms.ToTensor(),
        normalize
    ])

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=True, download=True,
                   transform=transform_train),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=False, transform=transform_test),
    batch_size=batch_size)


Files already downloaded and verified


In [5]:
data_train = list(torch.utils.data.DataLoader(
        datasets.CIFAR10('data', train=True, download=True,
                       transform=transform_test),
        batch_size=1, shuffle=False))

Files already downloaded and verified


In [6]:
data = list(torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=False, download=True,
                   transform=transform_test),
    batch_size=1, shuffle=False))

Files already downloaded and verified


## Code for Detecting OODs

<b> Extract predictions for train and test data </b>

In [7]:
train_preds = []
train_confs = []
train_logits = []
for idx in range(0,len(data_train),128):
    batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()
    
    logits = torch_model(batch)
    confs = F.softmax(logits,dim=1).cpu().detach().numpy()
    preds = np.argmax(confs,axis=1)
    logits = (logits.cpu().detach().numpy())#**2)#.sum(axis=1)

    train_confs.extend(np.max(confs,axis=1))    
    train_preds.extend(preds)
    train_logits.extend(logits)
print("Done")

test_preds = []
test_confs = []
test_logits = []

for idx in range(0,len(data),128):
    batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()
    
    logits = torch_model(batch)
    confs = F.softmax(logits,dim=1).cpu().detach().numpy()
    preds = np.argmax(confs,axis=1)
    logits = (logits.cpu().detach().numpy())#**2)#.sum(axis=1)

    test_confs.extend(np.max(confs,axis=1))    
    test_preds.extend(preds)
    test_logits.extend(logits)
print("Done")

Done
Done


<b> Code for detecting OODs by identifying anomalies in correlations </b>

In [8]:
import calculate_log as callog

def detect(all_test_deviations,all_ood_deviations, test_confs = None, ood_confs=None, verbose=True, normalize=False):
    if test_confs is not None:
        test_confs = np.array(test_confs)
        ood_confs = np.array(ood_confs)
    
    average_results = {}
    for i in range(1,11):
        random.seed(i)
        
        validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))
        test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))

        validation = all_test_deviations[validation_indices]
        test_deviations = all_test_deviations[test_indices]

        t95 = validation.mean(axis=0)+10**-7
        if not normalize:
            t95 = np.ones_like(t95)
        test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)
        ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)

        if test_confs is not None:
            thresh = np.max((validation/t95[np.newaxis,:]).sum(axis=1))
                        
            ood_deviations = ood_deviations - thresh*ood_confs
            test_deviations = test_deviations - thresh*test_confs[test_indices]
        
        results = callog.compute_metric(-test_deviations,-ood_deviations)
        for m in results:
            average_results[m] = average_results.get(m,0)+results[m]
    
    for m in average_results:
        average_results[m] /= i
    if verbose:
        callog.print_results(average_results)
    return average_results

def cpu(ob):
    for i in range(len(ob)):
        for j in range(len(ob[i])):
            ob[i][j] = ob[i][j].cpu()
    return ob
    
def cuda(ob):
    for i in range(len(ob)):
        for j in range(len(ob[i])):
            ob[i][j] = ob[i][j].cuda()
    return ob

class Detector:
    def __init__(self):
        self.all_test_deviations = None
        self.mins = {}
        self.maxs = {}
        
        self.classes = range(10)
    
    def compute_minmaxs(self,data_train,POWERS=[10]):
        for PRED in tqdm(self.classes):
            train_indices = np.where(np.array(train_preds)==PRED)[0]
            train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)
            mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)
            self.mins[PRED] = cpu(mins)
            self.maxs[PRED] = cpu(maxs)
            torch.cuda.empty_cache()
    
    def compute_test_deviations(self,POWERS=[10]):
        all_test_deviations = None
        all_test_deviations_msp = None
        all_test_confs = []
        for PRED in tqdm(self.classes):
            test_indices = np.where(np.array(test_preds)==PRED)[0]
            test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)
            test_confs_PRED = np.array([test_confs[i] for i in test_indices])
            all_test_confs.extend(test_confs_PRED)
            mins = cuda(self.mins[PRED])
            maxs = cuda(self.maxs[PRED])
            test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)
            test_deviations_MSP = test_deviations/test_confs_PRED[:,np.newaxis]
            cpu(mins)
            cpu(maxs)
            if all_test_deviations is None:
                all_test_deviations = test_deviations
                all_test_deviations_MSP = test_deviations_MSP
            else:
                all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)
                all_test_deviations_MSP = np.concatenate([all_test_deviations_MSP,test_deviations_MSP],axis=0)
            torch.cuda.empty_cache()
        self.all_test_confs = all_test_confs
        self.all_test_deviations = all_test_deviations
        self.all_test_deviations_MSP = all_test_deviations_MSP
    
    def compute_ood_deviations(self,ood,POWERS=[10],msp=False):
        ood_preds = []
        ood_confs = []
        
        for idx in range(0,len(ood),128):
            batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()
            logits = torch_model(batch)
            confs = F.softmax(logits,dim=1).cpu().detach().numpy()
            preds = np.argmax(confs,axis=1)
            
            ood_confs.extend(np.max(confs,axis=1))
            ood_preds.extend(preds)  
            torch.cuda.empty_cache()
#         print("MSP")
#         callog.print_results(callog.compute_metric(np.array(test_confs),np.array(ood_confs)))
        
        all_ood_deviations = None
        all_ood_deviations_MSP = None
        all_ood_confs = []
        for PRED in self.classes:
            ood_indices = np.where(np.array(ood_preds)==PRED)[0]
            if len(ood_indices)==0:
                continue
            ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)
            
            ood_confs_PRED =  np.array([ood_confs[i] for i in ood_indices])
            
            all_ood_confs.extend(ood_confs_PRED)
            
            mins = cuda(self.mins[PRED])
            maxs = cuda(self.maxs[PRED])
            ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)
            ood_deviations_MSP = ood_deviations/ood_confs_PRED[:,np.newaxis]
            cpu(self.mins[PRED])
            cpu(self.maxs[PRED])            
            if all_ood_deviations is None:
                all_ood_deviations = ood_deviations
                all_ood_deviations_MSP = ood_deviations_MSP
            else:
                all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)
                all_ood_deviations_MSP = np.concatenate([all_ood_deviations_MSP,ood_deviations_MSP],axis=0)
            torch.cuda.empty_cache()
        
        self.all_ood_confs = all_ood_confs
        
#         print("Ours")
        average_results = detect(self.all_test_deviations,all_ood_deviations)
#         print("Ours+MSP")
#         average_results = detect(self.all_test_deviations,all_ood_deviations,self.all_test_confs,self.all_ood_confs)
        return average_results, self.all_test_deviations, all_ood_deviations

In [9]:
def pipeline_batch(bxs):
    pil = transforms.ToPILImage()
    return torch.squeeze(torch.stack([transform_test(pil(bx)) for bx in bxs]), dim=1)

def get_b(d):
    batch_size = 32
    bx = []
    by = []
    tens = transforms.ToTensor()
    for idx in range(0,len(d),batch_size):
        bx_batch = torch.squeeze(torch.stack([tens(x[0]) for x in d[idx:idx+batch_size]]),dim=1)
        bx.append(bx_batch)
        by.append(torch.Tensor([x[1] for x in d[idx:idx+batch_size]]).type(torch.LongTensor))
    
    return bx, by

def advs_p(p, bxs, bys, nrof_batches=None):
    if nrof_batches is None:
        nrof_batches = len(bxs)
        
    advs = []
    for i in tqdm(range(len(bxs))):
        if i >= nrof_batches:
            break
            
        advs_batch = p(torch_model, bxs[i].cuda(), bys[i].cuda())

        advs.append(advs_batch)

    torch.cuda.empty_cache()
    
    return advs

def adversarial_acc(advs, bys):
    torch_model.eval()
    correct = 0
    total = 0

    for i in range(len(advs)):
        pipelined = pipeline_batch(advs[i].cpu())

        x = pipelined.cuda()
        y = bys[i].numpy()

        correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()
        total += y.shape[0]


    print("Adversarial Test Accuracy: ", correct/total)
    
def ds_grouped(bxs, bys):
    ds = []
    for i in range(len(bxs)):
        pipelined = pipeline_batch(bxs[i].cpu())
        for j in range(len(bxs[i])):
            ds.append((pipelined[j], bys[i][j]))
    return ds

def adversarial_scores(advs, bys, powers, folder=""):
    ds = ds_grouped(advs, bys)
    tds = list(torch.utils.data.DataLoader(ds,batch_size=1,shuffle=True))
    _ = detector.compute_ood_deviations(tds, POWERS=powers)
    
    
def model_accuracy():
    torch_model.eval()
    correct = 0
    total = 0
    for x,y in test_loader:
        x = x.cuda()
        y = y.numpy()
        correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()
        total += y.shape[0]
        
    return correct/total

<center><h1> Results </h1></center>

In [10]:
powers = [1]
def G_p(ob, p):
    temp = ob.detach()
    
    temp = temp.reshape(temp.shape[0],temp.shape[1],-1)
    temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2)
    temp = temp.reshape(temp.shape[0],-1)
    
    return temp

detector = Detector()
detector.compute_minmaxs(data_train,POWERS=powers)

detector.compute_test_deviations(POWERS=powers)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




In [12]:
cifar10 = list(datasets.CIFAR10('data', train=False))

print("Calculating L_Inf")
xs, ys = get_b(cifar10)
pinf = PGD()
advs_inf = advs_p(pinf, xs, ys)

adversarial_acc(advs_inf, ys)

adversarial_scores(advs_inf, ys, powers=powers)

Calculating L_Inf


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Adversarial Test Accuracy:  0.0014
 TNR    AUROC  DTACC  AUIN   AUOUT 
 56.273 75.262 77.094 58.851 84.621


In [15]:
print("Benchmark For a Cifar10 WideResNet Trained With OE\n")

print("Model Accuracy On Test Set:", model_accuracy())
adversarial_acc(advs_inf, ys)
print("Detection Benchmark:")
adversarial_scores(advs_inf, ys, powers=powers)

print("\nAverage Gram Deviations For Test Set: {}\n".format(calc_gram_dev_target()))


Benchmark For a Cifar10 WideResNet Trained With OE

Model Accuracy On Test Set: 0.5494
Adversarial Test Accuracy:  0.0014
Detection Benchmark:
 TNR    AUROC  DTACC  AUIN   AUOUT 
 56.273 75.262 77.094 58.851 84.621

Average Gram Deviations For Test Set: 0.6627426147460938



In [26]:
print("–––– Create Undetectible Adversarial Attacks ––––")
print("Epsilon: 8/255, Num Steps: 10, Step Size: 2/255")

p_gram = PGD_Gram(gram_target=calc_gram_dev_target(), verbose=True)
advs_gram = advs_p(p_gram, xs, ys, nrof_batches = 1)
adversarial_acc(advs_gram, ys)
adversarial_scores(advs_gram, ys, powers)

–––– Create Undetectible Adversarial Attacks ––––
Epsilon: 8/255, Num Steps: 10, Step Size: 2/255


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))

Step: 0, Cent: 0.14773714542388916, Gram: 73.73921203613281, Total Loss: -73.59147644042969
Step: 1, Cent: 0.6315072178840637, Gram: 1.5961074829101562, Total Loss: -0.9646002650260925
Step: 2, Cent: 1.6839289665222168, Gram: 16.002609252929688, Total Loss: -14.318679809570312
Step: 3, Cent: 2.3709537982940674, Gram: 60.39494323730469, Total Loss: -58.023990631103516
Step: 4, Cent: 4.563270568847656, Gram: 0.0, Total Loss: 4.563270568847656
Step: 5, Cent: 6.832106590270996, Gram: 124.79806518554688, Total Loss: -117.96595764160156
Step: 6, Cent: 5.832967281341553, Gram: 0.0, Total Loss: 5.832967281341553
Step: 7, Cent: 8.62597942352295, Gram: 36.21221923828125, Total Loss: -27.586238861083984
Step: 8, Cent: 7.858233451843262, Gram: 0.0, Total Loss: 7.858233451843262
Step: 9, Cent: 10.103129386901855, Gram: 9.694610595703125, Total Loss: 0.40851879119873047
Adversarial Test Accuracy:  0.03125
 TNR    AUROC  DTACC  AUIN   AUOUT 
  3.125 32.498 53.872 98.666  0.325


In [106]:
adversarial_acc(advs_gram, ys)
adversarial_scores(advs_gram, ys, powers)

Adversarial Test Accuracy:  0.0805
 TNR    AUROC  DTACC  AUIN   AUOUT 
 10.962 49.469 59.682 43.090 58.825


In [14]:
def calc_gram_dev_target():
    return detector.all_test_deviations.mean(axis=0).sum() 

def G_p_gpu(ob, p):
    temp = ob
    
    temp = temp**p
    temp = temp.reshape(temp.shape[0],temp.shape[1],-1)
    temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) 
    temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)
    
    return temp

class PGD_Gram(nn.Module):
    def __init__(self, epsilon=8/255, num_steps=10, step_size=2/255, grad_sign=True, 
                         mean = None, std = None, nrof_classes=10, gram_target = 247, verbose=True):
        super().__init__()
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.step_size = step_size
        self.grad_sign = grad_sign
        
        if mean is None:
            self.mean = torch.FloatTensor([0.4914, 0.4822, 0.4465]).view(1,3,1,1).cuda()
        else:
            self.mean = torch.FloatTensor(mean).view(1,3,1,1).cuda()
        if std is None:
            self.std = torch.FloatTensor([0.2023, 0.1994, 0.2010]).view(1,3,1,1).cuda()
        else:
            self.std = torch.FloatTensor(std).view(1,3,1,1).cuda()
            
        self.mns = [cuda(detector.mins[i]) for i in range(nrof_classes)]
        self.mxs = [cuda(detector.maxs[i]) for i in range(nrof_classes)]
        self.gram_target = gram_target * 0.85
        self.verbose = verbose
            
    def get_deviation(self, feat_list, idx, mins, maxs, power=powers):
        batch_deviations = []
        for L,feat_L in enumerate(feat_list):
            dev = 0
            for p,P in enumerate(power):
                g_p = G_p_gpu(feat_L,P)[idx]
                
                dev +=  (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)
                dev +=  (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)
                
                batch_deviations.append(dev)
                
        return batch_deviations
        
    def gram_loss(self, feats, logits):
        confs = F.softmax(logits, dim=1)
        _, indices = torch.max(confs, 1)
        
        loss = 0
        for i in range(10):
            idxs = indices == i

            if idxs.sum() == 0:
                continue
            
            batch_dev = self.get_deviation(feats, idxs, mins=self.mns[i], maxs=self.mxs[i])
            batch_dev = torch.squeeze(torch.stack(batch_dev, dim=1))
            
            loss += batch_dev.sum()
                
        return F.relu(loss - logits.shape[0] * self.gram_target)
    
    def forward(self, model, bx, by):
        """
        :param model: the classifier's forward method
        :param bx: batch of images
        :param by: true labels
        :return: perturbed batch of images
        """
        model.eval()
        
        adv_bx = bx.detach()
        adv_bx += torch.zeros_like(adv_bx).uniform_(-self.epsilon, self.epsilon)

        for i in range(self.num_steps):
            adv_bx.requires_grad_()
            with torch.enable_grad():
                logits, feats = model.gram_forward((adv_bx - self.mean)/self.std)
                
                cent_loss = F.cross_entropy(logits, by, reduction='mean')
                gram_loss = self.gram_loss(feats, logits)
                
                loss = cent_loss - gram_loss
                
            if self.verbose:
                print("Step: {}, Cent: {}, Gram: {}, Total Loss: {}".format(i, cent_loss, gram_loss, loss))
            
            grad = torch.autograd.grad(loss, adv_bx, only_inputs=True)[0]
            adv_bx = adv_bx.detach() + self.step_size * torch.sign(grad.detach())
            adv_bx = torch.min(torch.max(adv_bx, bx - self.epsilon), bx + self.epsilon).clamp(0, 1)

        return adv_bx