In [1]:
from google.colab import drive

drive.mount('/content/gdrive')


Mounted at /content/gdrive


In [2]:
import time
import copy
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
torch.manual_seed(2)
np.random.seed(2)
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler
import torchvision.models as models
!pip install torchinfo
!pip install torchattacks
!pip install pip install grad-cam
from torchinfo import summary
from torchvision.models.resnet import _resnet,BasicBlock
import torchattacks
import torchvision.utils
import torch.nn.functional as F
from pytorch_grad_cam import GradCAM

Collecting torchinfo
  Downloading torchinfo-1.6.3-py3-none-any.whl (20 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.6.3
Collecting torchattacks
  Downloading torchattacks-3.2.4-py3-none-any.whl (102 kB)
[K     |████████████████████████████████| 102 kB 4.5 MB/s 
[?25hInstalling collected packages: torchattacks
Successfully installed torchattacks-3.2.4
Collecting install
  Downloading install-1.3.5-py3-none-any.whl (3.2 kB)
Collecting grad-cam
  Downloading grad-cam-1.3.7.tar.gz (4.5 MB)
[K     |████████████████████████████████| 4.5 MB 5.4 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting ttach
  Downloading ttach-0.0.3-py3-none-any.whl (9.8 kB)
Building wheels for collected packages: grad-cam
  Building wheel for grad-cam (PEP 517) ... [?25l[?25hdone
  Created wheel for grad-cam: filename=grad_cam-1.3.7-py3-none-a

In [3]:
transform = transforms.Compose([
    # transforms.Resize((224)),
    transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR100(
    root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2,pin_memory=False)

valset = torchvision.datasets.CIFAR100(
    root='./data', train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(
    valset, batch_size=64, shuffle=False, num_workers=2, sampler = SequentialSampler(valset.data[5000:10000]),pin_memory=False)

testset = torchvision.datasets.CIFAR100(
    root='./data', train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=2, sampler = SequentialSampler(testset.data[5000:10000]),pin_memory=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


In [4]:
from torch.nn import Conv2d,AvgPool2d,Linear,Sequential,Dropout,BatchNorm2d,ModuleList,BatchNorm1d
import torch.nn.functional as F
import numpy as np
import math
from torch.autograd import Variable

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

class Base(nn.Module):
    def freeze(self):
        for param in self.base_model.parameters():
                param.requires_grad = False
    
    def unfreeze(self):
        for param in self.base_model.parameters():
                param.requires_grad = True
    
    def attach_fea_out(self,classname,input,output):
        self.features.append(output)

    def attach_fea_in(self,classname,input,output):
        self.features.append(input[0])

    def __init__(self,trainable = True,attention=False,base=18):
        super(Base,self).__init__()
        self.features = []
        self.channel_size = []
        print(base)
        if base == 9:
            self.base_model = _resnet('resnet', BasicBlock, [1, 1, 1, 1], False, True)
        elif base==18:
            self.base_model = models.resnet18(pretrained=False)
        else:
            self.base_model = models.resnet34(pretrained=False)

        used_blocks = ['layer1', 'layer2','layer3','layer4']
        unused_blocks = ['avgpool','fc']

        for block in used_blocks:
            getattr(self.base_model,block).register_forward_hook(self.attach_fea_out)

        for block in unused_blocks:
             setattr(self.base_model,block,nn.Identity())
        
        if not trainable:
            self.freeze()

        fake_img = torch.rand(1,3,256,256) ## pass fake img to the model to get the channel size of each inception block
        self.base_model(fake_img)
        self.channel_size = [block.size()[1] for block in self.features]
        self.features = []

    def forward(self,img):
        self.base_model(img)

    def get_MLSP(self,img,feature_type,resize = True):
        self.base_model(img)
        if resize:
            print(resize)
            if feature_type == 'narrow':
                MLSP = [F.adaptive_avg_pool2d(block, (1, 1)) for block in self.features]
                for i in range(len(MLSP)):
                    MLSP[i] = MLSP[i].squeeze(2).squeeze(2)

            if feature_type == 'wide':
                MLSP = [F.interpolate(block,mode = 'bilinear', size = 7) for block in self.features]
            
            MLSP = torch.cat(MLSP,dim = 1)
            self.features = []
        else:
            MLSP = self.features
            self.features = []
        return MLSP


class Identity(nn.Module):
    def __init__(self):
        super(Identity,self).__init__()
    def forward(self,x):
        return  x
        
class Head(nn.Module):
    def conv_block(self,inc,outc,ker,padding = 1,avgpool = False):
        modules = []
        modules.append(nn.Dropout(0.5))
        if avgpool:
            modules.append(AvgPool2d(3,1,1))
        modules.append(Conv2d(inc,outc,ker,padding = padding))
        modules.append(nn.BatchNorm2d(outc))
        modules.append(nn.ReLU())
        return Sequential(*modules)

    def __init__(self,head_type,num_channel):
        super(Head, self).__init__()
        self.head_type = head_type
        self.num_ch = num_channel
        self.identity = [Identity(),Identity(),Identity(),Identity()]
        if head_type == 'mlsp_cnn_gap_attn':
            self.attn = []
            self.conv = []
            for i in range(4):
                if i!=3:
                    self.attn.append(CBAM(num_channel[i],reduction_ratio=16))
                else:
                    self.attn.append(CBAM(num_channel[i],reduction_ratio=16,no_spatial=True))
                self.conv.append(Sequential(
                                    self.conv_block(num_channel[i],num_channel[i],1,0),
                                    self.conv_block(num_channel[i],num_channel[i],3,1),
                          ))
            self.attn = ModuleList(self.attn)
            self.conv = ModuleList(self.conv)
        self.dense = Sequential(Linear(960,100))

    def forward(self,features):
        if self.head_type == 'mlsp_gap':
            
            x = torch.cat([F.adaptive_avg_pool2d(self.identity[i](feature), (1, 1)) for i,feature in enumerate(features)],dim=1)
        else:
            x = torch.cat([F.adaptive_avg_pool2d(block2(block1(feature)+feature),(1,1)) for feature,block1,block2 in zip(features,self.attn,self.conv)],dim=1)
        x = torch.flatten(x, 1)
        x = self.dense(x)
        return x

class Fmodel(nn.Module):
    def __init__(self, head_type='mlsp_gap',base = 18):
        super(Fmodel,self).__init__()
        self.bmodel = Base(base=base)
        self.head = Head(head_type,self.bmodel.channel_size)
        self.feature_type = 'narrow'    
        self.resize = False
        self.fea = []
        self.gap_fea = []
        self.gradient = []
        self.handles = []
        
    def forward(self,img):
        x = self.bmodel.get_MLSP(img,self.feature_type,self.resize)
        x = self.head(x)
        return x

    def unfreeze(self):
        self.bmodel.unfreeze()
    
    def freeze(self):
        self.bmodel.freeze()

    def hook_gap(self):
        handle = self.head.dense.register_forward_hook(lambda layer, inl, _,: self.gap_fea.append(inl[0]))
        self.handles += [handle]
        return handle

    def hook_grad(self):
        handle = []
        handle.append(self.head.identity[0].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.identity[1].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.identity[2].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.identity[3].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        self.handles += handle
        return handle

    def hook_fea(self):
        handle = []
        handle.append(self.head.identity[0].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.identity[1].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.identity[2].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.identity[3].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        self.handles += handle
        return handle

    def hook(self):
        self.hook_gap()
        self.hook_grad()
        self.hook_fea()

    def unhook(self):
        for fea in self.gap_fea:
            fea.detach()
        for grad in self.gradient:
            grad.detach()
        for fea in self.fea:
            fea.detach()
        self.gap_fea = []
        self.gradient = []
        self.fea = []

        for h in self.handles:
            h.remove()
        self.handles = []



In [5]:
def scale(x):
    return (x - x.min(-1)[0].min(-1)[0].reshape(-1,1,1))/(x.max(-1)[0].max(-1)[0].reshape(-1,1,1)+1e-9)

def grad_cam(child, preds_child, label,ind,ind_ad): ##child and preds_child are actually the model and its prediction.just some naming convention
    masks_child_ori = []
    masks_child_adv = []
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    preds_child[torch.arange(len(preds_child)),label].sum().backward(retain_graph=True)
    mse = torch.nn.MSELoss()
    loss = 0
    child.gradient = child.gradient[::-1]
    mask_final_ori = []
    mask_final_adv = []
    for i in range(3): ## omit the last layer since it is impossible to scale a 1x1, 
                        ## and mask a full img with 1x1 map makes little sense. Maybe i am wrong here :)) 
        weight = child.gradient[i].mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
        mask = weight * child.fea[i]
        mask_ori = F.relu(mask[ind].sum(dim=1))
        mask_ori = scale(mask_ori)
        mask_ori = F.interpolate(mask_ori.unsqueeze(1), (32,32)).squeeze()
        mask_final_ori.append(mask_ori.detach())

        mask_adv = F.relu(mask[ind_ad].sum(dim=1))
        mask_adv = scale(mask_adv)
        mask_adv = F.interpolate(mask_adv.unsqueeze(1), (32,32)).squeeze(-1)
        mask_final_adv.append(mask_adv.detach())
    
    activ_map_ori = scale(torch.stack(mask_final_ori,axis=1).mean(1).squeeze())
    activ_map_adv = scale(torch.stack(mask_final_adv,axis=1).mean(1).squeeze())
    mask_final_ori = (activ_map_ori>activ_map_ori.mean(dim=(-1,-2),keepdim=True)).unsqueeze(1)
    mask_final_adv = (activ_map_adv>activ_map_adv.mean(dim=(-1,-2),keepdim=True)).unsqueeze(1)
    mask_final_ori = mask_final_ori.repeat(1,3,1,1)
    mask_final_adv = mask_final_adv.repeat(1,3,1,1)
    child.zero_grad()
    return mask_final_ori.type(preds_child.dtype).to(device), mask_final_adv.type(preds_child.dtype).to(device), \
                              activ_map_ori.type(preds_child.dtype).to(device), activ_map_adv.type(preds_child.dtype).to(device)

In [6]:
from tqdm.auto import tqdm, trange
import torchattacks

def train_model(model, dataloaders, criterion, optimizer, scheduler, path, num_epochs=3,adv = False):
    
    since = time.time()
    val_acc_history = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1000.0
    atk = torchattacks.PGD(model, eps=8/255, alpha=2/225, steps=7, random_start=True)
    atk_fgsm = torchattacks.FGSM(model, eps=8/255)
    

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']: # Each epoch has a training and validation phase
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            size = 0

            iterator = iter(dataloaders[phase])
            # for _ in tqdm(range(len(iterator))): # Iterate over data
            for _ in tqdm(range(len(iterator))): # Iterate over data
                inputs, labels = next(iterator)
                if adv:
                    if attn_pre or ada_reg:
                            inputs = torch.cat((inputs.to(device), atk(inputs, labels)),dim=0)
                            labels = torch.cat((labels,labels),dim=0)
                            ind = list(range(len(inputs)//2))
                            ind_ad = list(range(len(inputs)//2,len(inputs)))
                    else:
                        ind_ad = np.random.choice(len(inputs),len(inputs)//2,replace=False)
                        ind = np.delete(np.arange(len(inputs)),ind_ad)
                        inputs_adv = atk(inputs[ind_ad], labels[ind_ad])
                        inputs = inputs.to(device)
                        inputs[ind_ad] = inputs_adv

                else:
                    inputs = inputs.to(device)

                inputs = transforms.functional.normalize(inputs,(0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                labels = labels.to(device)
                optimizer.zero_grad() # Zero the parameter gradients

                with torch.set_grad_enabled(phase == 'train'): # Forward. Track history if only in train
                    
                    if phase == 'train': # Backward + optimize only if in training phase
                        model.hook()
                        outputs = model(inputs)
                        loss_c = (config['adv_weight']*criterion(outputs[ind], labels[ind]) + (1-config['adv_weight'])*criterion(outputs[ind_ad], labels[ind_ad]))  
                        # loss_c = criterion(old_outputs, labels)                              
                        _, preds = torch.max(outputs, 1)
                        attn_ori, attn_adv,activ_map_ori,activ_map_adv = grad_cam(model, outputs,labels,ind,ind_ad)
                        loss_p = nn.L1Loss()(activ_map_ori,activ_map_adv)

                        new_input = torch.zeros_like(inputs).to(device)
                        new_input[ind] = inputs[ind]*(torch.ones_like(attn_ori)-attn_ori)
                        new_input[ind_ad] = inputs[ind_ad]*(torch.ones_like(attn_adv)-attn_adv)
                        old_fea_ori = model.gap_fea[0][ind]
                        old_fea_adv = model.gap_fea[0][ind_ad]
                        model(new_input)
                        new_fea_ori = model.gap_fea[1][ind]
                        new_fea_adv = model.gap_fea[1][ind_ad]
                        loss_r = nn.MSELoss()(old_fea_ori,new_fea_ori) + nn.MSELoss()(old_fea_adv,new_fea_adv)

                        loss = loss_c + config['alpha']*loss_p - config['beta']*loss_r
                        
                        loss.backward()
                        optimizer.step()
                        model.unhook()
                    
                    if phase == 'val':
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        _, preds = torch.max(outputs, 1)

                
                # Statistics
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                size+=len(preds)

            epoch_loss = running_loss / size
            
            
            if phase == 'train': # Adjust learning rate based on val loss
                lr_scheduler.step(epoch_loss)
                
            epoch_acc = running_corrects.double() / size
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                torch.save(model.state_dict(),path)
                best_model_wts = copy.deepcopy(model.state_dict())


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('best_loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


In [7]:
prefix = '/content/gdrive/MyDrive/cifar100/'
path =  'Wu_tune'
epochs = 20 #50
for i in range(4):
    config = {'l2': 0.0003464850144870174, 'lr': 0.0006347331182586497, 'adv_weight': 0.5010492180243408, 'alpha': 0.0015003974488041923, 'beta': 0.01430504803197945}
    # config = {'l2': 1e-2, 'lr': 1e-3, 'adv_weight': 0.5, 'alpha': 0.001, 'beta': 0.001}
    model = Fmodel('mlsp_gap',18)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['l2'])
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.1, patience=1, verbose=True, min_lr =1e-6)

    train_model(model,{"train": train_loader, "val": val_loader}, criterion, optimizer, lr_scheduler,prefix+path+f'_{i}', epochs,adv = True)

18
Epoch 0/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 3.7048 Acc: 0.1461


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.5599 Acc: 0.1746
Epoch 1/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 3.1549 Acc: 0.2350


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.1493 Acc: 0.2450
Epoch 2/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.8613 Acc: 0.2890


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.0023 Acc: 0.2825
Epoch 3/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.6144 Acc: 0.3364


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8599 Acc: 0.3121
Epoch 4/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.3993 Acc: 0.3812


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8703 Acc: 0.3129
Epoch 5/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.1999 Acc: 0.4208


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8965 Acc: 0.3108
Epoch 6/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.9890 Acc: 0.4665


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8951 Acc: 0.3329
Epoch 7/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.7854 Acc: 0.5124


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.7875 Acc: 0.3529
Epoch 8/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.5613 Acc: 0.5655


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.9696 Acc: 0.3378
Epoch 9/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.3572 Acc: 0.6131


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.0048 Acc: 0.3533
Epoch 10/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.1596 Acc: 0.6640


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.2808 Acc: 0.3211
Epoch 11/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.9655 Acc: 0.7134


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.2914 Acc: 0.3439
Epoch 12/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.8129 Acc: 0.7541


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.5049 Acc: 0.3344
Training complete in 30m 6s
best_loss: 2.787485
18
Epoch 0/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 3.6805 Acc: 0.1442


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.3907 Acc: 0.1931
Epoch 1/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 3.1210 Acc: 0.2394


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.3666 Acc: 0.2241
Epoch 2/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.8253 Acc: 0.2961


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.0322 Acc: 0.2652
Epoch 3/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.5769 Acc: 0.3456


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.9029 Acc: 0.3002
Epoch 4/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.3729 Acc: 0.3874


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8191 Acc: 0.3215
Epoch 5/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.1360 Acc: 0.4368


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8024 Acc: 0.3340
Epoch 6/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.9384 Acc: 0.4762


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8385 Acc: 0.3372
Epoch 7/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.7405 Acc: 0.5232


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8937 Acc: 0.3442
Epoch 8/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.5198 Acc: 0.5753


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.9775 Acc: 0.3327
Epoch 9/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.3130 Acc: 0.6216


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.1095 Acc: 0.3347
Epoch 10/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.1367 Acc: 0.6704


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.0600 Acc: 0.3493
Epoch 11/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.9422 Acc: 0.7191


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.2774 Acc: 0.3464
Epoch 12/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.8109 Acc: 0.7552


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.3245 Acc: 0.3511
Training complete in 30m 1s
best_loss: 2.802425
18
Epoch 0/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 3.7239 Acc: 0.1396


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.5135 Acc: 0.1756
Epoch 1/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 3.1512 Acc: 0.2343


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.2144 Acc: 0.2357
Epoch 2/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.8429 Acc: 0.2944


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.9782 Acc: 0.2727
Epoch 3/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.6003 Acc: 0.3425


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.9265 Acc: 0.2993
Epoch 4/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.3813 Acc: 0.3845


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8168 Acc: 0.3225
Epoch 5/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.1846 Acc: 0.4262


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8949 Acc: 0.3152
Epoch 6/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.9817 Acc: 0.4669


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8644 Acc: 0.3247
Epoch 7/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.7772 Acc: 0.5158


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.7644 Acc: 0.3528
Epoch 8/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.5576 Acc: 0.5647


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8762 Acc: 0.3480
Epoch 9/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.3454 Acc: 0.6181


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.0035 Acc: 0.3427
Epoch 10/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.1602 Acc: 0.6638


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.1114 Acc: 0.3416
Epoch 11/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.9691 Acc: 0.7143


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.3060 Acc: 0.3370
Epoch 12/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.8206 Acc: 0.7534


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.4485 Acc: 0.3369
Training complete in 30m 23s
best_loss: 2.764404
18
Epoch 0/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 3.7076 Acc: 0.1429


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.4753 Acc: 0.1834
Epoch 1/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 3.1372 Acc: 0.2365


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.1819 Acc: 0.2485
Epoch 2/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.8410 Acc: 0.2934


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.1319 Acc: 0.2640
Epoch 3/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.6037 Acc: 0.3395


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.9407 Acc: 0.2954
Epoch 4/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.4017 Acc: 0.3795


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.9119 Acc: 0.3041
Epoch 5/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 2.1970 Acc: 0.4246


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8233 Acc: 0.3254
Epoch 6/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.9960 Acc: 0.4645


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8017 Acc: 0.3384
Epoch 7/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.7852 Acc: 0.5150


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8737 Acc: 0.3362
Epoch 8/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.5751 Acc: 0.5635


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 2.8749 Acc: 0.3511
Epoch 9/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.3693 Acc: 0.6100


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.1899 Acc: 0.3247
Epoch 10/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 1.1724 Acc: 0.6594


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.0575 Acc: 0.3530
Epoch 11/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.9812 Acc: 0.7099


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.2758 Acc: 0.3384
Epoch 12/12
----------


  0%|          | 0/391 [00:00<?, ?it/s]

train Loss: 0.8510 Acc: 0.7458


  0%|          | 0/79 [00:00<?, ?it/s]

val Loss: 3.3438 Acc: 0.3467
Training complete in 30m 18s
best_loss: 2.801669


In [10]:
test_loader1 = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2, sampler = SequentialSampler(testset.data[9000:10000]))
# test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2, sampler = SequentialSampler(testset.data[5000:10000]))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Meta():
    def hook(self):
        def hook_fn(layer,inl,outl):
            self.fea.append(F.adaptive_avg_pool2d(outl.detach(),(1,1)).cpu())

        for i in range(4):
            try:
                # self.models[i].head.dense.register_forward_hook(lambda layer, inl, _,: self.fea.append(inl[0].detach().cpu()))
                self.models[i].bmodel.base_model.layer1.register_forward_hook(hook_fn)
                self.models[i].bmodel.base_model.layer2.register_forward_hook(hook_fn)
                self.models[i].bmodel.base_model.layer3.register_forward_hook(hook_fn)
                self.models[i].bmodel.base_model.layer4.register_forward_hook(hook_fn)
            except:
                # self.models[i].fc.register_forward_hook(lambda layer, inl, _,: fea_final.append(inl[0].detach().cpu()))
                self.models[i].layer1.register_forward_hook(hook_fn)
                self.models[i].layer2.register_forward_hook(hook_fn)
                self.models[i].layer3.register_forward_hook(hook_fn)
                self.models[i].layer4.register_forward_hook(hook_fn)


    def eval(self):
        for i in range(self.nmodel):
            model = self.models[i]
            acc = []
            correct = 0 
            size = 0
            with  torch.no_grad():
                for inputs, labels in test_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    inputs = transforms.functional.normalize(inputs,(0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                    prediction= model(inputs)
                    _, predicted_class = torch.max(prediction.data, 1)
                    correct += (predicted_class == labels).float().sum().item()
                    size+=len(prediction)
            test_accuracy = correct / size
            acc.append(test_accuracy)

            atks = [torchattacks.FGSM(model, eps=8/255),   torchattacks.PGD(model, eps=8/255, alpha=2/225, steps=7, random_start=True)]            
            for j in [0,1]:
                correct = 0
                start = time.time()
                size = 0
                for images, labels in test_loader:   
                    fea = []
                    adv_images = atks[j](images, labels)
                    labels = labels.to(device)
                    adv_images = transforms.functional.normalize(adv_images,(0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                    outputs = model(adv_images)
                    _, pre = torch.max(outputs.data, 1)
                    correct += (pre == labels).float().sum().item()
                    size+=len(labels)

                acc.append(correct / size)  

            self.acc.append(acc)

            print(acc)

    def __init__(self,name,path,base=18,head_type='mlsp_gap',nmodel = 4):
        self.name = name
        self.models = []
        res = 32
        self.acc = []
        self.nmodel = nmodel
        for i in range(nmodel):
            if 'Fmodel' in name:
                model = globals()[name](base=base,head_type=head_type)
            elif 'resnet9' in name:
                model = _resnet('resnet', BasicBlock, [1, 1, 1, 1], False, True)
                model.fc = Linear(512,10)
            else:
                model = globals()[name](pretrained=False)
                model.fc = Linear(512,10)

            model.to(device)
            # model.load_state_dict(torch.load(prefix+path+f'_{i}',map_location=torch.device('cpu')))
            model.load_state_dict(torch.load(prefix+path+f'_{i}'))
            model.eval()
            self.models.append(model)

    def extract_fea(self):
        self.adv_fea = [[],[],[],[]]
        self.fea = [[],[],[],[]]
        for model_ind in range(self.nmodel):
            cnt = 0
            model = self.models[model_ind]
            atk = torchattacks.FGSM(model, eps=8/255)
            for images, labels in test_loader1:  
                adv_images = atk(images, labels)      
                model.hook_gap()
                images = images.to(device)
                images = torch.cat((images,adv_images),dim=0)
                images = transforms.functional.normalize(images,(0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                out = model(images)
                self.fea[model_ind].append(model.gap_fea[0][:100].cpu().detach())
                self.adv_fea[model_ind].append(model.gap_fea[0][100:].cpu().detach())
                model.unhook()

            self.fea[model_ind] = np.concatenate(self.fea[model_ind],axis=0)
            self.adv_fea[model_ind] = np.concatenate(self.adv_fea[model_ind],axis=0)

prefix = '/content/gdrive/MyDrive/cifar100/'

In [None]:
path = 'Wu_tune'
mlsp18_attn_def_adareg_distill = Meta('Fmodel',path,18,'mlsp_gap',nmodel=4)
mlsp18_attn_def_adareg_distill.eval()
np.array(mlsp18_attn_def_adareg_distill.acc).mean(axis=0)
