## Self-supervised Adversarial Training

### Prepared Work
Training a self-supervised model or download an pretrained self-supervised model. 

### Proposed method
1. Obtain the pretrained self-supervised model
2. Generating adversarial examples by PGD-KNN
3. Maximize the mutual information between the representations of clean examples and advesarial examples.

Below is the implementation code of proposed method.
[ADDIM](https://github.com/Philip-Bachman/amdim-public) is selected as the self-supervised model.

In [1]:
import os
import argparse

import torch
from PIL import Image
import mixed_precision
from stats import StatTracker
from datasets_bn import Dataset, build_dataset, get_dataset, get_encoder_size
from model_grad import Model
from checkpoint import Checkpoint
from task_self_supervised import train_self_supervised
from task_classifiers import train_classifiers


parser = argparse.ArgumentParser(description='Infomax Representations -- Self-Supervised Training')
parser.add_argument("--verbosity", help="increase output verbosity")
# parameters for general training stuff
parser.add_argument('--dataset', type=str, default='C10')
parser.add_argument('--batch_size', type=int, default=10,
                    help='input batch size (default: 200)')
parser.add_argument('--learning_rate', type=float, default=0.0002,
                    help='learning rate')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed (default: 1)')
parser.add_argument('--amp', action='store_true', default=False,
                    help='Enables automatic mixed precision')

# parameters for model and training objective
parser.add_argument('--classifiers', action='store_true', default=False,
                    help="Wether to run self-supervised encoder or"
                    "classifier training task")
parser.add_argument('--ndf', type=int, default=128,
                    help='feature width for encoder')
parser.add_argument('--n_rkhs', type=int, default=128,
                    help='number of dimensions in fake RKHS embeddings')
parser.add_argument('--tclip', type=float, default=20.0,
                    help='soft clipping range for NCE scores')
parser.add_argument('--n_depth', type=int, default=3)
parser.add_argument('--use_bn', type=int, default=1)

# parameters for output, logging, checkpointing, etc
parser.add_argument('--output_dir', type=str, default='./runs',
                    help='directory where tensorboard events and checkpoints will be stored')
parser.add_argument('--input_dir', type=str, default='/mnt/imagenet',
                    help="Input directory for the dataset. Not needed For C10,"
                    " C100 or STL10 as the data will be automatically downloaded.")
parser.add_argument('--cpt_load_path', type=str, default='abc.xyz',
                    help='path from which to load checkpoint (if available)')
parser.add_argument('--cpt_name', type=str, default='cifar_amdim_cpt.pth',
                    help='name to use for storing checkpoints during training')
parser.add_argument('--run_name', type=str, default='cifar_default_run',
                    help='name to use for the tensorbaord summary for this run')
# ...
args = parser.parse_args(args=[])

# create target output dir if it doesn't exist yet
if not os.path.isdir(args.output_dir):
    os.mkdir(args.output_dir)

# enable mixed-precision computation if desired
if args.amp:
    mixed_precision.enable_mixed_precision()

# set the RNG seeds (probably more hidden elsewhere...)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# get the dataset
dataset = get_dataset(args.dataset)
enc_size = get_encoder_size(dataset)

# get a helper object for tensorboard logging
log_dir = os.path.join(args.output_dir, args.run_name)
stat_tracker = StatTracker(log_dir=log_dir)

num_classes =10
torch_device = torch.device('cuda')
# create new model with random parameters
model = Model(ndf=args.ndf, n_classes=num_classes, n_rkhs=args.n_rkhs,
              tclip=args.tclip, n_depth=args.n_depth, enc_size=enc_size,
              use_bn=(args.use_bn == 1))
model.init_weights(init_scale=1.0)
# restore model parameters from a checkpoint if requested
checkpoint = Checkpoint(model, args.cpt_load_path, args.output_dir, args.cpt_name)
model = model.to(torch_device)

# select which type of training to do
task = train_classifiers if args.classifiers else train_self_supervised

# this the pretrained model.
ckpt=torch.load('./runs/amdim_cpt.pth') 
params = ckpt['model']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in params.items():
    name = k.replace("module.", "")
    new_state_dict[name] = v
# print(new_state_dict)
model.load_state_dict(new_state_dict)



log_dir: ./runs/cifar_default_run
Using a 32x32 encoder
***** CHECKPOINTING ****************
No checkpoint found. Starting fresh.
************************************


IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [2]:

import logging
import os
import time

import numpy as np
import matplotlib.pyplot as plt
import foolbox
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim

from lib.dataset_utils import *
from lib.cifar_resnet import *
from lib.adv_model import *
from lib.dknn_attack import DKNNAttack

from lib.dknn_attack import DKNNAttack

from lib.cwl2_attack import CWL2Attack
from lib.dknnl2norm import DKNN, DKNNL2
from lib.utils import *
from lib.lip_model import *
from lib.knn import *
from lib.nin import *
from lib.cifar10_model import *

from lib.cifar10_dcgan import Discriminator, Generator

from NCE.NCEAverage import NCEAverage
from NCE.NCECriterion import NCECriterion


def load_cifar10_all_amd(data_dir='./data', val_size=0.1, shuffle=True, seed=1):
    """Load entire CIFAR-10 dataset into tensor"""

    transform = transforms.Compose([
        transforms.ToTensor(),
#         transforms.Lambda(lambda x: x.mul(255)),
        transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                             std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=False, transform=transform)
    validset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=False, transform=transform)
    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=False, transform=transform)

    # Random split train and validation sets
    num_train = len(trainset)
    indices = list(range(num_train))
    split = int(np.floor(val_size * num_train))

    if shuffle:
        np.random.seed(seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=(num_train - split), sampler=train_sampler)
    validloader = torch.utils.data.DataLoader(
        validset, batch_size=split, sampler=valid_sampler)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=len(testset), shuffle=False)

    x_train = next(iter(trainloader))
    x_valid = next(iter(validloader))
    x_test = next(iter(testloader))

    return x_train, x_valid, x_test


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count    



