# Misclassification Aware Adversarial Training (MART)

In [1]:
import os
import sys
sys.path.append('..')
import yaml
import shutil
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, random_split

from utils import *

## Parameter setting

In [2]:
gpu = '0,1,2,3'
dataset = 'cifar10'
model_type = 'wrn34-10'
checkpoint = './checkpoint/mart/%s/%s' % (model_type, dataset)
num_classes = 10
lr = 0.01
momentum = 0.9
weight_decay = 3.5e-3
batch_size = 128
total_epochs = 100
beta = 5.0
epsilon = 8/255
alpha = 2/255
num_repeats = 10

## Inner maximization

In [3]:
def inner_max(model, xent, inputs, targets, epsilon, alpha, num_repeats):
    noise = torch.FloatTensor(inputs.shape).uniform_(-1, 1).cuda()
    x = inputs + noise
    
    for _ in range(num_repeats):
        x.requires_grad_()
        logits = model(x)
        loss = xent(logits, targets)
        loss.backward()
        grads = x.grad.data
        x = x.detach() + alpha*torch.sign(grads).detach()
        x = torch.min(torch.max(x, inputs-epsilon), inputs+epsilon).clamp(min=0, max=1)
    return x

## Training (Outer minimization)

In [4]:
def training(epoch, model, dataloader, optimizer, num_classes, 
             beta=6.0, epsilon=8/255, alpha=2/255, num_repeats=10):
    model.train()
    total = 0
    total_loss = 0
    total_correct = 0
        
    kl = nn.KLDivLoss(reduction='none')
    xent = nn.CrossEntropyLoss()
    for idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        batch = inputs.size(0)
        
        x = inner_max(model, xent, inputs, targets, epsilon, alpha, num_repeats)
        logits_adv = model(x)
        logits_nat = model(inputs)
        
        classes = torch.arange(num_classes)[None,:].repeat(batch,1).cuda()
        log_softmax_gt = torch.log_softmax(logits_adv, dim=1)[classes==targets[:,None]].unsqueeze(1)
        false_probs = torch.log_softmax(logits_adv, dim=1)[classes!=targets[:,None]].view(batch, num_classes-1)
        top2_probs = torch.topk(false_probs, k=1, dim=1).values
        boosted_xent_loss = torch.sum(-log_softmax_gt - torch.log(1 - top2_probs))/batch
        kl_loss = torch.sum(kl(torch.log_softmax(logits_adv, dim=1),
                            torch.softmax(logits_nat, dim=1)), dim=1)
        loss = boosted_xent_loss + beta*torch.sum(kl_loss*(1-logits_nat.softmax(dim=1)[classes==targets[:,None]]))/batch
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
        total += batch
        total_loss += loss.item()
        num_correct = torch.argmax(logits_adv.data, dim=1).eq(targets.data).cpu().sum().item()
        total_correct += num_correct
        
        if idx % 100 == 0:
            print('Epoch %d [%d/%d] | loss: %.4f (avg: %.4f) | acc: %.4f (avg: %.4f) |'\
                  % (epoch, idx, len(dataloader), loss.item(), total_loss/len(dataloader),
                     num_correct/batch, total_correct/total))

