# CNN with Radial Basis Function (RBF) kernel

In [4]:
import os
import sys
sys.path.append('..')
import yaml
import shutil
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, random_split

from utils import *

## Parameter setting

In [5]:
gpu = '4'
dataset = 'mnist'
model_type = 'rbf-cnn'
checkpoint = './checkpoint/rbf_cnn/%s/%s' % (dataset, model_type)
num_classes = 10
lr = 0.001
batch_size = 32
total_epochs = 12
epsilon = 80/255
alpha = 20/255
num_repeats = 10

## RBF kernel module

In [6]:
class RBF(nn.Module):
    def __init__(self, num_features, betas=2.0, use_gpu=True):
        super(RBF, self).__init__()
        if use_gpu:
            self.betas = nn.Parameter(torch.randn(num_features).cuda())
            self.center = nn.Parameter(torch.randn(num_features, num_features).cuda())
            self.A = nn.Parameter(betas*torch.eye(num_features, num_features).cuda())
        
            self.weight = nn.Parameter(torch.randn(num_features, num_features).cuda())
            self.bias = nn.Parameter(torch.randn(num_features).cuda())
        else:
            self.betas = nn.Parameter(torch.randn(num_features))
            self.center = nn.Parameter(torch.randn(num_features, num_features))
            self.A = nn.Parameter(betas*torch.eye(num_features, num_features))
        
            self.weight = nn.Parameter(torch.randn(num_features, num_features))
            self.bias = nn.Parameter(torch.randn(num_features))

        ## Parameter initialization
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)
        nn.init.constant_(self.betas, val=2.0)
        nn.init.uniform_(self.center, a=0.0, b=1.0)
        
    def forward(self, x):
        expanded_center = self.center[None,None,:,:]
        A = self.A + sys.float_info.epsilon
        psi = A.t() * A
        s = x[:,:,:,None].repeat(1,1,1,x.size(2)) - expanded_center
        dist = torch.sqrt(torch.sum(torch.tensordot(s, psi, dims=1) * s, dim=-1))
        mahalanobis = torch.exp(-self.betas * dist)
        
        out = torch.tensordot(mahalanobis, self.weight, dims=1)
        g = out + self.bias
        return g

## Main network architecture

In [7]:
class Kernel_trick(nn.Module):
    def __init__(self, num_classes):
        super(Kernel_trick, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=8, stride=2, padding=3)
        size = 14*14
        self.rbf1 = RBF(size)
        
        self.conv2 = nn.Conv2d(16*2, 32, kernel_size=6, stride=2)
        size = 5*5
        self.rbf2 = RBF(size)
        
        self.conv3 = nn.Conv2d(32*2, 32, kernel_size=5, stride=1)
        size = 1*1
        self.rbf3 = RBF(size)
        
        self.classifier = nn.Linear(32*2, num_classes)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        b = x.size(0)
        h = self.relu(self.conv1(x))
        g = self.rbf1(h.flatten(start_dim=2)).view(b,-1,h.size(2),h.size(3))
        
        h_cat = torch.cat((h, g), dim=1)
        h = self.relu(self.conv2(h_cat))
        g = self.rbf2(h.flatten(start_dim=2)).view(b,-1,h.size(2),h.size(3))
        
        h_cat = torch.cat((h, g), dim=1)
        h = self.relu(self.conv3(h_cat))
        g = self.rbf3(h.flatten(start_dim=2)).view(b,-1,h.size(2),h.size(3))
        
        h_cat = torch.cat((h, g), dim=1)
        out = self.classifier(h_cat.flatten(start_dim=1))
        return out

## Training phase

In [8]:
def training(epoch, model, dataloader, optimizer, num_classes):
    model.train()
    total = 0
    total_loss = 0
    total_correct = 0
        
    xent = nn.CrossEntropyLoss()
    for idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        batch = inputs.size(0)
        logits = model(inputs)
        loss = xent(logits, targets)
                
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
        total += batch
        total_loss += loss.item()
        num_correct = torch.argmax(logits.data, dim=1).eq(targets.data).cpu().sum().item()
        total_correct += num_correct
        
        if idx % 100 == 0:
            print('Epoch %d [%d/%d] | loss: %.4f (avg: %.4f) | acc: %.4f (avg: %.4f) |'\
                  % (epoch, idx, len(dataloader), loss.item(), total_loss/len(dataloader),
                     num_correct/batch, total_correct/total))

