# Probabilistically compact loss with logits constraints

In [1]:
import os
import sys
sys.path.append('..')
import yaml
import shutil
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, random_split

from utils import *

## Parameter setting

In [2]:
gpu = '0,1,2,3'
dataset = 'cifar10'
model_type = 'wrn34-10'
checkpoint = './checkpoint/pc_with_logits_const_at/%s/%s' % (model_type, dataset)
num_classes = 10
lr = 0.01
batch_size = 256
total_epochs = 300
xi = 0.995
lam = 0.05
epsilon = 8/255
alpha = 2/255
num_repeats = 10
warm_up = total_epochs//2

## Inner maximization

In [3]:
def inner_max(model, xent, inputs, targets, epsilon, alpha, num_repeats):
    noise = torch.FloatTensor(inputs.shape).uniform_(-epsilon, epsilon).cuda()
    x = torch.clamp(inputs + noise, min=0, max=1)
    
    for _ in range(num_repeats):
        x.requires_grad_()
        logits = model(x)
        loss = xent(logits, targets)
        loss.backward()
        grads = x.grad.data
        x = x.detach() + alpha*torch.sign(grads).detach()
        x = torch.min(torch.max(x, inputs-epsilon), inputs+epsilon).clamp(min=0, max=1)
    return x

## Training (Outer minimization)

In [4]:
def training(epoch, model, dataloader, optimizer, num_classes, xi, lam, warm_up, 
             epsilon=8/255, alpha=2/255, num_repeats=10, use_at=False):
    model.train()
    total = 0
    total_loss = 0
    total_correct = 0
        
    hinge = nn.ReLU()
    xent = nn.CrossEntropyLoss()
    for idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        batch = inputs.size(0)
        
        if use_at:
            x = inner_max(model, xent, inputs, targets, epsilon, alpha, num_repeats)
        else:
            x = inputs.clone()
        logits = model(x)
        
        if warm_up < epoch:
            classes = torch.arange(num_classes)[None,:].repeat(batch,1).cuda()
            labels = targets[:,None].repeat(1,num_classes)
            mask = logits.softmax(dim=1).eq(labels)
            gt_probs = torch.sum(mask * logits.softmax(dim=1), dim=1, keepdim=True).repeat(1,num_classes)
            diff = false_probs + xi - gt_probs
            zeros = torch.zeros_like(gt_probs)
            pc_loss = (torch.sum(torch.max(zeros, diff))/batch - xi) / (num_classes-1)
            #false_probs = logits.softmax(dim=1)[classes!=targets[:,None]].view(batch, logits.size(1)-1)
            #gt_probs = logits.softmax(dim=1)[classes==targets[:,None]].unsqueeze(1).repeat(1, num_classes-1)
            # (torch.max(zeros, diff).sum() / batch_size - self.margin)  / (self.num_classes - 1)
            
            true_logits = mask * logits.softmax(dim=1)
            gt_mask = -1e3*mask
            false_logits = logits - true_logits + gt_mask
            
            false_logits_max = F.softmax(false_logits, dim=1)
            
            diff = false_logits_max * (gt_probs - logits)
            lc_loss = torch.max(zeros, diff).sum() / batch
            #gt_logits = logits[classes == targets[:,None]]
            #false_logits = logits[classes != targets[:,None]].view(batch,num_classes-1)
            #top2_logits = torch.topk(false_logits, k=2)[0][:,0]
            #const = torch.sum(hinge(gt_logits - top2_logits)) / batch
            
            loss = pc_loss + lam*lc_loss
        else:
            loss = xent(logits, targets)
                
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
        total += batch
        total_loss += loss.item()
        num_correct = torch.argmax(logits.data, dim=1).eq(targets.data).cpu().sum().item()
        total_correct += num_correct
        
        if idx % 100 == 0:
            print('Epoch %d [%d/%d] | loss: %.4f (avg: %.4f) | acc: %.4f (avg: %.4f) |'\
                  % (epoch, idx, len(dataloader), loss.item(), total_loss/len(dataloader),
                     num_correct/batch, total_correct/total))

