## Siamese

In [1]:
import os
os.chdir('../../src/')

In [2]:
import torch
import torch.utils.data as data
from PIL import Image, ImageOps
import pickle
import numpy as np
from torchvision import transforms
import json
from credential_classifier.bit_pytorch.dataloader import HybridLoader
from credential_classifier.bit_pytorch.models import KNOWN_MODELS
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"]="1"

- Define dataloader

In [3]:
val_set = HybridLoader(img_folder='../datasets/val_merge_imgs',
                         annot_path='../datasets/val_merge_coords.txt')

valid_loader = torch.utils.data.DataLoader(
  val_set, batch_size=8, shuffle=False, pin_memory=True, drop_last=False)

In [4]:
len(valid_loader)

467

- Accuracy function

In [15]:
def compute_acc(dataloader, model, device):
    correct = 0
    total = 0

    for b, (x, y) in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            x = x.to(device, non_blocking=True, dtype=torch.float)
            y = y.to(device, non_blocking=True)
            logits = model(x)
            pred_cls = torch.argmax(logits, dim=1)

            correct += torch.sum(torch.eq(pred_cls, y)).item()
            total += y.shape[0]
            
    print('Accuracy after changing relu function: {:.2f}'.format(correct/total))    
    return correct/total

- Load model (original)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = KNOWN_MODELS["BiT-M-R50x1V2"](head_size=2, zero_head=True)

# Load weights
checkpoint = torch.load('/home/l/liny/ruofan/PhishIntention/src/credential_classifier/output/hybrid/hybrid_lr0.005/BiT-M-R50x1V2_0.005.pth.tar', 
                        map_location="cpu")["model"]

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

model.to(device)
model.eval()

