## Siamese

In [1]:
import os
os.chdir('../../src/')

In [2]:
import torch
import torch.utils.data as data
from PIL import Image, ImageOps
import pickle
import numpy as np
from torchvision import transforms
import json
from credential_classifier.bit_pytorch.dataloader import HybridLoader
from credential_classifier.bit_pytorch.models import KNOWN_MODELS
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"]="1"

- Define dataloader

In [3]:
val_set = HybridLoader(img_folder='../datasets/val_merge_imgs',
                         annot_path='../datasets/val_merge_coords.txt')

valid_loader = torch.utils.data.DataLoader(
  val_set, batch_size=8, shuffle=False, pin_memory=True, drop_last=False)

In [4]:
len(valid_loader)

467

- Accuracy function

In [15]:
def compute_acc(dataloader, model, device):
    correct = 0
    total = 0

    for b, (x, y) in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            x = x.to(device, non_blocking=True, dtype=torch.float)
            y = y.to(device, non_blocking=True)
            logits = model(x)
            pred_cls = torch.argmax(logits, dim=1)

            correct += torch.sum(torch.eq(pred_cls, y)).item()
            total += y.shape[0]
            
    print('Accuracy after changing relu function: {:.2f}'.format(correct/total))    
    return correct/total

- Load model (original)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = KNOWN_MODELS["BiT-M-R50x1V2"](head_size=2, zero_head=True)

# Load weights
checkpoint = torch.load('/home/l/liny/ruofan/PhishIntention/src/credential_classifier/output/hybrid/hybrid_lr0.005/BiT-M-R50x1V2_0.005.pth.tar', 
                        map_location="cpu")["model"]

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

model.to(device)
model.eval()

