# Geometry-Awere Instance-Reweighted Adversarial Training (GAIRAT)

In [1]:
import os
import sys
sys.path.append('..')
import yaml
import shutil
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, random_split

from utils import *

## Parameter setting

In [2]:
gpu = '4'
dataset = 'cifar10'
model_type = 'wrn34-10'
checkpoint = './checkpoint/%s/%s' % (model_type, dataset)
num_classes = 10
lr = 0.01
momentum = 0.9
weight_decay = 0.0035
batch_size = 128
total_epochs = 100
lam = -1
epsilon = 8/255
alpha = 2/255
num_repeats = 10
warm_up = 60

## Inner maximization

In [3]:
def inner_max(model, xent, inputs, targets, epsilon, alpha, num_repeats):
    noise = torch.FloatTensor(inputs.shape).uniform_(-epsilon, epsilon).cuda()
    x = torch.clamp(inputs + noise, min=0, max=1)
    kappa = torch.zeros(inputs.size(0)).cuda()
    
    for _ in range(num_repeats):
        x.requires_grad_()
        logits = model(x)
        kappa += logits.softmax(dim=1).argmax(dim=1).eq(targets)
        loss = xent(logits, targets)
        loss.backward()
        grads = x.grad.data
        x = x.detach() + alpha*torch.sign(grads).detach()
        x = torch.min(torch.max(x, inputs-epsilon), inputs+epsilon).clamp(min=0, max=1)
    return x, kappa

## Training (Outer minimization)

In [4]:
def training(epoch, model, dataloader, optimizer, num_classes, warm_up, 
             lam=-1, epsilon=8/255, alpha=2/255, num_repeats=10):
    model.train()
    total = 0
    total_loss = 0
    total_correct = 0
        
    tanh = nn.Tanh()
    xent = nn.CrossEntropyLoss()
    for idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        batch = inputs.size(0)
        
        x, kappa = pgd(model, xent, inputs, targets, epsilon, alpha, num_repeats)
        logits = model(x)
        s = (1 + tanh(lam + 5*(1 - 2*kappa/num_repeats)))/2
        s = s/torch.sum(weights)
        
        if warm_up < epoch:
            class_index = torch.arange(logits.size(1))[None,:].repeat(batch,1).cuda()
            loss = -torch.sum(s * torch.log_softmax(logits, dim=1)[class_index==targets[:,None]])/batch
        else:
            loss = xent(logits, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
        total += batch
        total_loss += loss.item()
        num_correct = torch.argmax(logits.data, dim=1).eq(targets.data).cpu().sum().item()
        total_correct += num_correct
        
        if idx % 100 == 0:
            print('Epoch %d [%d/%d] | loss: %.4f (avg: %.4f) | acc: %.4f (avg: %.4f) |'\
                  % (epoch, idx, len(dataloader), loss.item(), total_loss/len(dataloader),
                     num_correct/batch, total_correct/total))

In [5]:
def evaluation(epoch, model, dataloader, alpha, epsilon, num_repeats):
    model.eval()
    total_correct_nat = 0
    total_correct_adv = 0
    
    xent = nn.CrossEntropyLoss()
    for samples in dataloader:
        inputs, targets = samples[0].cuda(), samples[1].cuda()
        batch = inputs.size(0)
        with torch.enable_grad():
            x = pgd(model, xent, inputs, targets, epsilon, alpha, num_repeats)
            
        with torch.no_grad():
            logits_nat = model(inputs)
            logits_adv = model(x)
        
        total_correct_nat += torch.argmax(logits_nat.data, dim=1).eq(targets.data).cpu().sum().item()
        total_correct_adv += torch.argmax(logits_adv.data, dim=1).eq(targets.data).cpu().sum().item()
        
    print('Validation | acc (nat): %.4f | acc (rob): %.4f |' % (total_correct_nat / len(dataloader.dataset),
                                                                total_correct_adv / len(dataloader.dataset)))
    return (total_correct_nat / len(dataloader.dataset)), (total_correct_adv / len(dataloader.dataset))

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
os.makedirs(checkpoint, exist_ok=True)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()])
train_dataset, _ = get_dataloader(dataset, batch_size)
num_samples = len(train_dataset)
num_samples_for_train = int(num_samples * 0.98)
num_samples_for_valid = num_samples - num_samples_for_train
train_set, valid_set = random_split(train_dataset, [num_samples_for_train, num_samples_for_valid])
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=False)
valid_dataloader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)

model = nn.DataParallel(get_network(model_type, num_classes).cuda())
optimizer = optim.SGD(model.parameters(),lr=lr, momentum=momentum, weight_decay=weight_decay)

scheduler = [int(total_epochs*0.5), int(total_epochs*0.75)]
adjust_learning_rate = lr_scheduler.MultiStepLR(optimizer, scheduler, gamma=0.1)
best_acc_nat, best_acc_rob = 0, 0

for epoch in range(total_epochs):
    training(epoch, model, train_dataloader, optimizer, num_classes, warm_up, lam, epsilon, alpha, num_repeats)
    test_acc_nat, test_acc_rob = evaluation(epoch, model, valid_dataloader)
        
    is_best = best_acc_nat < test_acc_nat and best_acc_rob < test_acc_rob
    best_acc_nat = max(best_acc_nat, test_acc_nat)
    best_acc_rob = max(best_acc_rob, test_acc_rob)
    save_checkpoint = {'state_dict': model.state_dict(),
                       'best_acc': best_acc,
                       'test_acc': test_acc,
                       'optimizer': optimizer.state_dict(),
                       'model_type': model_type,
                       'dataset': dataset}
    torch.save(save_checkpoint, os.path.join(checkpoint, 'model'))
    if is_best:
        torch.save(save_checkpoint, os.path.join(checkpoint, 'best_model'))
    adjust_learning_rate.step()

Files already downloaded and verified
Files already downloaded and verified


TypeError: softmax() received an invalid combination of arguments - got (dim=int, ), but expected one of:
 * (Tensor input, name dim, *, torch.dtype dtype)
 * (Tensor input, int dim, torch.dtype dtype)