ResNetV2Coord(
  (root): Sequential(
    (conv): StdConv2d(8, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (pad): ConstantPad2d(padding=(1, 1, 1, 1), value=0)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (body): Sequential(
    (block1): Sequential(
      (unit01): PreActBottleneck(
        (gn1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (gn2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (gn3): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv3): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu): ReLU(inplace=True)
        (downsample): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (unit02): PreActBottleneck(
        (gn1): GroupNorm(32, 256, eps=1e-05

In [16]:
compute_acc(valid_loader, model, device)

3731it [11:43,  5.31it/s]

Accuracy after changing relu function: 0.95





0.9471991423210936

## Load model (change activation function)

In [5]:
import torch
import torch.nn.functional as F
from torch import nn

In [6]:
class QuantizeRelu(nn.Module):
    def __init__(self, step_size = 0.01):
        super().__init__()
        self.step_size = step_size

    def forward(self, x):
        mask = torch.ge(x, 0).bool() # mask for positive values
        quantize = torch.ones_like(x) * self.step_size
        out = torch.mul(torch.floor(torch.div(x, quantize)), self.step_size) # quantize by step_size
        out = torch.mul(out, mask) # zero-out negative values
        out = torch.abs(out) # remove sign
        return out

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = KNOWN_MODELS["BiT-M-R50x1V2"](head_size=2, zero_head=True)
# Load weights
checkpoint = torch.load('/home/l/liny/ruofan/PhishIntention/src/credential_classifier/output/hybrid/hybrid_lr0.005/BiT-M-R50x1V2_0.005.pth.tar', 
                        map_location="cpu")["model"]

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

# replace relu with defenselayer 
# model.body.block1.unit01.relu = QuantizeRelu()
# model.body.block1.unit02.relu = QuantizeRelu()
# model.body.block1.unit03.relu = QuantizeRelu()

# model.body.block2.unit01.relu = QuantizeRelu()
# model.body.block2.unit02.relu = QuantizeRelu()
# model.body.block2.unit03.relu = QuantizeRelu()
# model.body.block2.unit04.relu = QuantizeRelu()

# model.body.block3.unit01.relu = QuantizeRelu()
# model.body.block3.unit02.relu = QuantizeRelu()
# model.body.block3.unit03.relu = QuantizeRelu()
# model.body.block3.unit04.relu = QuantizeRelu()
# model.body.block3.unit05.relu = QuantizeRelu()
# model.body.block3.unit06.relu = QuantizeRelu()

model.body.block4.unit01.relu = QuantizeRelu()
model.body.block4.unit02.relu = QuantizeRelu()
model.body.block4.unit03.relu = QuantizeRelu()

model.to(device)
model.eval()

ResNetV2Coord(
  (root): Sequential(
    (conv): StdConv2d(8, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (pad): ConstantPad2d(padding=(1, 1, 1, 1), value=0)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (body): Sequential(
    (block1): Sequential(
      (unit01): PreActBottleneck(
        (gn1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (gn2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (gn3): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv3): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu): ReLU(inplace=True)
        (downsample): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (unit02): PreActBottleneck(
        (gn1): GroupNorm(32, 256, eps=1e-05

In [19]:
# compute_acc(valid_loader, model, device)

Accuracy after changing relu function: 0.93


0.9345417925478349

## Attack

In [11]:
from adv_attack.attack.Attack import *

criterion = nn.CrossEntropyLoss()
check = adversarial_attack(method='fgsm', model=model, dataloader=valid_loader, 
                           device=device, num_classes=2, save_data=True)

acc, _ = check.batch_attack()

3558it [6:33:19,  6.94s/it]

3557
Test Accuracy = 0.9471613265879708


3559it [6:33:25,  6.84s/it]

3558
Test Accuracy = 0.9471761730823265


3560it [6:33:33,  6.94s/it]

3559
Test Accuracy = 0.9471910112359551


3561it [6:33:39,  6.92s/it]

3560
Test Accuracy = 0.9472058410558831


3562it [6:33:49,  7.63s/it]

3561
Test Accuracy = 0.9472206625491297


3563it [6:33:56,  7.37s/it]

3562
Test Accuracy = 0.9472354757227056


3564it [6:34:02,  7.17s/it]

3563
Test Accuracy = 0.9472502805836139


3565it [6:34:09,  7.03s/it]

3564
Test Accuracy = 0.9472650771388499


3566it [6:34:16,  7.15s/it]

3565
Test Accuracy = 0.947279865395401


3567it [6:34:23,  7.10s/it]

3566
Test Accuracy = 0.9472946453602467


3568it [6:34:31,  7.14s/it]

3567
Test Accuracy = 0.9473094170403588


3569it [6:34:39,  7.65s/it]

3568
Test Accuracy = 0.9473241804427011


3570it [6:34:46,  7.24s/it]

3569
Test Accuracy = 0.9473389355742297


3571it [6:34:53,  7.18s/it]

3570
Test Accuracy = 0.947353682441893


3572it [6:35:00,  7.10s/it]

3571
Test Accuracy = 0.9473684210526315


3573it [6:35:06,  6.92s/it]

3572
Test Accuracy = 0.9473831514133781


3574it [6:35:13,  6.83s/it]

3573
Test Accuracy = 0.9473978735310576


3575it [6:35:20,  6.81s/it]

3574
Test Accuracy = 0.9474125874125874


3576it [6:35:26,  6.80s/it]

3575
Test Accuracy = 0.9474272930648769


3577it [6:35:27,  4.83s/it]

3576


3578it [6:35:33,  5.35s/it]

3577
Test Accuracy = 0.9471771939631078


3579it [6:35:39,  5.64s/it]

3578
Test Accuracy = 0.9471919530595139


3580it [6:35:47,  6.24s/it]

3579
Test Accuracy = 0.9472067039106146


3581it [6:35:54,  6.54s/it]

3580
Test Accuracy = 0.9472214465233175


3582it [6:36:01,  6.56s/it]

3581
Test Accuracy = 0.9472361809045227


3583it [6:36:09,  6.88s/it]

3582
Test Accuracy = 0.9472509070611219


3584it [6:36:16,  7.02s/it]

3583
Test Accuracy = 0.947265625


3585it [6:36:22,  6.88s/it]

3584
Test Accuracy = 0.9472803347280335


3586it [6:36:29,  6.93s/it]

3585
Test Accuracy = 0.9472950362520914


3587it [6:36:36,  6.78s/it]

3586
Test Accuracy = 0.9473097295790354


3588it [6:36:43,  6.87s/it]

3587
Test Accuracy = 0.947324414715719


3589it [6:36:51,  7.19s/it]

3588
Test Accuracy = 0.9473390916689886


3590it [6:36:58,  7.03s/it]

3589
Test Accuracy = 0.9473537604456824


3591it [6:37:05,  6.99s/it]

3590
Test Accuracy = 0.9473684210526315


3592it [6:37:11,  6.89s/it]

3591
Test Accuracy = 0.9473830734966593


3593it [6:37:21,  7.73s/it]

3592
Test Accuracy = 0.9473977177845812


3594it [6:37:28,  7.43s/it]

3593
Test Accuracy = 0.9474123539232053


3595it [6:37:35,  7.45s/it]

3594
Test Accuracy = 0.9474269819193324


3596it [6:37:42,  7.19s/it]

3595
Test Accuracy = 0.9474416017797553


3597it [6:37:48,  6.98s/it]

3596
Test Accuracy = 0.9474562135112594


3598it [6:37:54,  6.77s/it]

3597
Test Accuracy = 0.9474708171206225


3599it [6:38:01,  6.69s/it]

3598
Test Accuracy = 0.9474854126146152


3600it [6:38:08,  6.89s/it]

3599
Test Accuracy = 0.9475


3601it [6:38:15,  6.84s/it]

3600
Test Accuracy = 0.9475145792835323


3602it [6:38:22,  6.94s/it]

3601
Test Accuracy = 0.9475291504719601


3603it [6:38:29,  6.91s/it]

3602
Test Accuracy = 0.9475437135720233


3604it [6:38:29,  4.90s/it]

3603


3605it [6:38:29,  3.49s/it]

3604


3606it [6:38:36,  4.51s/it]

3605
Test Accuracy = 0.9470327232390461


3607it [6:38:43,  5.23s/it]

3606
Test Accuracy = 0.9470474078181315


3608it [6:38:50,  5.65s/it]

3607
Test Accuracy = 0.9470620842572062


3609it [6:38:57,  6.06s/it]

3608
Test Accuracy = 0.9470767525630368


3610it [6:39:04,  6.43s/it]

3609
Test Accuracy = 0.9470914127423823


3611it [6:39:11,  6.48s/it]

3610
Test Accuracy = 0.9471060648019939


3612it [6:39:18,  6.67s/it]

3611
Test Accuracy = 0.9471207087486158


3613it [6:39:25,  6.84s/it]

3612
Test Accuracy = 0.9471353445889842


3614it [6:39:32,  6.79s/it]

3613
Test Accuracy = 0.9471499723298284


3616it [6:39:39,  4.95s/it]

3614
Test Accuracy = 0.94716459197787
3615


3617it [6:39:47,  5.59s/it]

3616
Test Accuracy = 0.9469173348078518


3618it [6:39:53,  5.94s/it]

3617
Test Accuracy = 0.9469320066334992


3619it [6:40:00,  6.31s/it]

3618
Test Accuracy = 0.9469466703509257


3620it [6:40:07,  6.49s/it]

3619
Test Accuracy = 0.9469613259668508


3621it [6:40:14,  6.53s/it]

3620
Test Accuracy = 0.9469759734879868


3622it [6:40:21,  6.72s/it]

3621
Test Accuracy = 0.9469906129210381


3623it [6:40:28,  6.72s/it]

3622
Test Accuracy = 0.9470052442727022


3624it [6:40:35,  6.76s/it]

3623
Test Accuracy = 0.9470198675496688


3625it [6:40:42,  6.97s/it]

3624
Test Accuracy = 0.9470344827586207


3626it [6:40:49,  6.90s/it]

3625
Test Accuracy = 0.9470490899062327


3627it [6:40:56,  6.81s/it]

3626
Test Accuracy = 0.9470636889991728


3628it [6:41:02,  6.75s/it]

3627
Test Accuracy = 0.9470782800441014


3629it [6:41:09,  6.69s/it]

3628
Test Accuracy = 0.9470928630476715


3630it [6:41:15,  6.64s/it]

3629
Test Accuracy = 0.947107438016529


3631it [6:41:22,  6.68s/it]

3630
Test Accuracy = 0.947122004957312


3632it [6:41:29,  6.67s/it]

3631
Test Accuracy = 0.947136563876652


3633it [6:41:35,  6.68s/it]

3632
Test Accuracy = 0.9471511147811725


3634it [6:41:44,  7.34s/it]

3633
Test Accuracy = 0.9471656576774904


3635it [6:41:51,  7.06s/it]

3634
Test Accuracy = 0.9471801925722145


3636it [6:42:00,  7.61s/it]

3635
Test Accuracy = 0.9471947194719472


3637it [6:42:05,  7.08s/it]

3636
Test Accuracy = 0.947209238383283


3638it [6:42:12,  7.06s/it]

3637
Test Accuracy = 0.9472237493128093


3639it [6:42:19,  6.98s/it]

3638
Test Accuracy = 0.9472382522671063


3640it [6:42:26,  6.96s/it]

3639
Test Accuracy = 0.9472527472527472


3641it [6:42:33,  6.81s/it]

3640
Test Accuracy = 0.9472672342762977


3642it [6:42:33,  4.83s/it]

3641


3643it [6:42:40,  5.41s/it]

3642
Test Accuracy = 0.947021685424101


3644it [6:42:46,  5.63s/it]

3643
Test Accuracy = 0.9470362239297475


3645it [6:42:52,  5.79s/it]

3644
Test Accuracy = 0.9470507544581619


3646it [6:42:58,  5.92s/it]

3645
Test Accuracy = 0.9470652770159078


3647it [6:43:04,  5.88s/it]

3646
Test Accuracy = 0.9470797916095421


3648it [6:43:11,  6.21s/it]

3647
Test Accuracy = 0.9470942982456141


3649it [6:43:18,  6.46s/it]

3648
Test Accuracy = 0.947108796930666


3650it [6:43:25,  6.54s/it]

3649
Test Accuracy = 0.9471232876712329


3651it [6:43:25,  4.65s/it]

3650


3652it [6:43:31,  5.24s/it]

3651
Test Accuracy = 0.9468784227820373


3653it [6:43:38,  5.66s/it]

3652
Test Accuracy = 0.946892964686559


3654it [6:43:45,  5.93s/it]

3653
Test Accuracy = 0.9469074986316366


3655it [6:43:52,  6.28s/it]

3654
Test Accuracy = 0.946922024623803


3656it [6:43:58,  6.38s/it]

3655
Test Accuracy = 0.9469365426695843


3657it [6:44:05,  6.43s/it]

3656
Test Accuracy = 0.9469510527754991


3658it [6:44:11,  6.35s/it]

3657
Test Accuracy = 0.946965554948059


3659it [6:44:17,  6.37s/it]

3658
Test Accuracy = 0.9469800491937688


3660it [6:44:24,  6.31s/it]

3659
Test Accuracy = 0.9469945355191257


3661it [6:44:31,  6.71s/it]

3660
Test Accuracy = 0.9470090139306201


3662it [6:44:38,  6.87s/it]

3661
Test Accuracy = 0.9470234844347352


3663it [6:44:46,  6.95s/it]

3662
Test Accuracy = 0.9470379470379471


3664it [6:44:52,  6.84s/it]

3663
Test Accuracy = 0.9470524017467249


3665it [6:44:59,  6.89s/it]

3664
Test Accuracy = 0.9470668485675307


3666it [6:45:06,  6.98s/it]

3665
Test Accuracy = 0.9470812875068194


3667it [6:45:15,  7.32s/it]

3666
Test Accuracy = 0.947095718571039


3668it [6:45:26,  8.52s/it]

3667
Test Accuracy = 0.9471101417666303


3669it [6:45:33,  8.03s/it]

3668
Test Accuracy = 0.9471245571000273


3670it [6:45:39,  7.58s/it]

3669
Test Accuracy = 0.9471389645776567


3671it [6:45:47,  7.49s/it]

3670
Test Accuracy = 0.9471533642059384


3672it [6:45:54,  7.57s/it]

3671
Test Accuracy = 0.9471677559912854


3673it [6:46:01,  7.44s/it]

3672
Test Accuracy = 0.9471821399401035


3674it [6:46:08,  7.20s/it]

3673
Test Accuracy = 0.9471965160587915


3675it [6:46:15,  7.07s/it]

3674
Test Accuracy = 0.9472108843537415


3676it [6:46:22,  6.98s/it]

3675
Test Accuracy = 0.9472252448313384


3677it [6:46:29,  7.11s/it]

3676
Test Accuracy = 0.9472395974979603


3678it [6:46:36,  6.94s/it]

3677
Test Accuracy = 0.9472539423599783


3679it [6:46:42,  6.87s/it]

3678
Test Accuracy = 0.9472682794237565


3680it [6:46:49,  6.93s/it]

3679
Test Accuracy = 0.9472826086956522


3681it [6:46:57,  7.17s/it]

3680
Test Accuracy = 0.9472969301820158


3682it [6:47:04,  7.13s/it]

3681
Test Accuracy = 0.9473112438891906


3683it [6:47:11,  6.99s/it]

3682
Test Accuracy = 0.9473255498235135


3684it [6:47:18,  6.99s/it]

3683
Test Accuracy = 0.9473398479913138


3685it [6:47:25,  7.02s/it]

3684
Test Accuracy = 0.9473541383989145


3686it [6:47:31,  6.88s/it]

3685
Test Accuracy = 0.9473684210526315


3687it [6:47:38,  6.74s/it]

3686
Test Accuracy = 0.9473826959587741


3688it [6:47:48,  7.84s/it]

3687
Test Accuracy = 0.9473969631236443


3689it [6:47:56,  7.75s/it]

3688
Test Accuracy = 0.9474112225535375


3690it [6:48:07,  8.84s/it]

3689
Test Accuracy = 0.9474254742547426


3691it [6:48:14,  8.32s/it]

3690
Test Accuracy = 0.947439718233541


3692it [6:48:21,  7.88s/it]

3691
Test Accuracy = 0.9474539544962081


3693it [6:48:28,  7.68s/it]

3692
Test Accuracy = 0.9474681830490116


3694it [6:48:35,  7.44s/it]

3693
Test Accuracy = 0.9474824038982134


3695it [6:48:35,  5.27s/it]

3694


3696it [6:48:36,  3.76s/it]

3695


3697it [6:48:42,  4.58s/it]

3696
Test Accuracy = 0.9469840411144171


3698it [6:48:49,  5.27s/it]

3697
Test Accuracy = 0.9469983775013521


3699it [6:48:57,  5.97s/it]

3698
Test Accuracy = 0.9470127061367938


3700it [6:49:05,  6.59s/it]

3699
Test Accuracy = 0.947027027027027


3701it [6:49:11,  6.50s/it]

3700
Test Accuracy = 0.9470413401783302


3702it [6:49:18,  6.53s/it]

3701
Test Accuracy = 0.9470556455969746


3703it [6:49:24,  6.61s/it]

3702
Test Accuracy = 0.947069943289225


3704it [6:49:33,  7.20s/it]

3703
Test Accuracy = 0.9470842332613391


3705it [6:49:40,  7.02s/it]

3704
Test Accuracy = 0.9470985155195681


3706it [6:49:50,  8.19s/it]

3705
Test Accuracy = 0.9471127900701565


3707it [6:49:58,  7.88s/it]

3706
Test Accuracy = 0.9471270569193417


3708it [6:50:05,  7.63s/it]

3707
Test Accuracy = 0.9471413160733549


3709it [6:50:11,  7.37s/it]

3708
Test Accuracy = 0.94715556753842


3710it [6:50:18,  7.21s/it]

3709
Test Accuracy = 0.9471698113207547


3711it [6:50:25,  6.98s/it]

3710
Test Accuracy = 0.9471840474265697


3712it [6:50:31,  6.65s/it]

3711
Test Accuracy = 0.947198275862069


3713it [6:50:37,  6.72s/it]

3712
Test Accuracy = 0.9472124966334501


3714it [6:50:44,  6.72s/it]

3713
Test Accuracy = 0.9472267097469036


3715it [6:50:51,  6.69s/it]

3714
Test Accuracy = 0.9472409152086138


3716it [6:50:57,  6.61s/it]

3715
Test Accuracy = 0.9472551130247578


3717it [6:51:04,  6.63s/it]

3716
Test Accuracy = 0.9472693032015066


3718it [6:51:11,  6.81s/it]

3717
Test Accuracy = 0.9472834857450242


3719it [6:51:18,  6.83s/it]

3718
Test Accuracy = 0.9472976606614681


3720it [6:51:26,  7.10s/it]

3719
Test Accuracy = 0.9473118279569892


3721it [6:51:32,  6.97s/it]

3720
Test Accuracy = 0.9473259876377318


3722it [6:51:39,  6.87s/it]

3721
Test Accuracy = 0.9473401397098334


3723it [6:51:46,  6.80s/it]

3722
Test Accuracy = 0.9473542841794252


3724it [6:51:52,  6.81s/it]

3723
Test Accuracy = 0.9473684210526315


3725it [6:51:59,  6.72s/it]

3724
Test Accuracy = 0.9473825503355705


3726it [6:52:06,  6.71s/it]

3725
Test Accuracy = 0.9473966720343532


3727it [6:52:13,  6.79s/it]

3726
Test Accuracy = 0.9474107861550846


3728it [6:52:20,  6.83s/it]

3727
Test Accuracy = 0.9474248927038627


3729it [6:52:20,  4.85s/it]

3728


3730it [6:52:27,  5.61s/it]

3729
Test Accuracy = 0.9471849865951742


3731it [6:52:35,  6.64s/it]

3730
Test Accuracy = 0.9471991423210936
Test Accuracy = 3534 / 3731 = 0.9471991423210936





In [12]:
acc

0.9471991423210936

## BPDA

- With defense

In [7]:
from advertorch.bpda import BPDAWrapper
from advertorch.attacks import LinfPGDAttack

In [8]:
# original model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = KNOWN_MODELS["BiT-M-R50x1V2"](head_size=2, zero_head=True)
# Load weights
checkpoint = torch.load('/home/l/liny/ruofan/PhishIntention/src/credential_classifier/output/hybrid/hybrid_lr0.005/BiT-M-R50x1V2_0.005.pth.tar', 
                        map_location="cpu")["model"]

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)
model.to(device)
model.eval()

# BPDA approximation of gradients
defense_layer = BPDAWrapper(QuantizeRelu())

# # replace relu with defenselayer 
model.body.block1.unit01.relu = defense_layer
model.body.block1.unit02.relu = defense_layer
model.body.block1.unit03.relu = defense_layer

model.body.block2.unit01.relu = defense_layer
model.body.block2.unit02.relu = defense_layer
model.body.block2.unit03.relu = defense_layer
model.body.block2.unit04.relu = defense_layer

model.body.block3.unit01.relu = defense_layer
model.body.block3.unit02.relu = defense_layer
model.body.block3.unit03.relu = defense_layer
model.body.block3.unit04.relu = defense_layer
model.body.block3.unit05.relu = defense_layer
model.body.block3.unit06.relu = defense_layer

model.body.block4.unit01.relu = defense_layer
model.body.block4.unit02.relu = defense_layer
model.body.block4.unit03.relu = defense_layer

bpda_adversary = LinfPGDAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.05,
    nb_iter=100, eps_iter=0.005, rand_init=True, clip_min=0.0, clip_max=1.0,
    targeted=False)


In [None]:
perturb_correct = 0
total = 0

for cln_data, true_label in tqdm(valid_loader):
    cln_data, true_label = cln_data.to(device, dtype=torch.float), true_label.to(device)
    bpda_adv = bpda_adversary.perturb(cln_data, true_label)
    
    logits = model(bpda_adv)
    pred_cls = torch.argmax(logits, dim=1)
    perturb_correct += torch.sum(torch.eq(pred_cls, true_label)).item()
    total += len(true_label)
    
    print(perturb_correct/total)

  0%|          | 1/467 [00:17<2:12:18, 17.03s/it]

0.875


  0%|          | 2/467 [00:33<2:10:46, 16.87s/it]

0.9375


  1%|          | 3/467 [00:50<2:10:06, 16.82s/it]

0.9583333333333334


  1%|          | 4/467 [01:07<2:10:38, 16.93s/it]

0.96875


  1%|          | 5/467 [01:23<2:08:07, 16.64s/it]

0.975


  1%|▏         | 6/467 [01:39<2:07:44, 16.63s/it]

0.9791666666666666


  1%|▏         | 7/467 [01:55<2:05:17, 16.34s/it]

0.9642857142857143


  2%|▏         | 8/467 [02:12<2:05:37, 16.42s/it]

0.96875


  2%|▏         | 9/467 [02:27<2:03:33, 16.19s/it]

0.9583333333333334


  2%|▏         | 10/467 [02:44<2:03:55, 16.27s/it]

0.9625


  2%|▏         | 11/467 [02:59<2:02:11, 16.08s/it]

0.9659090909090909


  3%|▎         | 12/467 [03:16<2:02:34, 16.16s/it]

0.96875


  3%|▎         | 13/467 [03:32<2:02:07, 16.14s/it]

0.9615384615384616


  3%|▎         | 14/467 [03:49<2:03:09, 16.31s/it]

0.9464285714285714


  3%|▎         | 15/467 [04:06<2:04:11, 16.49s/it]

0.9333333333333333


  3%|▎         | 16/467 [04:22<2:03:23, 16.42s/it]

0.9296875


  4%|▎         | 17/467 [04:39<2:03:46, 16.50s/it]

0.9264705882352942


  4%|▍         | 18/467 [04:55<2:02:29, 16.37s/it]

0.9236111111111112


  4%|▍         | 19/467 [05:12<2:04:15, 16.64s/it]

0.9144736842105263


  4%|▍         | 20/467 [05:28<2:03:52, 16.63s/it]

0.91875


  4%|▍         | 21/467 [05:45<2:03:48, 16.66s/it]

0.9047619047619048


  5%|▍         | 22/467 [06:02<2:04:28, 16.78s/it]

0.9034090909090909


  5%|▍         | 23/467 [06:18<2:02:08, 16.51s/it]

0.8913043478260869


  5%|▌         | 24/467 [06:35<2:02:36, 16.61s/it]

0.8958333333333334


  5%|▌         | 25/467 [06:51<2:01:33, 16.50s/it]

0.89


  6%|▌         | 26/467 [07:08<2:01:35, 16.54s/it]

0.8942307692307693


  6%|▌         | 27/467 [07:23<1:57:15, 15.99s/it]

0.8842592592592593


  6%|▌         | 28/467 [07:41<2:01:21, 16.59s/it]

0.8794642857142857


  6%|▌         | 29/467 [07:56<1:59:43, 16.40s/it]

0.8793103448275862


  6%|▋         | 30/467 [08:11<1:55:30, 15.86s/it]

0.8833333333333333


  7%|▋         | 31/467 [08:28<1:57:16, 16.14s/it]

0.8830645161290323


  7%|▋         | 32/467 [08:45<1:58:06, 16.29s/it]

0.88671875


  7%|▋         | 33/467 [09:01<1:58:03, 16.32s/it]

0.8901515151515151


  7%|▋         | 34/467 [09:18<1:58:51, 16.47s/it]

0.8933823529411765


  7%|▋         | 35/467 [09:34<1:57:22, 16.30s/it]

0.8928571428571429


  8%|▊         | 36/467 [09:50<1:58:03, 16.44s/it]

0.8888888888888888


  8%|▊         | 37/467 [10:05<1:53:46, 15.88s/it]

0.8918918918918919


  8%|▊         | 38/467 [10:23<1:57:25, 16.42s/it]

0.8914473684210527


  8%|▊         | 39/467 [10:39<1:56:12, 16.29s/it]

0.8942307692307693


  9%|▊         | 40/467 [10:56<1:57:11, 16.47s/it]

0.890625


  9%|▉         | 41/467 [11:11<1:55:48, 16.31s/it]

0.8902439024390244


  9%|▉         | 42/467 [11:23<1:46:23, 15.02s/it]

0.8869047619047619


  9%|▉         | 43/467 [11:40<1:50:13, 15.60s/it]

0.8866279069767442


  9%|▉         | 44/467 [11:58<1:53:32, 16.11s/it]

0.8892045454545454


 10%|▉         | 45/467 [12:14<1:54:34, 16.29s/it]

0.8861111111111111


 10%|▉         | 46/467 [12:30<1:52:02, 15.97s/it]

0.8831521739130435


 10%|█         | 47/467 [12:50<2:00:14, 17.18s/it]

0.8856382978723404


 10%|█         | 48/467 [13:05<1:56:46, 16.72s/it]

0.8854166666666666


 10%|█         | 49/467 [13:22<1:56:54, 16.78s/it]

0.8877551020408163


 11%|█         | 50/467 [13:37<1:53:23, 16.32s/it]

0.885


 11%|█         | 51/467 [13:54<1:54:31, 16.52s/it]

0.8848039215686274


 11%|█         | 52/467 [14:10<1:52:13, 16.23s/it]

0.8846153846153846


 11%|█▏        | 53/467 [14:27<1:53:04, 16.39s/it]

0.8867924528301887


 12%|█▏        | 54/467 [14:43<1:52:43, 16.38s/it]

0.8888888888888888


 12%|█▏        | 55/467 [14:59<1:52:12, 16.34s/it]

0.8909090909090909


 12%|█▏        | 56/467 [15:16<1:52:40, 16.45s/it]

0.8928571428571429


 12%|█▏        | 57/467 [15:32<1:52:10, 16.42s/it]

0.8925438596491229


 12%|█▏        | 58/467 [15:49<1:52:25, 16.49s/it]

0.8900862068965517


 13%|█▎        | 59/467 [16:05<1:51:47, 16.44s/it]

0.8877118644067796


 13%|█▎        | 60/467 [16:23<1:52:59, 16.66s/it]

0.8854166666666666


 13%|█▎        | 61/467 [16:38<1:49:36, 16.20s/it]

0.8831967213114754


 13%|█▎        | 62/467 [16:54<1:50:21, 16.35s/it]

0.8810483870967742


 13%|█▎        | 63/467 [17:10<1:48:32, 16.12s/it]

0.878968253968254


 14%|█▎        | 64/467 [17:27<1:49:58, 16.37s/it]

0.876953125


 14%|█▍        | 65/467 [17:43<1:48:34, 16.21s/it]

0.8769230769230769


 14%|█▍        | 66/467 [18:00<1:49:31, 16.39s/it]

0.8768939393939394


 14%|█▍        | 67/467 [18:16<1:49:33, 16.43s/it]

0.8768656716417911


 15%|█▍        | 68/467 [18:28<1:40:02, 15.04s/it]

0.8786764705882353


 15%|█▍        | 69/467 [18:45<1:43:08, 15.55s/it]

0.8804347826086957


 15%|█▍        | 70/467 [19:03<1:48:38, 16.42s/it]

0.8821428571428571


 15%|█▌        | 71/467 [19:20<1:48:24, 16.43s/it]

0.8838028169014085


 15%|█▌        | 72/467 [19:36<1:48:35, 16.50s/it]

0.8854166666666666


 16%|█▌        | 73/467 [19:52<1:46:46, 16.26s/it]

0.8852739726027398


 16%|█▌        | 74/467 [20:08<1:46:50, 16.31s/it]

0.8868243243243243


 16%|█▌        | 75/467 [20:24<1:45:16, 16.11s/it]

0.8866666666666667


 16%|█▋        | 76/467 [20:41<1:45:51, 16.24s/it]

0.8865131578947368


 16%|█▋        | 77/467 [20:56<1:44:47, 16.12s/it]

0.8831168831168831


 17%|█▋        | 78/467 [21:13<1:45:14, 16.23s/it]

0.8846153846153846


 17%|█▋        | 79/467 [21:29<1:44:22, 16.14s/it]

0.8829113924050633


 17%|█▋        | 80/467 [21:45<1:45:08, 16.30s/it]

0.884375


 17%|█▋        | 81/467 [22:02<1:45:01, 16.33s/it]

0.8842592592592593


 18%|█▊        | 82/467 [22:18<1:45:19, 16.41s/it]

0.8826219512195121


 18%|█▊        | 83/467 [22:35<1:45:22, 16.47s/it]

0.8840361445783133


 18%|█▊        | 84/467 [22:51<1:44:57, 16.44s/it]

0.8824404761904762


 18%|█▊        | 85/467 [23:08<1:45:20, 16.55s/it]

0.8823529411764706


 18%|█▊        | 86/467 [23:24<1:43:25, 16.29s/it]

0.8808139534883721


 19%|█▊        | 87/467 [23:38<1:39:53, 15.77s/it]

0.8807471264367817


 19%|█▉        | 88/467 [23:58<1:47:40, 17.05s/it]

0.8792613636363636


 19%|█▉        | 89/467 [24:14<1:44:36, 16.60s/it]

0.8806179775280899


 19%|█▉        | 90/467 [24:27<1:37:15, 15.48s/it]

0.8805555555555555


 19%|█▉        | 91/467 [24:44<1:39:21, 15.85s/it]

0.8818681318681318


 20%|█▉        | 92/467 [25:00<1:39:21, 15.90s/it]

0.8817934782608695


 20%|█▉        | 93/467 [25:15<1:38:35, 15.82s/it]

0.8830645161290323


 20%|██        | 94/467 [25:32<1:39:28, 16.00s/it]

0.8829787234042553


 20%|██        | 95/467 [25:49<1:40:52, 16.27s/it]

0.881578947368421


 21%|██        | 96/467 [26:05<1:40:17, 16.22s/it]

0.8828125


 21%|██        | 97/467 [26:24<1:45:42, 17.14s/it]

0.8827319587628866


 21%|██        | 98/467 [26:40<1:43:31, 16.83s/it]

0.8839285714285714


 21%|██        | 99/467 [26:57<1:43:02, 16.80s/it]

0.8825757575757576


 21%|██▏       | 100/467 [27:12<1:39:58, 16.35s/it]

0.8825


 22%|██▏       | 101/467 [27:31<1:45:09, 17.24s/it]

0.8824257425742574


 22%|██▏       | 102/467 [27:49<1:46:02, 17.43s/it]

0.883578431372549


 22%|██▏       | 103/467 [28:06<1:44:25, 17.21s/it]

0.8847087378640777


 22%|██▏       | 104/467 [28:23<1:43:24, 17.09s/it]

0.8858173076923077


 22%|██▏       | 105/467 [28:48<1:57:59, 19.56s/it]

0.8857142857142857


 23%|██▎       | 106/467 [29:04<1:50:28, 18.36s/it]

0.8867924528301887


 23%|██▎       | 107/467 [29:24<1:53:21, 18.89s/it]

0.8866822429906542


 23%|██▎       | 108/467 [29:44<1:54:46, 19.18s/it]

0.8865740740740741


 23%|██▎       | 109/467 [29:53<1:37:09, 16.28s/it]

0.8876146788990825


 24%|██▎       | 110/467 [30:02<1:23:19, 14.00s/it]

0.8886363636363637


 24%|██▍       | 111/467 [30:11<1:13:56, 12.46s/it]

0.8873873873873874


 24%|██▍       | 112/467 [30:20<1:08:19, 11.55s/it]

0.8883928571428571


 24%|██▍       | 113/467 [30:29<1:03:11, 10.71s/it]

0.8871681415929203


 24%|██▍       | 114/467 [30:38<59:30, 10.12s/it]  

0.8859649122807017


 25%|██▍       | 115/467 [30:47<57:25,  9.79s/it]

0.8858695652173914


 25%|██▍       | 116/467 [30:56<55:45,  9.53s/it]

0.8868534482758621


 25%|██▌       | 117/467 [31:04<54:11,  9.29s/it]

0.8878205128205128


 25%|██▌       | 118/467 [31:14<54:49,  9.43s/it]

0.888771186440678


 25%|██▌       | 119/467 [31:24<54:49,  9.45s/it]

0.8897058823529411


 26%|██▌       | 120/467 [31:40<1:06:01, 11.42s/it]

0.8875


 26%|██▌       | 121/467 [31:56<1:13:52, 12.81s/it]

0.8853305785123967


 26%|██▌       | 122/467 [32:12<1:19:24, 13.81s/it]

0.8852459016393442


 26%|██▋       | 123/467 [32:28<1:23:11, 14.51s/it]

0.8851626016260162


 27%|██▋       | 124/467 [32:44<1:25:55, 15.03s/it]

0.8860887096774194


 27%|██▋       | 125/467 [33:01<1:28:08, 15.46s/it]

0.887


 27%|██▋       | 126/467 [33:17<1:30:01, 15.84s/it]

0.8878968253968254


 27%|██▋       | 127/467 [33:34<1:31:35, 16.16s/it]

0.8887795275590551


 27%|██▋       | 128/467 [33:51<1:32:49, 16.43s/it]

0.8876953125


 28%|██▊       | 129/467 [34:08<1:33:16, 16.56s/it]

0.8866279069767442


 28%|██▊       | 130/467 [34:28<1:37:52, 17.43s/it]

0.8875


 28%|██▊       | 131/467 [34:46<1:38:21, 17.56s/it]

0.8874045801526718


 28%|██▊       | 132/467 [35:03<1:38:40, 17.67s/it]

0.8863636363636364


 28%|██▊       | 133/467 [35:21<1:38:21, 17.67s/it]

0.8862781954887218


 29%|██▊       | 134/467 [35:39<1:38:20, 17.72s/it]

0.8861940298507462


 29%|██▉       | 135/467 [35:57<1:37:57, 17.70s/it]

0.8870370370370371


 29%|██▉       | 136/467 [36:15<1:38:12, 17.80s/it]

0.8878676470588235


 29%|██▉       | 137/467 [36:33<1:38:02, 17.83s/it]

0.8868613138686131


 30%|██▉       | 138/467 [36:50<1:37:26, 17.77s/it]

0.8858695652173914


 30%|██▉       | 139/467 [37:08<1:37:11, 17.78s/it]

0.8866906474820144


 30%|██▉       | 140/467 [37:26<1:36:34, 17.72s/it]

0.8875


 30%|███       | 141/467 [37:43<1:35:53, 17.65s/it]

0.8882978723404256


 30%|███       | 142/467 [38:00<1:35:13, 17.58s/it]

0.8890845070422535


 31%|███       | 143/467 [38:18<1:35:11, 17.63s/it]

0.8898601398601399


 31%|███       | 144/467 [38:35<1:34:15, 17.51s/it]

0.890625


 31%|███       | 145/467 [38:52<1:32:51, 17.30s/it]

0.8913793103448275


 31%|███▏      | 146/467 [39:09<1:32:05, 17.21s/it]

0.8886986301369864


 31%|███▏      | 147/467 [39:27<1:32:10, 17.28s/it]

0.8894557823129252


 32%|███▏      | 148/467 [39:44<1:32:30, 17.40s/it]

0.8902027027027027


 32%|███▏      | 149/467 [40:02<1:32:58, 17.54s/it]

0.8901006711409396


 32%|███▏      | 150/467 [40:20<1:32:37, 17.53s/it]

0.89


 32%|███▏      | 151/467 [40:39<1:35:05, 18.06s/it]

0.890728476821192


 33%|███▎      | 152/467 [40:57<1:34:11, 17.94s/it]

0.890625


 33%|███▎      | 153/467 [41:16<1:36:14, 18.39s/it]

0.8913398692810458


 33%|███▎      | 154/467 [41:36<1:37:30, 18.69s/it]

0.8920454545454546


 33%|███▎      | 155/467 [41:55<1:38:39, 18.97s/it]

0.8927419354838709


 33%|███▎      | 156/467 [42:14<1:38:39, 19.03s/it]

0.8910256410256411


 34%|███▎      | 157/467 [42:31<1:35:12, 18.43s/it]

0.89171974522293


 34%|███▍      | 158/467 [42:49<1:33:05, 18.08s/it]

0.8916139240506329


 34%|███▍      | 159/467 [43:06<1:31:13, 17.77s/it]

0.8907232704402516


 34%|███▍      | 160/467 [43:23<1:29:33, 17.50s/it]

0.89140625


 34%|███▍      | 161/467 [43:39<1:27:15, 17.11s/it]

0.8913043478260869


 35%|███▍      | 162/467 [43:58<1:29:43, 17.65s/it]

0.8912037037037037


 35%|███▍      | 163/467 [44:15<1:29:23, 17.64s/it]

0.8911042944785276


 35%|███▌      | 164/467 [44:33<1:29:25, 17.71s/it]

0.8894817073170732


 35%|███▌      | 165/467 [44:53<1:32:16, 18.33s/it]

0.8886363636363637


 36%|███▌      | 166/467 [45:11<1:31:34, 18.26s/it]

0.8885542168674698


 36%|███▌      | 167/467 [45:29<1:30:16, 18.06s/it]

0.8892215568862275


 36%|███▌      | 168/467 [45:46<1:29:17, 17.92s/it]

0.8898809523809523


 36%|███▌      | 169/467 [46:04<1:28:14, 17.77s/it]

0.8890532544378699


 36%|███▋      | 170/467 [46:21<1:27:36, 17.70s/it]

0.888235294117647


 37%|███▋      | 171/467 [46:39<1:26:47, 17.59s/it]

0.8888888888888888


 37%|███▋      | 172/467 [46:56<1:26:17, 17.55s/it]

0.8895348837209303


 37%|███▋      | 173/467 [47:13<1:25:26, 17.44s/it]

0.8901734104046243


 37%|███▋      | 174/467 [47:30<1:24:48, 17.37s/it]

0.8908045977011494


 37%|███▋      | 175/467 [47:47<1:23:54, 17.24s/it]

0.8907142857142857


 38%|███▊      | 176/467 [48:03<1:22:05, 16.93s/it]

0.890625


 38%|███▊      | 177/467 [48:20<1:20:40, 16.69s/it]

0.8905367231638418


 38%|███▊      | 178/467 [48:36<1:19:16, 16.46s/it]

0.889747191011236


 38%|███▊      | 179/467 [48:52<1:18:38, 16.38s/it]

0.8896648044692738


 39%|███▊      | 180/467 [49:08<1:18:04, 16.32s/it]

0.8902777777777777


 39%|███▉      | 181/467 [49:25<1:18:34, 16.48s/it]

0.8895027624309392


 39%|███▉      | 182/467 [49:41<1:18:06, 16.44s/it]

0.8901098901098901


 39%|███▉      | 183/467 [49:57<1:17:11, 16.31s/it]

0.8907103825136612


 39%|███▉      | 184/467 [50:13<1:16:48, 16.28s/it]

0.890625


 40%|███▉      | 185/467 [50:30<1:17:01, 16.39s/it]

0.8905405405405405


 40%|███▉      | 186/467 [50:46<1:16:20, 16.30s/it]

0.8911290322580645


 40%|████      | 187/467 [51:02<1:15:25, 16.16s/it]

0.891042780748663


 40%|████      | 188/467 [51:18<1:14:37, 16.05s/it]

0.8916223404255319


 40%|████      | 189/467 [51:34<1:14:29, 16.08s/it]

0.8921957671957672


 41%|████      | 190/467 [51:50<1:13:47, 15.98s/it]

0.8914473684210527


 41%|████      | 191/467 [52:06<1:13:36, 16.00s/it]

0.8913612565445026


 41%|████      | 192/467 [52:22<1:13:14, 15.98s/it]

0.8912760416666666


In [14]:
perturb_correct/total

0.0002680246582685607

In [24]:
perturb_correct/total

0.905337361530715

In [14]:
torch.randint(0, 277, (1,))

tensor([203])