In [None]:
from google.colab import drive

drive.mount('/content/gdrive')


Mounted at /content/gdrive


In [None]:
import time
import copy
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
torch.manual_seed(2)
np.random.seed(2)
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler
import torchvision.models as models
!pip install torchinfo
!pip install torchattacks
!pip install pip install grad-cam
# !pip install geomloss[full]
# from geomloss import SamplesLoss
from torchinfo import summary
from torchvision.models.resnet import _resnet,BasicBlock
import torchattacks
import torchvision.utils
import torch.nn.functional as F
from pytorch_grad_cam import GradCAM

Collecting torchinfo
  Downloading torchinfo-1.5.4-py3-none-any.whl (19 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.5.4
Collecting torchattacks
  Downloading torchattacks-3.2.2-py3-none-any.whl (102 kB)
[K     |████████████████████████████████| 102 kB 5.2 MB/s 
[?25hInstalling collected packages: torchattacks
Successfully installed torchattacks-3.2.2
Collecting install
  Downloading install-1.3.4-py3-none-any.whl (3.1 kB)
Collecting grad-cam
  Downloading grad-cam-1.3.5.tar.gz (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 6.5 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting ttach>=0.0.3
  Downloading ttach-0.0.3-py3-none-any.whl (9.8 kB)
Building wheels for collected packages: grad-cam
  Building wheel for grad-cam (PEP 517) ... [?25l[?25hdone
  Created wheel for grad-cam: filename=grad_cam-1.3.5-py3

In [None]:
transform = transforms.Compose([
    # transforms.Resize((224)),
    transforms.ToTensor(),
])

batch_size = 64
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2,pin_memory=True)

valset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(
    valset, batch_size=batch_size, shuffle=False, num_workers=2, sampler = SequentialSampler(valset.data[5000:10000]),pin_memory=True)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=2, sampler = SequentialSampler(testset.data[5000:10000]),pin_memory=True)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


In [None]:
from torch.nn import Conv2d,AvgPool2d,Linear,Sequential,Dropout,BatchNorm2d,ModuleList,BatchNorm1d
import torch.nn.functional as F
import numpy as np
import math
from torch.autograd import Variable

##Source: https://blog.paperspace.com/attention-mechanisms-in-computer-vision-cbam/
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

class Base(nn.Module):
    def freeze(self):
        for param in self.base_model.parameters():
                param.requires_grad = False
    
    def unfreeze(self):
        for param in self.base_model.parameters():
                param.requires_grad = True
    
    def attach_fea_out(self,classname,input,output):
        self.features.append(output)

    def attach_fea_in(self,classname,input,output):
        self.features.append(input[0])

    def __init__(self,trainable = True,attention=False,base=18):
        super(Base,self).__init__()
        self.features = []
        self.channel_size = []
        print(base)
        if base == 9:
            self.base_model = _resnet('resnet', BasicBlock, [1, 1, 1, 1], False, True)
        elif base==18:
            self.base_model = models.resnet18(pretrained=False)
        else:
            self.base_model = models.resnet34(pretrained=False)

        used_blocks = ['layer1', 'layer2','layer3','layer4']
        unused_blocks = ['avgpool','fc']

        for block in used_blocks:
            getattr(self.base_model,block).register_forward_hook(self.attach_fea_out)

        for block in unused_blocks:
             setattr(self.base_model,block,nn.Identity())
        
        if not trainable:
            self.freeze()

        fake_img = torch.rand(1,3,256,256) ## pass fake img to the model to get the channel size of each inception block
        self.base_model(fake_img)
        self.channel_size = [block.size()[1] for block in self.features]
        self.features = []

    def forward(self,img):
        self.base_model(img)

    def get_MLSP(self,img,feature_type,resize = True):
        self.base_model(img)
        if resize:
            print(resize)
            if feature_type == 'narrow':
                MLSP = [F.adaptive_avg_pool2d(block, (1, 1)) for block in self.features]
                for i in range(len(MLSP)):
                    MLSP[i] = MLSP[i].squeeze(2).squeeze(2)

            if feature_type == 'wide':
                MLSP = [F.interpolate(block,mode = 'bilinear', size = 7) for block in self.features]
            
            MLSP = torch.cat(MLSP,dim = 1)
            self.features = []
        else:
            MLSP = self.features
            self.features = []
        return MLSP


class Head(nn.Module):
    def conv_block(self,inc,outc,ker,padding = 1,avgpool = False):
        modules = []
        modules.append(nn.Dropout(0.5))
        if avgpool:
            modules.append(AvgPool2d(3,1,1))
        modules.append(Conv2d(inc,outc,ker,padding = padding))
        modules.append(nn.BatchNorm2d(outc))
        modules.append(nn.ReLU())
        return Sequential(*modules)

    def __init__(self,head_type,num_channel):
        super(Head, self).__init__()
        self.head_type = head_type
        self.num_ch = num_channel
        
        if head_type == 'mlsp_cnn_gap_attn':
            self.attn = []
            self.conv = []
            for i in range(4):
                if i!=3:
                    self.attn.append(CBAM(num_channel[i],reduction_ratio=16))
                else:
                    self.attn.append(CBAM(num_channel[i],reduction_ratio=16,no_spatial=True))
                self.conv.append(Sequential(
                                    self.conv_block(num_channel[i],num_channel[i],1,0),
                                    self.conv_block(num_channel[i],num_channel[i],3,1),
                          ))
            self.attn = ModuleList(self.attn)
            self.conv = ModuleList(self.conv)
            
        self.dense = Sequential(Linear(960,10))

    def forward(self,features):
        if self.head_type == 'mlsp_gap':
            x = torch.cat([F.adaptive_avg_pool2d(feature, (1, 1)) for feature in features],dim=1)
        else:
            x = torch.cat([F.adaptive_avg_pool2d(block2(block1(feature)+feature),(1,1)) for feature,block1,block2 in zip(features,self.attn,self.conv)],dim=1)
        x = torch.flatten(x, 1)
        x = self.dense(x)
        return x

class Fmodel(nn.Module):
    def __init__(self, head_type='mlsp_gap',base = 18):
        super(Fmodel,self).__init__()
        self.bmodel = Base(base=base)
        self.head = Head(head_type,self.bmodel.channel_size)
        self.feature_type = 'narrow'    
        self.resize = False
        self.fea = []
        self.gap_fea = []
        self.gradient = []
        self.handles = []
        
    def forward(self,img):
        x = self.bmodel.get_MLSP(img,self.feature_type,self.resize)
        x = self.head(x)
        return x

    def unfreeze(self):
        self.bmodel.unfreeze()
    
    def freeze(self):
        self.bmodel.freeze()

    def hook_gap(self):
        handle = self.head.dense.register_forward_hook(lambda layer, inl, _,: self.gap_fea.append(inl[0]))
        self.handles += [handle]
        return handle

    def hook_grad(self):
        handle = []
        handle.append(self.head.conv[0].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.conv[1].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.conv[2].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.conv[3].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        self.handles += handle
        return handle

    def hook_fea(self):
        handle = []
        handle.append(self.head.conv[0].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.conv[1].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.conv[2].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.conv[3].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        self.handles += handle
        return handle

    def unhook(self):
        for fea in self.gap_fea:
            fea.detach()
        for grad in self.gradient:
            grad.detach()
        for fea in self.fea:
            fea.detach()

        self.gap_fea = []
        self.gradient = []
        self.fea = []

        for h in self.handles:
            h.remove()
        self.handles = []



In [9]:
def grad_cam(model, preds, label,ind,ind_ad):
    masks_attn_ori = []
    masks_attn_adv = []


    preds[torch.arange(len(preds)),label].sum().backward(retain_graph=True)
    mse = torch.nn.MSELoss()
    loss = 0
    model.gradient = model.gradient[::-1]
    for i in range(4):
        weight = model.gradient[i].mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
        mask = weight * model.fea[i]
        mask_ori = F.relu(mask[ind].sum(dim=1))
        masks_attn_ori.append(mask_ori)
        
        mask_adv = F.relu(mask[ind_ad].sum(dim=1))      
        masks_attn_adv.append(mask_adv)

    model.zero_grad()
    for i in range(4):
        loss += mse(masks_attn_ori[i],masks_attn_adv[i])
    return loss

In [10]:
from tqdm.auto import tqdm, trange
import torchattacks

def train_model(model, dataloaders, criterion, optimizer, scheduler, path, num_epochs=3,adv = False, ada_reg = False, attn_pre = False):
    since = time.time()
    val_acc_history = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1000.0
    atk = torchattacks.PGD(model, eps=8/255, alpha=2/225, steps=7, random_start=True)
    atk_fgsm = torchattacks.FGSM(model, eps=8/255)  

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']: # Each epoch has a training and validation phase
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            size = 0


            iterator = iter(dataloaders[phase])
            for _ in tqdm(range(len(iterator))): # Iterate over data
                inputs, labels = next(iterator)
                if adv:
                    if attn_pre or ada_reg:
                            inputs = torch.cat((inputs.to(device), atk(inputs, labels)),dim=0)
                            labels = torch.cat((labels,labels),dim=0)
                            ind = list(range(len(inputs)//2))
                            ind_ad = list(range(len(inputs)//2,len(inputs)))
                    else:
                        ind_ad = np.random.choice(len(inputs),len(inputs)//2,replace=False)
                        ind = np.delete(np.arange(len(inputs)),ind_ad)
                        inputs_adv = atk(inputs[ind_ad], labels[ind_ad])
                        inputs = inputs.to(device)
                        inputs[ind_ad] = inputs_adv

                else:
                    inputs = inputs.to(device)

                inputs = transforms.functional.normalize(inputs,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                labels = labels.to(device)
                optimizer.zero_grad() # Zero the parameter gradients

                with torch.set_grad_enabled(phase == 'train'): # Forward. Track history if only in train
                    
                    if phase == 'train': # Backward + optimize only if in training phase
                        if ada_reg:
                            model.hook()
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)                            
                        _, preds = torch.max(outputs, 1)
                        if attn_pre:
                            attn_loss = grad_cam(model, outputs ,labels,ind,ind_ad)
                            loss = 1*loss + 1*attn_loss
    
                        if ada_reg:
                            fea = model.gap_fea[0]
                            out_ori = fea[ind]
                            out_adv = fea[ind_ad]
                            diff = (out_adv-out_ori).abs().mean(axis=0).reshape(out_ori.shape[-1])
                            diff = (diff - diff.min())/(diff.max()-diff.min()+1e-6)*(5e-2)
                            loss += (diff*((model.head.dense[0].weight)**2).sum(dim=0)).sum()


                        loss.backward()
                        optimizer.step()
                        if ada_reg:
                            model.unhook()
                    
                    if phase == 'val':
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        _, preds = torch.max(outputs, 1)

                
                # Statistics
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                size+=len(preds)

            epoch_loss = running_loss / size
            
            
            if phase == 'train': # Adjust learning rate based on val loss
                lr_scheduler.step(epoch_loss)
                
            epoch_acc = running_corrects.double() / size
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                torch.save(model.state_dict(),path)
                best_model_wts = copy.deepcopy(model.state_dict())


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('best_loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


In [12]:
# prefix = '/content/gdrive/MyDrive/cifar10/new_saved_model/'
prefix = './'
path =  'mlsp18_attn_def_adareg_attnpres'
epochs = 20
## make sure that batch size when using normal adv training is 128
## for ada reg and attn pres, use batch size 64 since the batch is replicated with adv noise for the deviation calculation.
for i in range(4):
    model = Fmodel('mlsp_cnn_gap_attn',18)
    model = Fmodel('mlsp_gap',18)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.1, patience=1, verbose=True, min_lr =1e-6)

    train_model(model,{"train": train_loader, "val": val_loader}, criterion, optimizer, lr_scheduler,prefix+path+f'_{i}', epochs,adv = True, ada_reg = True,attn_pre=True)

18
Epoch 0/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.6336 Acc: 0.4251
loss2:  1.5673171431207291


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.4100 Acc: 0.4988
loss2:  1.5673171431207291
Epoch 1/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.3173 Acc: 0.5463
loss2:  1.4685179784779658


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.2703 Acc: 0.5530
loss2:  1.4685179784779658
Epoch 2/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.1397 Acc: 0.6121
loss2:  1.4113508926328187


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.1554 Acc: 0.6036
loss2:  1.4113508926328187
Epoch 3/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.0261 Acc: 0.6553
loss2:  1.3978665987853809


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.1108 Acc: 0.6248
loss2:  1.3978665987853809
Epoch 4/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.9441 Acc: 0.6835
loss2:  1.3720158155616897


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0351 Acc: 0.6436
loss2:  1.3720158155616897
Epoch 5/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.8548 Acc: 0.7166
loss2:  1.3365666783984056


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9951 Acc: 0.6617
loss2:  1.3365666783984056
Epoch 6/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.7619 Acc: 0.7497
loss2:  1.3223125425446065


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.8830 Acc: 0.6972
loss2:  1.3223125425446065
Epoch 7/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.6882 Acc: 0.7762
loss2:  1.3013484185309057


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9524 Acc: 0.6853
loss2:  1.3013484185309057
Epoch 8/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.6281 Acc: 0.7969
loss2:  1.2985558497631335


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9909 Acc: 0.6756
loss2:  1.2985558497631335
Epoch 9/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.5437 Acc: 0.8262
loss2:  1.2942631256854749


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.8833 Acc: 0.7150
loss2:  1.2942631256854749
Epoch 10/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.4883 Acc: 0.8447
loss2:  1.270400542737273


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.8992 Acc: 0.7068
loss2:  1.270400542737273
Epoch 11/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.4335 Acc: 0.8658
loss2:  1.275377370817277


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9865 Acc: 0.6946
loss2:  1.275377370817277
Epoch 12/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.4052 Acc: 0.8743
loss2:  1.2743549834736778


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9345 Acc: 0.7166
loss2:  1.2743549834736778
Training complete in 44m 20s
best_loss: 0.882997
18
Epoch 0/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.6364 Acc: 0.4241
loss2:  1.5956735659743209


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.3961 Acc: 0.4959
loss2:  1.5956735659743209
Epoch 1/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.3007 Acc: 0.5499
loss2:  1.4921636928987625


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.2303 Acc: 0.5742
loss2:  1.4921636928987625
Epoch 2/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.1268 Acc: 0.6142
loss2:  1.4391218579333762


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0613 Acc: 0.6388
loss2:  1.4391218579333762
Epoch 3/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.9928 Acc: 0.6654
loss2:  1.4145305556104617


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.1298 Acc: 0.6069
loss2:  1.4145305556104617
Epoch 4/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.8872 Acc: 0.7067
loss2:  1.3880154171868053


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9805 Acc: 0.6642
loss2:  1.3880154171868053
Epoch 5/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.7940 Acc: 0.7403
loss2:  1.3614227744319556


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0085 Acc: 0.6573
loss2:  1.3614227744319556
Epoch 6/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.7311 Acc: 0.7623
loss2:  1.3578501366593343


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9476 Acc: 0.6867
loss2:  1.3578501366593343
Epoch 7/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.6586 Acc: 0.7851
loss2:  1.3342184557024475


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9679 Acc: 0.6831
loss2:  1.3342184557024475
Epoch 8/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.5954 Acc: 0.8078
loss2:  1.3430707619318267


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.8754 Acc: 0.7119
loss2:  1.3430707619318267
Epoch 9/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.5459 Acc: 0.8261
loss2:  1.3213430255880136


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9387 Acc: 0.6981
loss2:  1.3213430255880136
Epoch 10/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.4946 Acc: 0.8430
loss2:  1.3113137514085111


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9265 Acc: 0.7140
loss2:  1.3113137514085111
Epoch 11/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.4253 Acc: 0.8662
loss2:  1.2889919345031309


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0743 Acc: 0.6805
loss2:  1.2889919345031309
Epoch 12/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.3931 Acc: 0.8791
loss2:  1.273061941651737


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9756 Acc: 0.7022
loss2:  1.273061941651737
Training complete in 44m 39s
best_loss: 0.875370
18
Epoch 0/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.6413 Acc: 0.4205
loss2:  1.584393396401954


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.4276 Acc: 0.4832
loss2:  1.584393396401954
Epoch 1/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.2924 Acc: 0.5544
loss2:  1.4942727253565093


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.1595 Acc: 0.5915
loss2:  1.4942727253565093
Epoch 2/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.1287 Acc: 0.6169
loss2:  1.4505152287690535


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.1323 Acc: 0.6044
loss2:  1.4505152287690535
Epoch 3/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.0030 Acc: 0.6638
loss2:  1.4361133551048806


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0940 Acc: 0.6263
loss2:  1.4361133551048806
Epoch 4/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.9057 Acc: 0.6986
loss2:  1.4163956541539457


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0030 Acc: 0.6563
loss2:  1.4163956541539457
Epoch 5/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.8293 Acc: 0.7248
loss2:  1.3956166469227627


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.1267 Acc: 0.6420
loss2:  1.3956166469227627
Epoch 6/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.7388 Acc: 0.7592
loss2:  1.3713622943824515


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0032 Acc: 0.6615
loss2:  1.3713622943824515
Epoch 7/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.6520 Acc: 0.7882
loss2:  1.3431602817057344


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.8826 Acc: 0.7095
loss2:  1.3431602817057344
Epoch 8/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.5987 Acc: 0.8093
loss2:  1.3481901582244717


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9529 Acc: 0.6867
loss2:  1.3481901582244717
Epoch 9/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.5374 Acc: 0.8300
loss2:  1.3269564242619079


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9088 Acc: 0.7090
loss2:  1.3269564242619079
Epoch 10/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.4782 Acc: 0.8501
loss2:  1.3346982331532042


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0754 Acc: 0.6838
loss2:  1.3346982331532042
Epoch 11/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.4348 Acc: 0.8663
loss2:  1.3235499831416724


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0186 Acc: 0.6909
loss2:  1.3235499831416724
Epoch 12/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.3938 Acc: 0.8786
loss2:  1.3254296087547945


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9521 Acc: 0.7141
loss2:  1.3254296087547945
Training complete in 44m 40s
best_loss: 0.882598
18
Epoch 0/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.6290 Acc: 0.4259
loss2:  1.5685477229335425


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.3128 Acc: 0.5307
loss2:  1.5685477229335425
Epoch 1/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.2854 Acc: 0.5558
loss2:  1.48079191632283


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.2364 Acc: 0.5703
loss2:  1.48079191632283
Epoch 2/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.1211 Acc: 0.6203
loss2:  1.4606062411652196


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0613 Acc: 0.6327
loss2:  1.4606062411652196
Epoch 3/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.9956 Acc: 0.6659
loss2:  1.42465795732825


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 1.0756 Acc: 0.6304
loss2:  1.42465795732825
Epoch 4/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.8886 Acc: 0.7041
loss2:  1.4112662628788473


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9612 Acc: 0.6693
loss2:  1.4112662628788473
Epoch 5/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.8090 Acc: 0.7331
loss2:  1.3937629900320108


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9775 Acc: 0.6663
loss2:  1.3937629900320108
Epoch 6/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.7259 Acc: 0.7612
loss2:  1.3872044952324285


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9746 Acc: 0.6709
loss2:  1.3872044952324285
Epoch 7/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.6369 Acc: 0.7930
loss2:  1.368696112157134


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9239 Acc: 0.6945
loss2:  1.368696112157134
Epoch 8/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.5931 Acc: 0.8094
loss2:  1.349601322427735


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9055 Acc: 0.6995
loss2:  1.349601322427735
Epoch 9/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.5042 Acc: 0.8416
loss2:  1.3291835635519393


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.8688 Acc: 0.7143
loss2:  1.3291835635519393
Epoch 10/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.4572 Acc: 0.8572
loss2:  1.30974688859242


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9093 Acc: 0.7139
loss2:  1.30974688859242
Epoch 11/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.4166 Acc: 0.8703
loss2:  1.3082408048307803


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9674 Acc: 0.7014
loss2:  1.3082408048307803
Epoch 12/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.3714 Acc: 0.8850
loss2:  1.28997828161625


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 0.9331 Acc: 0.7242
loss2:  1.28997828161625
Training complete in 44m 43s
best_loss: 0.868767


In [None]:
##Meta evaluator for 4 different runs
class Meta():
    def hook(self):
        def hook_fn(layer,inl,outl):
            self.fea.append(F.adaptive_avg_pool2d(outl.detach(),(1,1)).cpu())

        for i in range(4):
            try:
                # self.models[i].head.dense.register_forward_hook(lambda layer, inl, _,: self.fea.append(inl[0].detach().cpu()))
                self.models[i].bmodel.base_model.layer1.register_forward_hook(hook_fn)
                self.models[i].bmodel.base_model.layer2.register_forward_hook(hook_fn)
                self.models[i].bmodel.base_model.layer3.register_forward_hook(hook_fn)
                self.models[i].bmodel.base_model.layer4.register_forward_hook(hook_fn)
            except:
                # self.models[i].fc.register_forward_hook(lambda layer, inl, _,: fea_final.append(inl[0].detach().cpu()))
                self.models[i].layer1.register_forward_hook(hook_fn)
                self.models[i].layer2.register_forward_hook(hook_fn)
                self.models[i].layer3.register_forward_hook(hook_fn)
                self.models[i].layer4.register_forward_hook(hook_fn)


    def eval(self):
        for i in range(self.nmodel):
            model = self.models[i]
            acc = []
            correct = 0 
            size = 0
            with  torch.no_grad():
                for inputs, labels in test_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    inputs = transforms.functional.normalize(inputs,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                    prediction= model(inputs)
                    _, predicted_class = torch.max(prediction.data, 1)
                    correct += (predicted_class == labels).float().sum().item()
                    size+=len(prediction)
            test_accuracy = correct / size
            acc.append(test_accuracy)

            atks = [torchattacks.FGSM(model, eps=8/255),   torchattacks.PGD(model, eps=8/255, alpha=2/225, steps=7, random_start=True)]            
            for j in [0,1]:
                correct = 0
                start = time.time()
                size = 0
                for images, labels in test_loader:   
                    fea = []
                    adv_images = atks[j](images, labels)
                    labels = labels.to(device)
                    adv_images = transforms.functional.normalize(adv_images,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                    outputs = model(adv_images)
                    _, pre = torch.max(outputs.data, 1)
                    correct += (pre == labels).float().sum().item()
                    size+=len(labels)

                acc.append(correct / size)  

            self.acc.append(acc)


    def __init__(self,name,path,base=18,head_type='mlsp_gap',nmodel = 4):
        self.name = name
        self.models = []
        res = 32
        self.acc = []
        self.nmodel = nmodel
        for i in range(nmodel):
            if 'Fmodel' in name:
                model = globals()[name](base=base,head_type=head_type)
            elif 'resnet9' in name:
                model = _resnet('resnet', BasicBlock, [1, 1, 1, 1], False, True)
                model.fc = Linear(512,10)
            else:
                model = globals()[name](pretrained=False)
                model.fc = Linear(512,10)

            model.to(device)
            # model.load_state_dict(torch.load(prefix+path+f'_{i}',map_location=torch.device('cpu')))
            model.load_state_dict(torch.load(prefix+path+f'_{i}'))
            model.eval()
            self.models.append(model)

    def extract_fea(self):
        self.adv_fea = [[],[],[],[]]
        self.fea = [[],[],[],[]]
        for model_ind in range(self.nmodel):
            cnt = 0
            model = self.models[model_ind]
            atk = torchattacks.FGSM(model, eps=8/255)
            for images, labels in test_loader:  
                adv_images = atk(images, labels)      
                model.hook_gap()
                images = images.to(device)
                images = torch.cat((images,adv_images),dim=0)
                images = transforms.functional.normalize(images,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                out = model(images)
                self.fea[model_ind].append(model.gap_fea[0][:100].cpu().detach())
                self.adv_fea[model_ind].append(model.gap_fea[0][100:].cpu().detach())
                model.unhook()

            self.fea[model_ind] = np.concatenate(self.fea[model_ind],axis=0)
            self.adv_fea[model_ind] = np.concatenate(self.adv_fea[model_ind],axis=0)



In [None]:
prefix = './'
path = 'mlsp18_attn_def_adareg_attnpres'
mlsp18_def = Meta('Fmodel',path,18,'mlsp_cnn_gap_attn',nmodel=4)
mlsp18_def.eval()
np.array(mlsp18_def.acc).mean(axis=0)