In [5]:
def evaluation(epoch, model, dataloader, alpha, epsilon, num_repeats):
    model.eval()
    total_correct_nat = 0
    total_correct_adv = 0
    
    xent = nn.CrossEntropyLoss()
    for samples in dataloader:
        inputs, targets = samples[0].cuda(), samples[1].cuda()
        batch = inputs.size(0)
        with torch.enable_grad():
            x = inner_max(model, xent, inputs, targets, epsilon, alpha, num_repeats)
            
        with torch.no_grad():
            logits_nat = model(inputs)
            logits_adv = model(x)
        
        total_correct_nat += torch.argmax(logits_nat.data, dim=1).eq(targets.data).cpu().sum().item()
        total_correct_adv += torch.argmax(logits_adv.data, dim=1).eq(targets.data).cpu().sum().item()
        
    print('Validation | acc (nat): %.4f | acc (rob): %.4f |' % (total_correct_nat / len(dataloader.dataset),
                                                                total_correct_adv / len(dataloader.dataset)))
    return (total_correct_nat / len(dataloader.dataset)), (total_correct_adv / len(dataloader.dataset))

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
os.makedirs(checkpoint, exist_ok=True)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()])
train_dataset, _ = get_dataloader(dataset, batch_size)
num_samples = len(train_dataset)
num_samples_for_train = int(num_samples * 0.98)
num_samples_for_valid = num_samples - num_samples_for_train
train_set, valid_set = random_split(train_dataset, [num_samples_for_train, num_samples_for_valid])
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=False)
valid_dataloader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)

model = nn.DataParallel(get_network(model_type, num_classes).cuda())
optimizer = optim.SGD(model.parameters(),lr=lr, momentum=momentum, weight_decay=weight_decay)

scheduler = [75, 90]
adjust_learning_rate = lr_scheduler.MultiStepLR(optimizer, scheduler, gamma=0.1)
best_acc_nat, best_acc_rob = 0, 0

for epoch in range(total_epochs):
    training(epoch, model, train_dataloader, optimizer, num_classes, beta, epsilon, alpha, num_repeats)
    test_acc_nat, test_acc_rob = evaluation(epoch, model, valid_dataloader, alpha, epsilon, num_repeats)
        
    is_best = best_acc_nat < test_acc_nat and best_acc_rob < test_acc_rob
    best_acc_nat = max(best_acc_nat, test_acc_nat)
    best_acc_rob = max(best_acc_rob, test_acc_rob)
    save_checkpoint = {'state_dict': model.state_dict(),
                       'best_acc_nat': best_acc_nat,
                       'best_acc_rob': best_acc_rob,
                       'optimizer': optimizer.state_dict(),
                       'model_type': model_type,
                       'dataset': dataset}
    torch.save(save_checkpoint, os.path.join(checkpoint, 'model'))
    if is_best:
        torch.save(save_checkpoint, os.path.join(checkpoint, 'best_model'))
    adjust_learning_rate.step()

Files already downloaded and verified
Files already downloaded and verified
Epoch 0 [0/383] | loss: 2.1897 (avg: 0.0057) | acc: 0.0000 (avg: 0.0000) |
Epoch 0 [100/383] | loss: 1.1616 (avg: 0.3332) | acc: 0.0781 (avg: 0.0726) |
Epoch 0 [200/383] | loss: 1.1334 (avg: 0.6308) | acc: 0.1562 (avg: 0.1088) |
Epoch 0 [300/383] | loss: 1.0833 (avg: 0.9230) | acc: 0.2109 (avg: 0.1272) |
Validation | acc (nat): 0.2200 | acc (rob): 0.1840 |
Epoch 1 [0/383] | loss: 1.1087 (avg: 0.0029) | acc: 0.1719 (avg: 0.1719) |
Epoch 1 [100/383] | loss: 1.0971 (avg: 0.2894) | acc: 0.1484 (avg: 0.1866) |
Epoch 1 [200/383] | loss: 1.0590 (avg: 0.5744) | acc: 0.2422 (avg: 0.1947) |
Epoch 1 [300/383] | loss: 1.0661 (avg: 0.8577) | acc: 0.1484 (avg: 0.1975) |
Validation | acc (nat): 0.2510 | acc (rob): 0.1960 |
Epoch 2 [0/383] | loss: 1.0677 (avg: 0.0028) | acc: 0.2344 (avg: 0.2344) |
Epoch 2 [100/383] | loss: 1.0950 (avg: 0.2844) | acc: 0.2344 (avg: 0.2251) |
Epoch 2 [200/383] | loss: 1.0581 (avg: 0.5645) | acc: 