## Step 1: Dataset and DataLoader

In [1]:
# !pip install tensorboardX

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt

from torchvision import datasets, transforms
# from tensorboardX import SummaryWriter

use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 64

np.random.seed(42)
torch.manual_seed(42)


## Dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=False, transform=transforms.Compose(
    [transforms.ToTensor()]
))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=False, transform=transforms.Compose(
    [transforms.ToTensor()]
))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Step 2: IBP-Modified Fully Connected Network

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 50)
        self.fc2 = nn.Linear(50,50)
        self.fc3 = nn.Linear(50,50)
        self.output = nn.Linear(50, 10)
    
    def _prop_affine(self, l_x, u_x, W, b):
        sum = (u_x + l_x) / 2
        dif = (u_x - l_x) / 2
        
        # print(sum.shape, W.T.shape, b.unsqueeze(0).shape)
        pos = torch.mm(sum, W.T) + b.unsqueeze(0)
        # print(pos.shape, '---------------------')
        neg = torch.mm(dif, torch.abs(W).T)
        
        l_x = pos - neg
        u_x = pos + neg
        return l_x, u_x
    
    def _prop_relu(self, l_x, u_x):
        l_x = torch.max(l_x, torch.zeros_like(l_x))
        u_x = torch.max(u_x, torch.zeros_like(u_x))
        return l_x, u_x

    def forward(self, x, l_x, u_x):
        x = x.view((-1, 28*28))
        l_x, u_x = l_x.view((-1, 28*28)), u_x.view((-1, 28*28))
        
        l_x, u_x = self._prop_affine(l_x, u_x, self.fc1.weight, self.fc1.bias)
        # print('layer1----------------\n', l_x, u_x)
        l_x, u_x = self._prop_relu(l_x, u_x)
        l_x, u_x = self._prop_affine(l_x, u_x, self.fc2.weight, self.fc2.bias)
        # print('layer2----------------\n', l_x, u_x)
        l_x, u_x = self._prop_relu(l_x, u_x)
        l_x, u_x = self._prop_affine(l_x, u_x, self.fc3.weight, self.fc3.bias)
        # print('layer3----------------\n', l_x, u_x)
        l_x, u_x = self._prop_relu(l_x, u_x)
        l_y, u_y = self._prop_affine(l_x, u_x, self.output.weight, self.output.bias)
        
        x1 = F.relu(self.fc1(x))
        x2 = F.relu(self.fc2(x1))
        x3 = F.relu(self.fc3(x2))
        y = self.output(x3)
        
        return y, l_y, u_y

model = Net()

## Step 3: Loss Function and Optimizer

In [3]:
lr = 0.001

def criterion(x, l_x, u_x, y, kappa):
    cel = nn.CrossEntropyLoss()
    l_fit = cel(x, y)
    
    z = u_x.clone()
    z[:, y] = l_x[:, y]
    l_spec = cel(z, y)
    return kappa*l_fit + (1-kappa)*l_spec

optimizer = optim.SGD(model.parameters(), lr=lr)

## Step 4: Training and Testing

In [4]:
import time

num_epochs = 35
kappa = 0.5
e_train = 0.1
warmup = 15
max_e_epoch = 23

model = model.to(device)

def test(model, eps):
    model.eval()
    num_correct = 0
    num_total = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            yhat, _, _ = model(x, x-eps, x+eps)
            _, yhat_label = torch.max(yhat, 1)
            
            num_total += y.shape[0]
            num_correct += (y == yhat_label).sum().item()
    print(f"Accuracy: {num_correct / num_total * 100}%")

start_time = time.time()
for epoch in range(num_epochs):
    model.train()
    sum_loss = 0
    
    now_e_epoch = (
        0 if epoch <= warmup
        else epoch - warmup if warmup < epoch <= max_e_epoch
        else max_e_epoch - warmup
    )
    gradual_epochs = max_e_epoch - warmup
    eps = e_train * (now_e_epoch / gradual_epochs)
    kap = 1 * (1 - now_e_epoch / gradual_epochs) + kappa * (now_e_epoch / gradual_epochs)
    print(eps, kap)
    
    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        
        yhat, l_yhat, u_yhat = model(x, x-eps, x+eps)
        loss = criterion(yhat, l_yhat, u_yhat, y, kap)
        
        loss.backward()
        optimizer.step()
        
        sum_loss += loss.item()
    print(f"Epoch {epoch}, loss {sum_loss}")
    test(model, eps)
torch.save(model, "model.pth")
end_time = time.time()

print(end_time - start_time)