# Set all random seeds
seed = 2019
np.random.seed(seed)
torch.manual_seed(seed)
(x_train, y_train), (x_valid, y_valid), (x_test, y_test) = load_cifar10_all_amd(
    './data', val_size=0.1, seed=seed)




In [3]:
#Evaluate the performance of the pretrained model
layers=['encode']
num = 10000
dknn = DKNNL2(model, x_train, y_train, x_valid, y_valid, layers,
              k=75, num_classes=10)

with torch.no_grad():
    y_pred = dknn.classify(x_test)
    ind = np.where(y_pred.argmax(1) == y_test.numpy())[0]
    print((y_pred.argmax(1) == y_test.numpy()).sum() / y_test.size(0))
    

0.8464


In [5]:
# Generating adversarial example using PGD-kNN
from lib.pgd_norm4 import DKNNPGDAttack
import pickle
attack = DKNNPGDAttack()
layer='encoder'
def attack_batch(x, y, batch_size, layer):
    x_a = torch.zeros_like(x)
    total_num = x.size(0)
    num_batches = total_num // batch_size
    for i in range(num_batches):
        begin = i * batch_size
        end = (i + 1) * batch_size
        x_a[begin:end] = attack(
            dknn, x[begin:end], y[begin:end],
            guide_layer=layer, m=300, binary_search_steps=1000,
            max_iterations=1000, learning_rate=1e-2, initial_const=1e-3,
            abort_early=False, random_start=True, guide=2)
    return x_a



In [None]:
# Maximize the mutual information, finetue the model.
import sys
mods_inf = [m for m in model.info_modules]
optimizer = optim.Adam(
    [{'params': mod.parameters(), 'lr': 0.0001} for mod in mods_inf],
    betas=(0.8, 0.999), weight_decay=1e-5, eps=1e-8)


batch_size =100
batch_num = x_train.shape[0] // batch_size