In [5]:
def evaluation(epoch, model, dataloader, alpha, epsilon, num_repeats):
    model.eval()
    total_correct_nat = 0
    total_correct_adv = 0
    
    xent = nn.CrossEntropyLoss()
    for samples in dataloader:
        inputs, targets = samples[0].cuda(), samples[1].cuda()
        batch = inputs.size(0)
        with torch.enable_grad():
            x = inner_max(model, xent, inputs, targets, epsilon, alpha, num_repeats)
            
        with torch.no_grad():
            logits_nat = model(inputs)
            logits_adv = model(x)
        
        total_correct_nat += torch.argmax(logits_nat.data, dim=1).eq(targets.data).cpu().sum().item()
        total_correct_adv += torch.argmax(logits_adv.data, dim=1).eq(targets.data).cpu().sum().item()
        
    print('Validation | acc (nat): %.4f | acc (rob): %.4f |' % (total_correct_nat / len(dataloader.dataset),
                                                                total_correct_adv / len(dataloader.dataset)))
    return (total_correct_nat / len(dataloader.dataset)), (total_correct_adv / len(dataloader.dataset))

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
os.makedirs(checkpoint, exist_ok=True)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()])
train_dataset, _ = get_dataloader(dataset, batch_size)
num_samples = len(train_dataset)
num_samples_for_train = int(num_samples * 0.98)
num_samples_for_valid = num_samples - num_samples_for_train
train_set, valid_set = random_split(train_dataset, [num_samples_for_train, num_samples_for_valid])
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=False)
valid_dataloader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)

model = nn.DataParallel(get_network(model_type, num_classes).cuda())
optimizer = optim.Adam(model.parameters(), lr=lr)

#adjust_learning_rate = lr_scheduler.MultiStepLR(optimizer, scheduler, gamma=0.1)
best_acc_nat, best_acc_rob = 0, 0

for epoch in range(total_epochs):
    training(epoch, model, train_dataloader, optimizer, num_classes, xi, lam, warm_up, epsilon, alpha, num_repeats, use_at=True)
    test_acc_nat, test_acc_rob = evaluation(epoch, model, valid_dataloader, alpha, epsilon, num_repeats)
        
    is_best = best_acc_nat < test_acc_nat and best_acc_rob < test_acc_rob
    best_acc_nat = max(best_acc_nat, test_acc_nat)
    best_acc_rob = max(best_acc_rob, test_acc_rob)
    save_checkpoint = {'state_dict': model.state_dict(),
                       'best_acc_nat': best_acc_nat,
                       'best_acc_rob': best_acc_rob,
                       'optimizer': optimizer.state_dict(),
                       'model_type': model_type,
                       'dataset': dataset}
    torch.save(save_checkpoint, os.path.join(checkpoint, 'model'))
    if is_best:
        torch.save(save_checkpoint, os.path.join(checkpoint, 'best_model'))
    #adjust_learning_rate.step()

Files already downloaded and verified
Files already downloaded and verified
Epoch 0 [0/192] | loss: 3.1064 (avg: 0.0162) | acc: 0.0000 (avg: 0.0000) |
Epoch 0 [100/192] | loss: 2.1563 (avg: 1.2056) | acc: 0.2422 (avg: 0.1538) |
Validation | acc (nat): 0.2460 | acc (rob): 0.2000 |
Epoch 1 [0/192] | loss: 2.1259 (avg: 0.0111) | acc: 0.2031 (avg: 0.2031) |
Epoch 1 [100/192] | loss: 2.1147 (avg: 1.1124) | acc: 0.2070 (avg: 0.2141) |
Validation | acc (nat): 0.3030 | acc (rob): 0.2240 |
Epoch 2 [0/192] | loss: 2.0152 (avg: 0.0105) | acc: 0.2383 (avg: 0.2383) |
Epoch 2 [100/192] | loss: 1.9911 (avg: 1.0537) | acc: 0.2461 (avg: 0.2451) |
Validation | acc (nat): 0.3730 | acc (rob): 0.2460 |
Epoch 3 [0/192] | loss: 1.9719 (avg: 0.0103) | acc: 0.2383 (avg: 0.2383) |
Epoch 3 [100/192] | loss: 1.8483 (avg: 1.0183) | acc: 0.3242 (avg: 0.2747) |
Validation | acc (nat): 0.4000 | acc (rob): 0.2780 |
Epoch 4 [0/192] | loss: 1.9239 (avg: 0.0100) | acc: 0.2539 (avg: 0.2539) |
Epoch 4 [100/192] | loss: 1.8