ResNetV2Coord(
  (root): Sequential(
    (conv): StdConv2d(8, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (pad): ConstantPad2d(padding=(1, 1, 1, 1), value=0)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (body): Sequential(
    (block1): Sequential(
      (unit01): PreActBottleneck(
        (gn1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (gn2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (gn3): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv3): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu): ReLU(inplace=True)
        (downsample): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (unit02): PreActBottleneck(
        (gn1): GroupNorm(32, 256, eps=1e-05

In [16]:
compute_acc(valid_loader, model, device)

3731it [11:43,  5.31it/s]

Accuracy after changing relu function: 0.95





0.9471991423210936

## Load model (change activation function)

In [5]:
import torch
import torch.nn.functional as F
from torch import nn

In [11]:
class QuantizeRelu(nn.Module):
    def __init__(self, step_size = 0.01):
        super().__init__()
        self.step_size = step_size

    def forward(self, x):
        mask = torch.ge(x, 0).bool() # mask for positive values
        quantize = torch.ones_like(x) * self.step_size
        out = torch.mul(torch.floor(torch.div(x, quantize)), self.step_size) # quantize by step_size
        out = torch.mul(out, mask) # zero-out negative values
        out = torch.abs(out) # remove sign
        return out

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = KNOWN_MODELS["BiT-M-R50x1V2"](head_size=2, zero_head=True)
# Load weights
checkpoint = torch.load('/home/l/liny/ruofan/PhishIntention/src/credential_classifier/output/hybrid/hybrid_lr0.005/BiT-M-R50x1V2_0.005.pth.tar', 
                        map_location="cpu")["model"]

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

# replace relu with defenselayer 
# model.body.block1.unit01.relu = QuantizeRelu()
# model.body.block1.unit02.relu = QuantizeRelu()
# model.body.block1.unit03.relu = QuantizeRelu()

# model.body.block2.unit01.relu = QuantizeRelu()
# model.body.block2.unit02.relu = QuantizeRelu()
# model.body.block2.unit03.relu = QuantizeRelu()
# model.body.block2.unit04.relu = QuantizeRelu()

# model.body.block3.unit01.relu = QuantizeRelu()
# model.body.block3.unit02.relu = QuantizeRelu()
# model.body.block3.unit03.relu = QuantizeRelu()
# model.body.block3.unit04.relu = QuantizeRelu()
# model.body.block3.unit05.relu = QuantizeRelu()
# model.body.block3.unit06.relu = QuantizeRelu()

model.body.block4.unit01.relu = QuantizeRelu()
model.body.block4.unit02.relu = QuantizeRelu()
model.body.block4.unit03.relu = QuantizeRelu()

model.to(device)
model.eval()

ResNetV2Coord(
  (root): Sequential(
    (conv): StdConv2d(8, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (pad): ConstantPad2d(padding=(1, 1, 1, 1), value=0)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (body): Sequential(
    (block1): Sequential(
      (unit01): PreActBottleneck(
        (gn1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (gn2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (gn3): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv3): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu): ReLU(inplace=True)
        (downsample): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (unit02): PreActBottleneck(
        (gn1): GroupNorm(32, 256, eps=1e-05

In [19]:
# compute_acc(valid_loader, model, device)

Accuracy after changing relu function: 0.93


0.9345417925478349

## Attack

In [11]:
from adv_attack.attack.Attack import *

criterion = nn.CrossEntropyLoss()
check = adversarial_attack(method='fgsm', model=model, dataloader=valid_loader, 
                           device=device, num_classes=2, save_data=True)

acc, _ = check.batch_attack()

3558it [6:33:19,  6.94s/it]

3557
Test Accuracy = 0.9471613265879708


3559it [6:33:25,  6.84s/it]

3558
Test Accuracy = 0.9471761730823265


3560it [6:33:33,  6.94s/it]

3559
Test Accuracy = 0.9471910112359551


3561it [6:33:39,  6.92s/it]

3560
Test Accuracy = 0.9472058410558831


3562it [6:33:49,  7.63s/it]

3561
Test Accuracy = 0.9472206625491297


3563it [6:33:56,  7.37s/it]

3562
Test Accuracy = 0.9472354757227056


3564it [6:34:02,  7.17s/it]

3563
Test Accuracy = 0.9472502805836139


3565it [6:34:09,  7.03s/it]

3564
Test Accuracy = 0.9472650771388499


3566it [6:34:16,  7.15s/it]

3565
Test Accuracy = 0.947279865395401


3567it [6:34:23,  7.10s/it]

3566
Test Accuracy = 0.9472946453602467


3568it [6:34:31,  7.14s/it]

3567
Test Accuracy = 0.9473094170403588


3569it [6:34:39,  7.65s/it]

3568
Test Accuracy = 0.9473241804427011


3570it [6:34:46,  7.24s/it]

3569
Test Accuracy = 0.9473389355742297


3571it [6:34:53,  7.18s/it]

3570
Test Accuracy = 0.947353682441893


3572it [6:35:00,  7.10s/it]

3571
Test Accuracy = 0.9473684210526315


3573it [6:35:06,  6.92s/it]

3572
Test Accuracy = 0.9473831514133781


3574it [6:35:13,  6.83s/it]

3573
Test Accuracy = 0.9473978735310576


3575it [6:35:20,  6.81s/it]

3574
Test Accuracy = 0.9474125874125874


3576it [6:35:26,  6.80s/it]

3575
Test Accuracy = 0.9474272930648769


3577it [6:35:27,  4.83s/it]

3576


3578it [6:35:33,  5.35s/it]

3577
Test Accuracy = 0.9471771939631078


3579it [6:35:39,  5.64s/it]

3578
Test Accuracy = 0.9471919530595139


3580it [6:35:47,  6.24s/it]

3579
Test Accuracy = 0.9472067039106146


3581it [6:35:54,  6.54s/it]

3580
Test Accuracy = 0.9472214465233175


3582it [6:36:01,  6.56s/it]

3581
Test Accuracy = 0.9472361809045227


3583it [6:36:09,  6.88s/it]

3582
Test Accuracy = 0.9472509070611219


3584it [6:36:16,  7.02s/it]

3583
Test Accuracy = 0.947265625


3585it [6:36:22,  6.88s/it]

3584
Test Accuracy = 0.9472803347280335


3586it [6:36:29,  6.93s/it]

3585
Test Accuracy = 0.9472950362520914


3587it [6:36:36,  6.78s/it]

3586
Test Accuracy = 0.9473097295790354


3588it [6:36:43,  6.87s/it]

3587
Test Accuracy = 0.947324414715719


3589it [6:36:51,  7.19s/it]

3588
Test Accuracy = 0.9473390916689886


3590it [6:36:58,  7.03s/it]

3589
Test Accuracy = 0.9473537604456824


3591it [6:37:05,  6.99s/it]

3590
Test Accuracy = 0.9473684210526315


3592it [6:37:11,  6.89s/it]

3591
Test Accuracy = 0.9473830734966593


3593it [6:37:21,  7.73s/it]

3592
Test Accuracy = 0.9473977177845812


3594it [6:37:28,  7.43s/it]

3593
Test Accuracy = 0.9474123539232053


3595it [6:37:35,  7.45s/it]

3594
Test Accuracy = 0.9474269819193324


3596it [6:37:42,  7.19s/it]

3595
Test Accuracy = 0.9474416017797553


3597it [6:37:48,  6.98s/it]

3596
Test Accuracy = 0.9474562135112594


3598it [6:37:54,  6.77s/it]

3597
Test Accuracy = 0.9474708171206225


3599it [6:38:01,  6.69s/it]

3598
Test Accuracy = 0.9474854126146152


3600it [6:38:08,  6.89s/it]

3599
Test Accuracy = 0.9475


3601it [6:38:15,  6.84s/it]

3600
Test Accuracy = 0.9475145792835323


3602it [6:38:22,  6.94s/it]

3601
Test Accuracy = 0.9475291504719601


3603it [6:38:29,  6.91s/it]

3602
Test Accuracy = 0.9475437135720233


3604it [6:38:29,  4.90s/it]

3603


3605it [6:38:29,  3.49s/it]

3604


3606it [6:38:36,  4.51s/it]

3605
Test Accuracy = 0.9470327232390461


3607it [6:38:43,  5.23s/it]

3606
Test Accuracy = 0.9470474078181315


3608it [6:38:50,  5.65s/it]

3607
Test Accuracy = 0.9470620842572062


3609it [6:38:57,  6.06s/it]

3608
Test Accuracy = 0.9470767525630368


3610it [6:39:04,  6.43s/it]

3609
Test Accuracy = 0.9470914127423823


3611it [6:39:11,  6.48s/it]

3610
Test Accuracy = 0.9471060648019939


3612it [6:39:18,  6.67s/it]

3611
Test Accuracy = 0.9471207087486158


3613it [6:39:25,  6.84s/it]

3612
Test Accuracy = 0.9471353445889842


3614it [6:39:32,  6.79s/it]

3613
Test Accuracy = 0.9471499723298284


3616it [6:39:39,  4.95s/it]

3614
Test Accuracy = 0.94716459197787
3615


3617it [6:39:47,  5.59s/it]

3616
Test Accuracy = 0.9469173348078518


3618it [6:39:53,  5.94s/it]

3617
Test Accuracy = 0.9469320066334992


3619it [6:40:00,  6.31s/it]

3618
Test Accuracy = 0.9469466703509257


3620it [6:40:07,  6.49s/it]

3619
Test Accuracy = 0.9469613259668508


3621it [6:40:14,  6.53s/it]

3620
Test Accuracy = 0.9469759734879868


3622it [6:40:21,  6.72s/it]

3621
Test Accuracy = 0.9469906129210381


3623it [6:40:28,  6.72s/it]

3622
Test Accuracy = 0.9470052442727022


3624it [6:40:35,  6.76s/it]

3623
Test Accuracy = 0.9470198675496688


3625it [6:40:42,  6.97s/it]

3624
Test Accuracy = 0.9470344827586207


3626it [6:40:49,  6.90s/it]

3625
Test Accuracy = 0.9470490899062327


3627it [6:40:56,  6.81s/it]

3626
Test Accuracy = 0.9470636889991728


3628it [6:41:02,  6.75s/it]

3627
Test Accuracy = 0.9470782800441014


3629it [6:41:09,  6.69s/it]

3628
Test Accuracy = 0.9470928630476715


3630it [6:41:15,  6.64s/it]

3629
Test Accuracy = 0.947107438016529


3631it [6:41:22,  6.68s/it]

3630
Test Accuracy = 0.947122004957312


3632it [6:41:29,  6.67s/it]

3631
Test Accuracy = 0.947136563876652


3633it [6:41:35,  6.68s/it]

3632
Test Accuracy = 0.9471511147811725


3634it [6:41:44,  7.34s/it]

3633
Test Accuracy = 0.9471656576774904


3635it [6:41:51,  7.06s/it]

3634
Test Accuracy = 0.9471801925722145


3636it [6:42:00,  7.61s/it]

3635
Test Accuracy = 0.9471947194719472


3637it [6:42:05,  7.08s/it]

3636
Test Accuracy = 0.947209238383283


3638it [6:42:12,  7.06s/it]

3637
Test Accuracy = 0.9472237493128093


3639it [6:42:19,  6.98s/it]

3638
Test Accuracy = 0.9472382522671063


3640it [6:42:26,  6.96s/it]

3639
Test Accuracy = 0.9472527472527472


3641it [6:42:33,  6.81s/it]

3640
Test Accuracy = 0.9472672342762977


3642it [6:42:33,  4.83s/it]

3641


3643it [6:42:40,  5.41s/it]

3642
Test Accuracy = 0.947021685424101


3644it [6:42:46,  5.63s/it]

3643
Test Accuracy = 0.9470362239297475


3645it [6:42:52,  5.79s/it]

3644
Test Accuracy = 0.9470507544581619


3646it [6:42:58,  5.92s/it]

3645
Test Accuracy = 0.9470652770159078


3647it [6:43:04,  5.88s/it]

3646
Test Accuracy = 0.9470797916095421


3648it [6:43:11,  6.21s/it]

3647
Test Accuracy = 0.9470942982456141


3649it [6:43:18,  6.46s/it]

3648
Test Accuracy = 0.947108796930666


3650it [6:43:25,  6.54s/it]

3649
Test Accuracy = 0.9471232876712329


3651it [6:43:25,  4.65s/it]

3650


3652it [6:43:31,  5.24s/it]

3651
Test Accuracy = 0.9468784227820373


3653it [6:43:38,  5.66s/it]

3652
Test Accuracy = 0.946892964686559


3654it [6:43:45,  5.93s/it]

3653
Test Accuracy = 0.9469074986316366


3655it [6:43:52,  6.28s/it]

3654
Test Accuracy = 0.946922024623803


3656it [6:43:58,  6.38s/it]

3655
Test Accuracy = 0.9469365426695843


3657it [6:44:05,  6.43s/it]

3656
Test Accuracy = 0.9469510527754991


3658it [6:44:11,  6.35s/it]

3657
Test Accuracy = 0.946965554948059


3659it [6:44:17,  6.37s/it]

3658
Test Accuracy = 0.9469800491937688


3660it [6:44:24,  6.31s/it]

3659
Test Accuracy = 0.9469945355191257


3661it [6:44:31,  6.71s/it]

3660
Test Accuracy = 0.9470090139306201


3662it [6:44:38,  6.87s/it]

3661
Test Accuracy = 0.9470234844347352


3663it [6:44:46,  6.95s/it]

3662
Test Accuracy = 0.9470379470379471


3664it [6:44:52,  6.84s/it]

3663
Test Accuracy = 0.9470524017467249


3665it [6:44:59,  6.89s/it]

3664
Test Accuracy = 0.9470668485675307


3666it [6:45:06,  6.98s/it]

3665
Test Accuracy = 0.9470812875068194


3667it [6:45:15,  7.32s/it]

3666
Test Accuracy = 0.947095718571039


3668it [6:45:26,  8.52s/it]

3667
Test Accuracy = 0.9471101417666303


3669it [6:45:33,  8.03s/it]

3668
Test Accuracy = 0.9471245571000273


3670it [6:45:39,  7.58s/it]

3669
Test Accuracy = 0.9471389645776567


3671it [6:45:47,  7.49s/it]

3670
Test Accuracy = 0.9471533642059384


3672it [6:45:54,  7.57s/it]

3671
Test Accuracy = 0.9471677559912854


3673it [6:46:01,  7.44s/it]

3672
Test Accuracy = 0.9471821399401035


3674it [6:46:08,  7.20s/it]

3673
Test Accuracy = 0.9471965160587915


3675it [6:46:15,  7.07s/it]

3674
Test Accuracy = 0.9472108843537415


3676it [6:46:22,  6.98s/it]

3675
Test Accuracy = 0.9472252448313384


3677it [6:46:29,  7.11s/it]

3676
Test Accuracy = 0.9472395974979603


3678it [6:46:36,  6.94s/it]

3677
Test Accuracy = 0.9472539423599783


3679it [6:46:42,  6.87s/it]

3678
Test Accuracy = 0.9472682794237565


3680it [6:46:49,  6.93s/it]

3679
Test Accuracy = 0.9472826086956522


3681it [6:46:57,  7.17s/it]

3680
Test Accuracy = 0.9472969301820158


3682it [6:47:04,  7.13s/it]

3681
Test Accuracy = 0.9473112438891906


3683it [6:47:11,  6.99s/it]

3682
Test Accuracy = 0.9473255498235135


3684it [6:47:18,  6.99s/it]

3683
Test Accuracy = 0.9473398479913138


3685it [6:47:25,  7.02s/it]

3684
Test Accuracy = 0.9473541383989145


3686it [6:47:31,  6.88s/it]

3685
Test Accuracy = 0.9473684210526315


3687it [6:47:38,  6.74s/it]

3686
Test Accuracy = 0.9473826959587741


3688it [6:47:48,  7.84s/it]

3687
Test Accuracy = 0.9473969631236443


3689it [6:47:56,  7.75s/it]

3688
Test Accuracy = 0.9474112225535375


3690it [6:48:07,  8.84s/it]

3689
Test Accuracy = 0.9474254742547426


3691it [6:48:14,  8.32s/it]

3690
Test Accuracy = 0.947439718233541


3692it [6:48:21,  7.88s/it]

3691
Test Accuracy = 0.9474539544962081


3693it [6:48:28,  7.68s/it]

3692
Test Accuracy = 0.9474681830490116


3694it [6:48:35,  7.44s/it]

3693
Test Accuracy = 0.9474824038982134


3695it [6:48:35,  5.27s/it]

3694


3696it [6:48:36,  3.76s/it]

3695


3697it [6:48:42,  4.58s/it]

3696
Test Accuracy = 0.9469840411144171


3698it [6:48:49,  5.27s/it]

3697
Test Accuracy = 0.9469983775013521


3699it [6:48:57,  5.97s/it]

3698
Test Accuracy = 0.9470127061367938


3700it [6:49:05,  6.59s/it]

3699
Test Accuracy = 0.947027027027027


3701it [6:49:11,  6.50s/it]

3700
Test Accuracy = 0.9470413401783302


3702it [6:49:18,  6.53s/it]

3701
Test Accuracy = 0.9470556455969746


3703it [6:49:24,  6.61s/it]

3702
Test Accuracy = 0.947069943289225


3704it [6:49:33,  7.20s/it]

3703
Test Accuracy = 0.9470842332613391


3705it [6:49:40,  7.02s/it]

3704
Test Accuracy = 0.9470985155195681


3706it [6:49:50,  8.19s/it]

3705
Test Accuracy = 0.9471127900701565


3707it [6:49:58,  7.88s/it]

3706
Test Accuracy = 0.9471270569193417


3708it [6:50:05,  7.63s/it]

3707
Test Accuracy = 0.9471413160733549


3709it [6:50:11,  7.37s/it]

3708
Test Accuracy = 0.94715556753842


3710it [6:50:18,  7.21s/it]

3709
Test Accuracy = 0.9471698113207547


3711it [6:50:25,  6.98s/it]

3710
Test Accuracy = 0.9471840474265697


3712it [6:50:31,  6.65s/it]

3711
Test Accuracy = 0.947198275862069


3713it [6:50:37,  6.72s/it]

3712
Test Accuracy = 0.9472124966334501


3714it [6:50:44,  6.72s/it]

3713
Test Accuracy = 0.9472267097469036


3715it [6:50:51,  6.69s/it]

3714
Test Accuracy = 0.9472409152086138


3716it [6:50:57,  6.61s/it]

3715
Test Accuracy = 0.9472551130247578


3717it [6:51:04,  6.63s/it]

3716
Test Accuracy = 0.9472693032015066


3718it [6:51:11,  6.81s/it]

3717
Test Accuracy = 0.9472834857450242


3719it [6:51:18,  6.83s/it]

3718
Test Accuracy = 0.9472976606614681


3720it [6:51:26,  7.10s/it]

3719
Test Accuracy = 0.9473118279569892


3721it [6:51:32,  6.97s/it]

3720
Test Accuracy = 0.9473259876377318


3722it [6:51:39,  6.87s/it]

3721
Test Accuracy = 0.9473401397098334


3723it [6:51:46,  6.80s/it]

3722
Test Accuracy = 0.9473542841794252


3724it [6:51:52,  6.81s/it]

3723
Test Accuracy = 0.9473684210526315


3725it [6:51:59,  6.72s/it]

3724
Test Accuracy = 0.9473825503355705


3726it [6:52:06,  6.71s/it]

3725
Test Accuracy = 0.9473966720343532


3727it [6:52:13,  6.79s/it]

3726
Test Accuracy = 0.9474107861550846


3728it [6:52:20,  6.83s/it]

3727
Test Accuracy = 0.9474248927038627


3729it [6:52:20,  4.85s/it]

3728


3730it [6:52:27,  5.61s/it]

3729
Test Accuracy = 0.9471849865951742


3731it [6:52:35,  6.64s/it]

3730
Test Accuracy = 0.9471991423210936
Test Accuracy = 3534 / 3731 = 0.9471991423210936





In [12]:
acc

0.9471991423210936

## BPDA

- With defense

In [6]:
from advertorch.bpda import BPDAWrapper
from advertorch.attacks import LinfPGDAttack

In [12]:
# original model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = KNOWN_MODELS["BiT-M-R50x1V2"](head_size=2, zero_head=True)
# Load weights
checkpoint = torch.load('/home/l/liny/ruofan/PhishIntention/src/credential_classifier/output/hybrid/hybrid_lr0.005/BiT-M-R50x1V2_0.005.pth.tar', 
                        map_location="cpu")["model"]

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)
model.to(device)
model.eval()

# BPDA approximation of gradients
defense_layer = BPDAWrapper(QuantizeRelu())

# # replace relu with defenselayer 
# model.body.block1.unit01.relu = defense_layer
# model.body.block1.unit02.relu = defense_layer
# model.body.block1.unit03.relu = defense_layer

# model.body.block2.unit01.relu = defense_layer
# model.body.block2.unit02.relu = defense_layer
# model.body.block2.unit03.relu = defense_layer
# model.body.block2.unit04.relu = defense_layer

# model.body.block3.unit01.relu = defense_layer
# model.body.block3.unit02.relu = defense_layer
# model.body.block3.unit03.relu = defense_layer
# model.body.block3.unit04.relu = defense_layer
# model.body.block3.unit05.relu = defense_layer
# model.body.block3.unit06.relu = defense_layer

model.body.block4.unit01.relu = defense_layer
model.body.block4.unit02.relu = defense_layer
model.body.block4.unit03.relu = defense_layer

bpda_adversary = LinfPGDAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.05,
    nb_iter=100, eps_iter=0.005, rand_init=True, clip_min=0.0, clip_max=1.0,
    targeted=False)


In [None]:
perturb_correct = 0
total = 0

for cln_data, true_label in tqdm(valid_loader):
    cln_data, true_label = cln_data.to(device, dtype=torch.float), true_label.to(device)
    bpda_adv = bpda_adversary.perturb(cln_data, true_label)
    
    logits = model(bpda_adv)
    pred_cls = torch.argmax(logits, dim=1)
    perturb_correct += torch.sum(torch.eq(pred_cls, true_label)).item()
    total += len(true_label)
    
    print(perturb_correct/total)

  0%|          | 1/467 [00:08<1:04:09,  8.26s/it]

0.0


  0%|          | 2/467 [00:16<1:04:06,  8.27s/it]

0.0


  1%|          | 3/467 [00:24<1:03:03,  8.15s/it]

0.0


  1%|          | 4/467 [00:32<1:02:16,  8.07s/it]

0.0


  1%|          | 5/467 [00:41<1:04:15,  8.35s/it]

0.0


  1%|▏         | 6/467 [00:49<1:03:33,  8.27s/it]

0.0


  1%|▏         | 7/467 [00:57<1:03:31,  8.29s/it]

0.0


  2%|▏         | 8/467 [01:06<1:04:35,  8.44s/it]

0.0


  2%|▏         | 9/467 [01:14<1:04:11,  8.41s/it]

0.0


  2%|▏         | 10/467 [01:22<1:03:03,  8.28s/it]

0.0


  2%|▏         | 11/467 [01:30<1:02:08,  8.18s/it]

0.0


  3%|▎         | 12/467 [01:39<1:02:10,  8.20s/it]

0.0


  3%|▎         | 13/467 [01:47<1:01:51,  8.17s/it]

0.0


  3%|▎         | 14/467 [01:54<1:00:56,  8.07s/it]

0.0


  3%|▎         | 15/467 [02:02<1:00:42,  8.06s/it]

0.0


  3%|▎         | 16/467 [02:11<1:00:30,  8.05s/it]

0.0


  4%|▎         | 17/467 [02:18<1:00:07,  8.02s/it]

0.0


  4%|▍         | 18/467 [02:27<1:00:39,  8.11s/it]

0.0


  4%|▍         | 19/467 [02:35<1:00:50,  8.15s/it]

0.0


  4%|▍         | 20/467 [02:43<1:00:51,  8.17s/it]

0.0


  4%|▍         | 21/467 [02:51<1:00:25,  8.13s/it]

0.0


  5%|▍         | 22/467 [02:59<1:00:07,  8.11s/it]

0.0


  5%|▍         | 23/467 [03:07<59:46,  8.08s/it]  

0.0


  5%|▌         | 24/467 [03:15<59:18,  8.03s/it]

0.0


  5%|▌         | 25/467 [03:24<1:00:02,  8.15s/it]

0.0


  6%|▌         | 26/467 [03:32<59:47,  8.13s/it]  

0.0


  6%|▌         | 27/467 [03:40<59:26,  8.11s/it]

0.0


  6%|▌         | 28/467 [03:48<59:15,  8.10s/it]

0.0


  6%|▌         | 29/467 [03:56<59:51,  8.20s/it]

0.0


  6%|▋         | 30/467 [04:05<59:46,  8.21s/it]

0.0


  7%|▋         | 31/467 [04:13<59:15,  8.16s/it]

0.0


  7%|▋         | 32/467 [04:21<58:33,  8.08s/it]

0.0


  7%|▋         | 33/467 [04:28<58:03,  8.03s/it]

0.0


  7%|▋         | 34/467 [04:36<57:47,  8.01s/it]

0.0


  7%|▋         | 35/467 [04:45<58:29,  8.12s/it]

0.0


  8%|▊         | 36/467 [04:53<57:50,  8.05s/it]

0.0


  8%|▊         | 37/467 [05:01<57:47,  8.06s/it]

0.0


  8%|▊         | 38/467 [05:09<57:19,  8.02s/it]

0.0


  8%|▊         | 39/467 [05:17<57:19,  8.04s/it]

0.0


  9%|▊         | 40/467 [05:25<57:11,  8.04s/it]

0.0


  9%|▉         | 41/467 [05:33<57:32,  8.11s/it]

0.0


  9%|▉         | 42/467 [05:41<57:20,  8.10s/it]

0.0


  9%|▉         | 43/467 [05:53<1:05:40,  9.29s/it]

0.0


  9%|▉         | 44/467 [06:02<1:03:35,  9.02s/it]

0.0


 10%|▉         | 45/467 [06:10<1:01:49,  8.79s/it]

0.0


 10%|▉         | 46/467 [06:18<1:00:11,  8.58s/it]

0.0


 10%|█         | 47/467 [06:26<58:35,  8.37s/it]  

0.0


 10%|█         | 48/467 [06:34<58:01,  8.31s/it]

0.0


 10%|█         | 49/467 [06:42<57:01,  8.19s/it]

0.0


 11%|█         | 50/467 [06:50<56:24,  8.12s/it]

0.0


 11%|█         | 51/467 [06:58<55:59,  8.08s/it]

0.0


 11%|█         | 52/467 [07:06<55:24,  8.01s/it]

0.0


 11%|█▏        | 53/467 [07:14<55:45,  8.08s/it]

0.0


 12%|█▏        | 54/467 [07:22<55:50,  8.11s/it]

0.0


 12%|█▏        | 55/467 [07:30<55:33,  8.09s/it]

0.0


 12%|█▏        | 56/467 [07:39<56:03,  8.18s/it]

0.0


 12%|█▏        | 57/467 [07:47<55:42,  8.15s/it]

0.0


 12%|█▏        | 58/467 [07:55<55:47,  8.19s/it]

0.0


 13%|█▎        | 59/467 [08:04<56:35,  8.32s/it]

0.0


 13%|█▎        | 60/467 [08:12<56:15,  8.29s/it]

0.0


 13%|█▎        | 61/467 [08:20<55:36,  8.22s/it]

0.0


 13%|█▎        | 62/467 [08:28<55:32,  8.23s/it]

0.0


 13%|█▎        | 63/467 [08:36<55:09,  8.19s/it]

0.0


 14%|█▎        | 64/467 [08:44<54:38,  8.13s/it]

0.0


 14%|█▍        | 65/467 [08:52<54:20,  8.11s/it]

0.0


 14%|█▍        | 66/467 [09:02<57:26,  8.59s/it]

0.0


 14%|█▍        | 67/467 [09:10<56:35,  8.49s/it]

0.0


 15%|█▍        | 68/467 [09:18<55:23,  8.33s/it]

0.0


 15%|█▍        | 69/467 [09:27<55:22,  8.35s/it]

0.0


 15%|█▍        | 70/467 [09:35<55:05,  8.33s/it]

0.0


 15%|█▌        | 71/467 [09:43<54:29,  8.26s/it]

0.0


 15%|█▌        | 72/467 [09:51<54:58,  8.35s/it]

0.0


 16%|█▌        | 73/467 [10:00<54:57,  8.37s/it]

0.0


 16%|█▌        | 74/467 [10:12<1:03:03,  9.63s/it]

0.0


 16%|█▌        | 75/467 [10:20<59:35,  9.12s/it]  

0.0


 16%|█▋        | 76/467 [10:28<56:55,  8.74s/it]

0.0


 16%|█▋        | 77/467 [10:41<1:04:50,  9.98s/it]

0.0


 17%|█▋        | 78/467 [10:49<1:01:09,  9.43s/it]

0.0


 17%|█▋        | 79/467 [11:00<1:03:31,  9.82s/it]

0.0


 17%|█▋        | 80/467 [11:08<59:41,  9.26s/it]  

0.0


 17%|█▋        | 81/467 [11:16<57:22,  8.92s/it]

0.0


 18%|█▊        | 82/467 [11:24<55:14,  8.61s/it]

0.0


 18%|█▊        | 83/467 [11:32<53:42,  8.39s/it]

0.0


 18%|█▊        | 84/467 [11:40<52:41,  8.25s/it]

0.0


 18%|█▊        | 85/467 [11:48<51:59,  8.17s/it]

0.0


 18%|█▊        | 86/467 [11:55<51:03,  8.04s/it]

0.0


 19%|█▊        | 87/467 [12:03<50:37,  7.99s/it]

0.0


 19%|█▉        | 88/467 [12:11<50:09,  7.94s/it]

0.0


 19%|█▉        | 89/467 [12:19<50:00,  7.94s/it]

0.0


 19%|█▉        | 90/467 [12:27<49:55,  7.95s/it]

0.0


 19%|█▉        | 91/467 [12:36<50:58,  8.14s/it]

0.0


 20%|█▉        | 92/467 [12:44<51:04,  8.17s/it]

0.0


 20%|█▉        | 93/467 [12:52<50:27,  8.10s/it]

0.0


 20%|██        | 94/467 [13:00<49:43,  8.00s/it]

0.0


 20%|██        | 95/467 [13:10<53:40,  8.66s/it]

0.0


 21%|██        | 96/467 [13:18<52:09,  8.43s/it]

0.0


 21%|██        | 97/467 [13:27<54:01,  8.76s/it]

0.0


 21%|██        | 98/467 [13:35<52:13,  8.49s/it]

0.0


 21%|██        | 99/467 [13:43<51:39,  8.42s/it]

0.0


 21%|██▏       | 100/467 [13:51<50:42,  8.29s/it]

0.0


 22%|██▏       | 101/467 [14:02<54:26,  8.92s/it]

0.0


 22%|██▏       | 102/467 [14:10<52:56,  8.70s/it]

0.0


 22%|██▏       | 103/467 [14:20<54:51,  9.04s/it]

0.0


 22%|██▏       | 104/467 [14:28<52:54,  8.74s/it]

0.0


 22%|██▏       | 105/467 [14:36<51:34,  8.55s/it]

0.0


 23%|██▎       | 106/467 [14:47<55:19,  9.19s/it]

0.0


 23%|██▎       | 107/467 [14:55<52:58,  8.83s/it]

0.0


 23%|██▎       | 108/467 [15:05<55:25,  9.26s/it]

0.0


 23%|██▎       | 109/467 [15:13<52:54,  8.87s/it]

0.0


 24%|██▎       | 110/467 [15:22<53:45,  9.03s/it]

0.0


 24%|██▍       | 111/467 [15:31<52:42,  8.88s/it]

0.0


 24%|██▍       | 112/467 [15:40<53:19,  9.01s/it]

0.0


 24%|██▍       | 113/467 [15:48<51:32,  8.74s/it]

0.0


 24%|██▍       | 114/467 [15:56<50:16,  8.55s/it]

0.0


 25%|██▍       | 115/467 [16:04<49:09,  8.38s/it]

0.0


 25%|██▍       | 116/467 [16:12<48:24,  8.28s/it]

0.0


 25%|██▌       | 117/467 [16:20<47:52,  8.21s/it]

0.0


 25%|██▌       | 118/467 [16:32<54:12,  9.32s/it]

0.0


 25%|██▌       | 119/467 [16:40<51:43,  8.92s/it]

0.0


 26%|██▌       | 120/467 [16:50<52:39,  9.11s/it]

0.0


 26%|██▌       | 121/467 [16:58<50:37,  8.78s/it]

0.0


 26%|██▌       | 122/467 [17:06<48:55,  8.51s/it]

0.0


 26%|██▋       | 123/467 [17:14<48:01,  8.38s/it]

0.0


 27%|██▋       | 124/467 [17:22<47:28,  8.30s/it]

0.0


 27%|██▋       | 125/467 [17:30<46:42,  8.19s/it]

0.0


 27%|██▋       | 126/467 [17:38<45:47,  8.06s/it]

0.0


 27%|██▋       | 127/467 [17:45<45:16,  7.99s/it]

0.0


 27%|██▋       | 128/467 [17:53<44:50,  7.94s/it]

0.0


 28%|██▊       | 129/467 [18:01<44:53,  7.97s/it]

0.0


 28%|██▊       | 130/467 [18:09<44:41,  7.96s/it]

0.0


 28%|██▊       | 131/467 [18:17<45:00,  8.04s/it]

0.0


 28%|██▊       | 132/467 [18:25<44:38,  8.00s/it]

0.0


 28%|██▊       | 133/467 [18:33<44:34,  8.01s/it]

0.0


 29%|██▊       | 134/467 [18:41<43:57,  7.92s/it]

0.0


 29%|██▉       | 135/467 [18:49<43:37,  7.88s/it]

0.0


 29%|██▉       | 136/467 [18:57<43:29,  7.88s/it]

0.0


 29%|██▉       | 137/467 [19:05<43:32,  7.92s/it]

0.0


 30%|██▉       | 138/467 [19:13<43:43,  7.98s/it]

0.0


 30%|██▉       | 139/467 [19:21<43:38,  7.98s/it]

0.0


 30%|██▉       | 140/467 [19:29<43:30,  7.98s/it]

0.0


 30%|███       | 141/467 [19:37<43:07,  7.94s/it]

0.0


 30%|███       | 142/467 [19:45<43:01,  7.94s/it]

0.0


 31%|███       | 143/467 [19:53<43:06,  7.98s/it]

0.0


 31%|███       | 144/467 [20:01<42:50,  7.96s/it]

0.0


 31%|███       | 145/467 [20:09<42:51,  7.99s/it]

0.0


 31%|███▏      | 146/467 [20:16<42:29,  7.94s/it]

0.0


 31%|███▏      | 147/467 [20:24<42:07,  7.90s/it]

0.0


 32%|███▏      | 148/467 [20:32<42:19,  7.96s/it]

0.0


 32%|███▏      | 149/467 [20:40<41:49,  7.89s/it]

0.0


 32%|███▏      | 150/467 [20:48<41:39,  7.88s/it]

0.0


 32%|███▏      | 151/467 [20:56<41:26,  7.87s/it]

0.0


 33%|███▎      | 152/467 [21:04<41:21,  7.88s/it]

0.0


 33%|███▎      | 153/467 [21:12<41:20,  7.90s/it]

0.0


 33%|███▎      | 154/467 [21:19<41:04,  7.87s/it]

0.0


 33%|███▎      | 155/467 [21:28<41:50,  8.05s/it]

0.0


 33%|███▎      | 156/467 [21:37<43:00,  8.30s/it]

0.0


 34%|███▎      | 157/467 [21:45<42:29,  8.22s/it]

0.0


 34%|███▍      | 158/467 [21:53<41:43,  8.10s/it]

0.0


 34%|███▍      | 159/467 [22:00<41:13,  8.03s/it]

0.0


 34%|███▍      | 160/467 [22:09<41:04,  8.03s/it]

0.0


 34%|███▍      | 161/467 [22:16<40:45,  7.99s/it]

0.0


 35%|███▍      | 162/467 [22:24<40:29,  7.97s/it]

0.0


 35%|███▍      | 163/467 [22:32<40:12,  7.94s/it]

0.0


 35%|███▌      | 164/467 [22:40<39:57,  7.91s/it]

0.0


 35%|███▌      | 165/467 [22:48<39:57,  7.94s/it]

0.0


 36%|███▌      | 166/467 [22:56<39:42,  7.91s/it]

0.0


 36%|███▌      | 167/467 [23:04<39:32,  7.91s/it]

0.0


 36%|███▌      | 168/467 [23:12<39:36,  7.95s/it]

0.0


 36%|███▌      | 169/467 [23:20<39:45,  8.00s/it]

0.0


 36%|███▋      | 170/467 [23:28<39:59,  8.08s/it]

0.0


 37%|███▋      | 171/467 [23:36<39:46,  8.06s/it]

0.0


 37%|███▋      | 172/467 [23:44<39:28,  8.03s/it]

0.0


 37%|███▋      | 173/467 [23:52<39:02,  7.97s/it]

0.0


 37%|███▋      | 174/467 [24:00<38:40,  7.92s/it]

0.0


 37%|███▋      | 175/467 [24:09<39:55,  8.20s/it]

0.0


 38%|███▊      | 176/467 [24:17<39:26,  8.13s/it]

0.0


 38%|███▊      | 177/467 [24:25<39:05,  8.09s/it]

0.0


 38%|███▊      | 178/467 [24:33<38:48,  8.06s/it]

0.0


 38%|███▊      | 179/467 [24:40<38:20,  7.99s/it]

0.0


 39%|███▊      | 180/467 [24:48<37:58,  7.94s/it]

0.0


 39%|███▉      | 181/467 [24:56<37:53,  7.95s/it]

0.0


 39%|███▉      | 182/467 [25:04<37:51,  7.97s/it]

0.0


 39%|███▉      | 183/467 [25:12<37:42,  7.97s/it]

0.0


 39%|███▉      | 184/467 [25:20<37:50,  8.02s/it]

0.0


 40%|███▉      | 185/467 [25:28<37:33,  7.99s/it]

0.0


 40%|███▉      | 186/467 [25:36<37:25,  7.99s/it]

0.0


 40%|████      | 187/467 [25:44<37:13,  7.98s/it]

0.0


 40%|████      | 188/467 [25:52<36:59,  7.96s/it]

0.0


 40%|████      | 189/467 [26:00<36:49,  7.95s/it]

0.0


 41%|████      | 190/467 [26:12<42:09,  9.13s/it]

0.0


 41%|████      | 191/467 [26:20<40:16,  8.76s/it]

0.0


 41%|████      | 192/467 [26:28<38:41,  8.44s/it]

0.0


 41%|████▏     | 193/467 [26:41<44:57,  9.84s/it]

0.0


 42%|████▏     | 194/467 [26:48<41:54,  9.21s/it]

0.0


 42%|████▏     | 195/467 [27:01<46:00, 10.15s/it]

0.0


 42%|████▏     | 196/467 [27:09<42:44,  9.46s/it]

0.0


 42%|████▏     | 197/467 [27:16<40:23,  8.98s/it]

0.0


 42%|████▏     | 198/467 [27:24<38:57,  8.69s/it]

0.0


 43%|████▎     | 199/467 [27:32<37:45,  8.45s/it]

0.0


 43%|████▎     | 200/467 [27:40<37:06,  8.34s/it]

0.0


 43%|████▎     | 201/467 [27:49<37:14,  8.40s/it]

0.0


 43%|████▎     | 202/467 [27:57<36:48,  8.33s/it]

0.0


 43%|████▎     | 203/467 [28:05<35:52,  8.15s/it]

0.0


 44%|████▎     | 204/467 [28:13<35:44,  8.16s/it]

0.0


 44%|████▍     | 205/467 [28:21<35:20,  8.09s/it]

0.0


 44%|████▍     | 206/467 [28:29<34:58,  8.04s/it]

0.0


 44%|████▍     | 207/467 [28:37<34:41,  8.01s/it]

0.0


 45%|████▍     | 208/467 [28:45<34:13,  7.93s/it]

0.0


 45%|████▍     | 209/467 [28:52<34:01,  7.91s/it]

0.0


 45%|████▍     | 210/467 [29:01<34:07,  7.97s/it]

0.0


 45%|████▌     | 211/467 [29:10<35:45,  8.38s/it]

0.0


 45%|████▌     | 212/467 [29:18<34:57,  8.23s/it]

0.0


 46%|████▌     | 213/467 [29:28<37:26,  8.84s/it]

0.0


 46%|████▌     | 214/467 [29:36<36:11,  8.58s/it]

0.0


 46%|████▌     | 215/467 [29:44<35:03,  8.35s/it]

0.0


 46%|████▋     | 216/467 [29:52<35:05,  8.39s/it]

0.0


 46%|████▋     | 217/467 [30:00<34:19,  8.24s/it]

0.0


 47%|████▋     | 218/467 [30:08<34:07,  8.22s/it]

0.0


 47%|████▋     | 219/467 [30:16<33:34,  8.12s/it]

0.0


 47%|████▋     | 220/467 [30:25<34:22,  8.35s/it]

0.0


 47%|████▋     | 221/467 [30:33<34:08,  8.33s/it]

0.0


 48%|████▊     | 222/467 [30:42<33:58,  8.32s/it]

0.0


In [9]:
perturb_correct/total

0.0

In [24]:
perturb_correct/total

0.905337361530715

In [14]:
torch.randint(0, 277, (1,))

tensor([203])