for i in range(400):

    print(i)
    begin = i * batch_size
    end = (i + 1) * batch_size

    begin_nat = (i+batch_num/2)%batch_num*batch_size
    end_nat = (i+1+batch_num/2)%batch_num*batch_size

    x_ori = x_train[begin:end].cuda()


    x_a = torch.zeros_like(x_ori).cuda()
    y_ori = y_train[begin:end].cuda()
    model.eval()

    dknn = DKNNL2(model, x_train, y_train, x_valid, y_valid, layers,
                  k=75, num_classes=10)

    x_a = attack(
        dknn, x_ori, y_ori,
        guide_layer=layer, m=300, binary_search_steps=10,
        max_iterations=10, learning_rate=1e-2, initial_const=1e-5,
        abort_early=False, random_start=True, guide=2)

    y_pred = dknn.classify(x_a)
    ind = np.where(y_pred.argmax(1) != y_ori.cpu().numpy())[0]
    index = torch.LongTensor(list(range(begin, end)))
    x_a = x_a[ind]
    x_ori = x_ori[ind]
    y_a= y_ori[ind]
    index = index[ind]

    model.train()

    bsz = x_ori.size(0)
    inputs = x_ori.float()
    inputs_adv = x_a.detach().float()
    if torch.cuda.is_available():
        index = index.cuda()
        inputs = inputs.cuda()
        inputs_adv = inputs_adv.cuda()

        # ===================forward=====================
    if inputs.shape[0] == 1:
        inputs = torch.cat((inputs, inputs), 0)
        inputs_adv = torch.cat((inputs_adv, inputs_adv), 0)
        res_dict = model(x1=inputs, x2=inputs_adv, class_only=False)
        lgt_glb_mlp, lgt_glb_lin = res_dict['class']
        # compute costs for all self-supervised tasks
        loss_g2l = (res_dict['g2l_1t5'] +
                    res_dict['g2l_1t7'] +
                    res_dict['g2l_5t5'])
        loss_inf = loss_g2l 
    else:
        res_dict = model(x1=inputs, x2=inputs_adv, class_only=False)
        lgt_glb_mlp, lgt_glb_lin = res_dict['class']
        # compute costs for all self-supervised tasks
        loss_g2l = (res_dict['g2l_1t5'] +
                    res_dict['g2l_1t7'] +
                    res_dict['g2l_5t5'])
        loss_inf = loss_g2l

    loss = loss_inf

    # ===================backward=====================
    optimizer.zero_grad()
    mixed_precision.backward(loss, optimizer)
    optimizer.step()

    # print info
    if (i + 1) % 1 == 0:
        print('loss: {}'.format(loss.item()))
        sys.stdout.flush()

    with torch.no_grad():
        y_pred = dknn.classify(x_test)
        ind = np.where(y_pred.argmax(1) == y_test.numpy())[0]
        print((y_pred.argmax(1) == y_test.numpy()).sum() / y_test.size(0))
        y_pred = dknn.classify(x_a)
        acc = (y_pred.argmax(1) == y_a.cpu().numpy()).sum() / len(y_pred)
        print(acc) 
                
        
    if i %5==0:
        print('==> Saving...')
        state = {
            'model': model.state_dict(),
        }
        torch.save(state, 'save_models/CIDAR_AMDIM_{epoch}.pth'.format(epoch=i))       
        
    with torch.no_grad():
        y_pred = dknn.classify(x_test)
        ind = np.where(y_pred.argmax(1) == y_test.numpy())[0]
        print((y_pred.argmax(1) == y_test.numpy()).sum() / y_test.size(0))


0
step 0 number of successful adv: 42/100
step 1 number of successful adv: 52/100
step 2 number of successful adv: 61/100
step 3 number of successful adv: 73/100
step 4 number of successful adv: 75/100
step 5 number of successful adv: 84/100
step 6 number of successful adv: 83/100
step 7 number of successful adv: 88/100
step 8 number of successful adv: 86/100
step 9 number of successful adv: 91/100
loss: 9.467467308044434
0.8445
0.21978021978021978
==> Saving...
0.8445
1
step 0 number of successful adv: 41/100
step 1 number of successful adv: 48/100
step 2 number of successful adv: 58/100
step 3 number of successful adv: 66/100
step 4 number of successful adv: 74/100
step 5 number of successful adv: 81/100
step 6 number of successful adv: 85/100
step 7 number of successful adv: 89/100
step 8 number of successful adv: 85/100
step 9 number of successful adv: 90/100
loss: 9.386272430419922
0.8442
0.2111111111111111
0.8442
2
step 0 number of successful adv: 44/100
step 1 number of successf

loss: 7.981569766998291
0.8191
0.19767441860465115
0.8191
18
step 0 number of successful adv: 30/100
step 1 number of successful adv: 43/100
step 2 number of successful adv: 56/100
step 3 number of successful adv: 64/100
step 4 number of successful adv: 74/100
step 5 number of successful adv: 77/100
step 6 number of successful adv: 80/100
step 7 number of successful adv: 81/100
step 8 number of successful adv: 83/100
step 9 number of successful adv: 85/100
loss: 7.679532527923584
0.8186
0.23529411764705882
0.8186
19
step 0 number of successful adv: 30/100
step 1 number of successful adv: 38/100
step 2 number of successful adv: 47/100
step 3 number of successful adv: 56/100
step 4 number of successful adv: 68/100
step 5 number of successful adv: 79/100
step 6 number of successful adv: 81/100
step 7 number of successful adv: 81/100
step 8 number of successful adv: 82/100
step 9 number of successful adv: 82/100
loss: 8.01947021484375
0.8187
0.1951219512195122
0.8187
20
step 0 number of su