In [9]:
def evaluation(epoch, model, dataloader, alpha, epsilon, num_repeats):
    model.eval()
    total_correct_nat = 0
    total_correct_adv = 0
    
    xent = nn.CrossEntropyLoss()
    for samples in dataloader:
        inputs, targets = samples[0].cuda(), samples[1].cuda()
        batch = inputs.size(0)
        with torch.enable_grad():
            noise = torch.cuda.FloatTensor(inputs.shape).uniform_(-epsilon, epsilon)
            x = torch.clamp(inputs+noise, min=0, max=1)
            
            for _ in range(num_repeats):
                x.requires_grad_()
                logits = model(x)
                loss = xent(logits, targets)
                loss.backward()
                grads = x.grad.data
                x = x.detach() + alpha*torch.sign(grads).detach()
                x = torch.min(torch.max(x - epsilon), x + epsilon).clamp(min=0, max=1)
            
        with torch.no_grad():
            logits_nat = model(inputs)
            logits_adv = model(x)
        
        total_correct_nat += torch.argmax(logits_nat.data, dim=1).eq(targets.data).cpu().sum().item()
        total_correct_adv += torch.argmax(logits_adv.data, dim=1).eq(targets.data).cpu().sum().item()
        
    print('Validation | acc (nat): %.4f | acc (rob): %.4f |' % (total_correct_nat / len(dataloader.dataset),
                                                                total_correct_adv / len(dataloader.dataset)))
    return (total_correct_nat / len(dataloader.dataset)), (total_correct_adv / len(dataloader.dataset))

In [10]:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
os.makedirs(checkpoint, exist_ok=True)

train_dataset, _ = get_dataloader(dataset, batch_size, image_size=28)
num_samples = len(train_dataset)
num_samples_for_train = int(num_samples * 0.98)
num_samples_for_valid = num_samples - num_samples_for_train
train_set, valid_set = random_split(train_dataset, [num_samples_for_train, num_samples_for_valid])
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=False)
valid_dataloader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)

model = nn.DataParallel(Kernel_trick(num_classes).cuda())
optimizer = optim.Adam(model.parameters(), lr=lr)

best_acc_nat, best_acc_rob = 0, 0

for epoch in range(total_epochs):
    training(epoch, model, train_dataloader, optimizer, num_classes)
    test_acc_nat, test_acc_rob = evaluation(epoch, model, valid_dataloader, alpha, epsilon, num_repeats)
        
    is_best = best_acc_nat < test_acc_nat and best_acc_rob < test_acc_rob
    best_acc_nat = max(best_acc_nat, test_acc_nat)
    best_acc_rob = max(best_acc_rob, test_acc_rob)
    save_checkpoint = {'state_dict': model.state_dict(),
                       'best_acc_nat': best_acc_nat,
                       'best_acc_rob': best_acc_rob,
                       'optimizer': optimizer.state_dict(),
                       'model_type': model_type,
                       'dataset': dataset}
    torch.save(save_checkpoint, os.path.join(checkpoint, 'model'))
    if is_best:
        torch.save(save_checkpoint, os.path.join(checkpoint, 'best_model'))

Epoch 0 [0/1838] | loss: 2.3024 (avg: 0.0013) | acc: 0.0938 (avg: 0.0938) |
Epoch 0 [100/1838] | loss: 1.8402 (avg: 0.1156) | acc: 0.3125 (avg: 0.2203) |
Epoch 0 [200/1838] | loss: 1.0452 (avg: 0.1891) | acc: 0.5938 (avg: 0.3730) |
Epoch 0 [300/1838] | loss: 0.4394 (avg: 0.2338) | acc: 0.8750 (avg: 0.4949) |
Epoch 0 [400/1838] | loss: 0.3966 (avg: 0.2665) | acc: 0.9375 (avg: 0.5742) |
Epoch 0 [500/1838] | loss: 0.4528 (avg: 0.2949) | acc: 0.8438 (avg: 0.6266) |
Epoch 0 [600/1838] | loss: 0.4231 (avg: 0.3181) | acc: 0.8750 (avg: 0.6680) |
Epoch 0 [700/1838] | loss: 0.2976 (avg: 0.3379) | acc: 0.9375 (avg: 0.6995) |
Epoch 0 [800/1838] | loss: 0.4086 (avg: 0.3577) | acc: 0.8438 (avg: 0.7232) |
Epoch 0 [900/1838] | loss: 0.2564 (avg: 0.3750) | acc: 0.9375 (avg: 0.7429) |
Epoch 0 [1000/1838] | loss: 0.3460 (avg: 0.3917) | acc: 0.8750 (avg: 0.7596) |
Epoch 0 [1100/1838] | loss: 0.9061 (avg: 0.4063) | acc: 0.8438 (avg: 0.7738) |
Epoch 0 [1200/1838] | loss: 0.1152 (avg: 0.4201) | acc: 0.9688 (