0.0 1.0
Epoch 0, loss 2160.806093454361
Accuracy: 9.69%
0.0 1.0
Epoch 1, loss 2155.5478847026825
Accuracy: 12.690000000000001%
0.0 1.0
Epoch 2, loss 2149.757318019867
Accuracy: 22.39%
0.0 1.0
Epoch 3, loss 2142.8633856773376
Accuracy: 22.91%
0.0 1.0
Epoch 4, loss 2133.7242019176483
Accuracy: 25.89%
0.0 1.0
Epoch 5, loss 2120.673803329468
Accuracy: 31.81%
0.0 1.0
Epoch 6, loss 2100.911470413208
Accuracy: 38.26%
0.0 1.0
Epoch 7, loss 2069.4437205791473
Accuracy: 42.35%
0.0 1.0
Epoch 8, loss 2017.1880428791046
Accuracy: 43.11%
0.0 1.0
Epoch 9, loss 1931.023352265358
Accuracy: 48.08%
0.0 1.0
Epoch 10, loss 1798.6601880788803
Accuracy: 57.31%
0.0 1.0
Epoch 11, loss 1616.3839721679688
Accuracy: 61.69%
0.0 1.0
Epoch 12, loss 1390.3584356307983
Accuracy: 65.29%
0.0 1.0
Epoch 13, loss 1164.7271184921265
Accuracy: 70.37%
0.0 1.0
Epoch 14, loss 992.906395316124
Accuracy: 73.4%
0.0 1.0
Epoch 15, loss 872.9108471870422
Accuracy: 75.39%
0.0125 0.9375
Epoch 16, loss 825.9373168945312
Accuracy: 76.05%

In [8]:
model = torch.load('model.pth')
model.to('cpu')
W1 = model.fc1.weight.detach().numpy()
b1 = model.fc1.bias.detach().numpy()

W2 = model.fc2.weight.detach().numpy()
b2 = model.fc2.bias.detach().numpy()

W3 = model.fc3.weight.detach().numpy()
b3 = model.fc3.bias.detach().numpy()

Wout = model.output.weight.detach().numpy()
bout = model.output.bias.detach().numpy()

def get_interval(index:int, l_inf:float):
    image, label = test_dataset[index]
    image = image.view(-1, 28*28)
    l_image = image - l_inf
    u_image = image + l_inf
    return l_image.numpy()[0], u_image.numpy()[0], label

def single_analysis(w:list, b:float, lx:list, ux:list):
    n = len(w) # dim. of x
    ly, uy = 0, 0
    for i in range(n):
        ux_bar = (lx[i] if w[i]<0 else ux[i])
        lx_bar = (lx[i] if w[i]>=0 else ux[i])
        ly += lx_bar * w[i] + b
        uy += ux_bar * w[i] + b
    return ly, uy

def multiple_analysis(w:list, b:list, lx:list, ux:list):
    ly, uy = list(), list()
    n = len(w) # dim. of y
    for i in range(n):
        lyi, uyi = single_analysis(w[i], b[i], lx, ux)
        ly.append(lyi)
        uy.append(uyi)
    return ly, uy
        
def interval_analysis(index:int, l_inf:float, silent=True):
    l0, u0, label = get_interval(index, l_inf)
    if not silent: display_together(l0, u0, 20)
    
    l1, u1 = multiple_analysis(W1.tolist(), b1.tolist(), l0, u0)
    if not silent: display_together(l1, u1, 20)
    
    l2, u2 = multiple_analysis(W2.tolist(), b2.tolist(), l1, u1)
    if not silent: display_together(l2, u2, 20)
    
    l3, u3 = multiple_analysis(W3.tolist(), b3.tolist(), l2, u2)
    if not silent: display_together(l3, u3, 20)
    
    lout, uout = multiple_analysis(Wout.tolist(), bout.tolist(), l3, u3)
    if not silent: display_together(lout, uout, 20, indent=True)
    
    return lout, uout, label

def display_together(l, u, max_index=None, indent=False):
    # print(len(l), len(u))
    if max_index is None:
        max_index = len(l)
    max_index = min(max_index, len(l))
    res = [(l[i], u[i]) for i in range(max_index)]
    if not indent:
        print(res)
    else:
        print('===================================')
        for r in res:
            print(r)
        print('===================================')
    
def single_verification(lout, uout, label):
    min_label = lout[label]
    is_always_greater = 1
    for i in range(len(lout)):
        if i != label:
            if uout[i] > min_label:
                is_always_greater = 0
    return is_always_greater

def multiple_verification(begin_i, end_i, l_inf):
    sum, total = 0, end_i - begin_i
    for i in range(begin_i, end_i):
        lout, uout, label = interval_analysis(i, l_inf)
        is_verified = single_verification(lout, uout, label)
        # print(f"index: {i}, label: {label}, is_verified: {is_verified}")
        sum += is_verified
    acc = sum / total
    print(f"l_inf: {l_inf}, acc: {acc * 100}%")
    return acc

  model = torch.load('model.pth')


In [9]:
import multiprocessing as mp

params = [(0, len(test_dataset), l_inf) for l_inf in np.linspace(0.01, 0.1, 10)]
with mp.Pool(processes=10) as pool:
    results = pool.starmap(multiple_verification, params)

In [10]:
results

[0.1028, 0.1028, 0.1028, 0.1028, 0.1028, 0.1028, 0.1028, 0.1028, 0.0, 0.0]