step 8 number of successful adv: 80/100
step 9 number of successful adv: 77/100
loss: 6.863848686218262
0.8123
0.1038961038961039
==> Saving...
0.8123
36
step 0 number of successful adv: 38/100
step 1 number of successful adv: 45/100
step 2 number of successful adv: 52/100
step 3 number of successful adv: 60/100
step 4 number of successful adv: 66/100
step 5 number of successful adv: 75/100
step 6 number of successful adv: 77/100
step 7 number of successful adv: 79/100
step 8 number of successful adv: 80/100
step 9 number of successful adv: 81/100
loss: 7.478770732879639
0.8103
0.19753086419753085
0.8103
37
step 0 number of successful adv: 28/100
step 1 number of successful adv: 38/100
step 2 number of successful adv: 51/100
step 3 number of successful adv: 57/100
step 4 number of successful adv: 65/100
step 5 number of successful adv: 74/100
step 6 number of successful adv: 75/100
step 7 number of successful adv: 81/100
step 8 number of successful adv: 78/100
step 9 number of successf

step 6 number of successful adv: 60/100
step 7 number of successful adv: 62/100
step 8 number of successful adv: 62/100
step 9 number of successful adv: 68/100
loss: 6.546164035797119
0.7973
0.22058823529411764
0.7973
54
step 0 number of successful adv: 22/100
step 1 number of successful adv: 33/100
step 2 number of successful adv: 41/100
step 3 number of successful adv: 56/100
step 4 number of successful adv: 59/100
step 5 number of successful adv: 66/100
step 6 number of successful adv: 68/100
step 7 number of successful adv: 70/100
step 8 number of successful adv: 72/100
step 9 number of successful adv: 72/100
loss: 7.0228471755981445
0.7973
0.2222222222222222
0.7973
55
step 0 number of successful adv: 25/100
step 1 number of successful adv: 30/100
step 2 number of successful adv: 38/100
step 3 number of successful adv: 48/100
step 4 number of successful adv: 57/100
step 5 number of successful adv: 62/100
step 6 number of successful adv: 65/100
step 7 number of successful adv: 69/10

step 4 number of successful adv: 62/100
step 5 number of successful adv: 68/100
step 6 number of successful adv: 69/100
step 7 number of successful adv: 69/100
step 8 number of successful adv: 70/100
step 9 number of successful adv: 71/100
loss: 6.823429107666016
0.7881
0.1267605633802817
0.7881
72
step 0 number of successful adv: 32/100
step 1 number of successful adv: 41/100
step 2 number of successful adv: 50/100
step 3 number of successful adv: 61/100
step 4 number of successful adv: 64/100
step 5 number of successful adv: 73/100
step 6 number of successful adv: 74/100
step 7 number of successful adv: 79/100
step 8 number of successful adv: 74/100
step 9 number of successful adv: 80/100
loss: 7.210648536682129
0.7878
0.1375
0.7878
73
step 0 number of successful adv: 37/100
step 1 number of successful adv: 41/100
step 2 number of successful adv: 47/100
step 3 number of successful adv: 57/100
step 4 number of successful adv: 64/100
step 5 number of successful adv: 69/100
step 6 numbe

step 2 number of successful adv: 45/100
step 3 number of successful adv: 55/100
step 4 number of successful adv: 63/100
step 5 number of successful adv: 66/100
step 6 number of successful adv: 70/100
step 7 number of successful adv: 72/100
step 8 number of successful adv: 73/100
step 9 number of successful adv: 71/100
loss: 6.17366886138916
0.7806
0.16901408450704225
0.7806
90
step 0 number of successful adv: 24/100
step 1 number of successful adv: 33/100
step 2 number of successful adv: 40/100
step 3 number of successful adv: 46/100
step 4 number of successful adv: 56/100
step 5 number of successful adv: 59/100
step 6 number of successful adv: 60/100
step 7 number of successful adv: 62/100
step 8 number of successful adv: 62/100
step 9 number of successful adv: 64/100
loss: 6.6740217208862305
0.7794
0.140625
==> Saving...
0.7794
91
step 0 number of successful adv: 29/100
step 1 number of successful adv: 36/100
step 2 number of successful adv: 42/100
step 3 number of successful adv: 50