In [11]:
class PGD_optimizer():
    # ref: https://www.geeksforgeeks.org/custom-optimizers-in-pytorch/
    def __init__(self, params, original_params, epsilon, lr=1e-2):
        self.params = list(params)
        self.original_params = original_params
        self.epsilon = epsilon
        self.lr = lr
                
    def step(self):
        with torch.no_grad():
            for param in self.params:
                if param.grad is not None:
                    if torch.max(torch.abs(param.grad.data)) == 0:
                        print("warning: all gradients are 0")
                    param.data = param.data - self.lr * param.grad.data
                else:
                    print("warning: param.grad is None")
                param.data = torch.clamp(
                    param.data, 
                    min = (self.original_params.data - self.epsilon), 
                    max = (self.original_params.data + self.epsilon)
                )
    
    def zero_grad(self):
        for param in self.params:
            if param.grad is not None:
                param.grad.zero_()

In [25]:
from torch.utils.data import Subset

num_samples = 10000
mini_dataset = Subset(test_dataset, range(num_samples))
mini_loader = torch.utils.data.DataLoader(mini_dataset, batch_size=num_samples, shuffle=False)
images_labels = next(iter(mini_loader))

target_label = torch.LongTensor([1]*num_samples).to(device)
num_epochs = 500

model.eval()
model.to(device)

def pgd_attack(epsilon):
    print(f'epsilon: {epsilon}')
    
    images, labels = images_labels
    images = images.clone().detach().to(device)
    labels = labels.to(device)
    images_original = images.clone().detach()
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(num_epochs):
        optimizer = PGD_optimizer([images], images_original, epsilon)
        images.requires_grad = True
        optimizer.zero_grad()
        yhat, _, _ = model(images, images, images)
        loss = criterion(yhat, target_label)
        # print(epoch, loss.item())
        loss.backward()
        # print(images.grad.data)
        optimizer.step()
        images = images.detach().clone()
    
    assert torch.max(torch.abs(images_original - images)) != 0
    assert torch.max(torch.abs(images_original - images)) <= 1.1 * epsilon
    
    yhat = torch.max(model(images, images, images)[0], 1)[1]
    robust_acc = torch.sum(yhat==labels) / labels.shape[0]
    print(yhat.cpu().detach().numpy().tolist())
    print(labels.cpu().detach().numpy().tolist())
    print(robust_acc)
    return images

a = pgd_attack(1)

epsilon: 1
[9, 6, 1, 0, 6, 1, 9, 6, 6, 9, 0, 7, 6, 0, 1, 9, 6, 9, 7, 6, 9, 6, 6, 6, 6, 0, 7, 6, 0, 1, 1, 1, 7, 0, 9, 6, 7, 1, 7, 1, 1, 9, 7, 1, 7, 6, 1, 6, 6, 6, 6, 2, 6, 9, 6, 6, 6, 1, 6, 9, 7, 6, 6, 7, 2, 6, 6, 6, 6, 0, 9, 0, 0, 9, 1, 1, 1, 9, 1, 9, 6, 6, 6, 7, 6, 6, 9, 6, 6, 1, 9, 6, 2, 9, 1, 6, 1, 1, 0, 6, 6, 0, 9, 6, 6, 6, 0, 1, 6, 6, 7, 1, 1, 6, 1, 6, 7, 6, 7, 6, 6, 6, 9, 6, 9, 6, 0, 9, 9, 7, 6, 6, 6, 6, 9, 1, 0, 1, 6, 6, 6, 9, 9, 1, 7, 1, 7, 6, 0, 1, 6, 2, 2, 9, 1, 6, 6, 0, 7, 6, 7, 0, 9, 6, 6, 7, 6, 6, 1, 6, 6, 9, 9, 9, 6, 1, 1, 6, 1, 9, 1, 6, 9, 0, 6, 6, 6, 9, 0, 1, 1, 1, 0, 6, 0, 7, 1, 6, 6, 6, 6, 0, 1, 1, 1, 1, 2, 7, 6, 6, 6, 9, 6, 1, 6, 0, 7, 6, 9, 9, 7, 6, 6, 9, 1, 9, 9, 6, 1, 1, 6, 1, 6, 9, 9, 6, 6, 6, 6, 1, 9, 6, 6, 7, 1, 6, 0, 6, 6, 1, 6, 1, 6, 2, 9, 9, 1, 1, 6, 0, 2, 1, 9, 9, 9, 1, 6, 1, 6, 0, 1, 0, 1, 7, 6, 6, 1, 2, 6, 1, 1, 6, 1, 9, 6, 6, 6, 6, 1, 7, 6, 7, 6, 6, 0, 6, 0, 0, 6, 2, 1, 6, 1, 6, 9, 0, 1, 9, 7, 9, 2, 0, 0, 9, 1, 6, 6, 1, 6, 9, 1, 1, 6, 1, 6, 7, 6, 0, 1, 1