<a href="https://colab.research.google.com/github/lucasl02/Adversarial-Training/blob/main/IBP_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install bound-propagation


# Setting up Model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt
from tqdm import tqdm

from torchvision import datasets, transforms
from bound_propagation import BoundModelFactory, HyperRectangle
from tqdm import trange

# from tensorboardX import SummaryWriter

use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 64

np.random.seed(42)
torch.manual_seed(42)


## Dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


class FashionMNISTNetwork(nn.Sequential):
    def __init__(self, *args):
        if args:
            # To support __get_index__ of nn.Sequential when slice indexing
            super().__init__(*args)
        else:
            img_size = 28 * 28
            classes = 10

            super().__init__(
                # nn.Flatten(),
                nn.Linear(img_size, 50),
                nn.ReLU(),
                nn.Linear(50, 50),
                nn.ReLU(),
                nn.Linear(50, classes)
            )
model = FashionMNISTNetwork().to(device)
model.train()
factory = BoundModelFactory()
model = factory.build(model)

#Normal Training

Training the Model without IBP

In [3]:
def train(net, num_epochs):
    print('')
    print('[TRAINING]')

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    for epoch in trange(num_epochs):

        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            images, labels = data
            # flatten images to expected dimensions
            images = images.view(-1, 28*28)
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.3f}')

def test(net):
    print('')
    print('[TEST]')
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data in test_loader:
            images, labels = data
            images = images.view(-1, 28*28)
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(f'Accuracy on images: {100 * correct / total}')




# Executing Normal Training

The clean accuracy of the model is 93.5

In [4]:
train(model, 10)


[TRAINING]


 10%|█         | 1/10 [00:07<01:06,  7.42s/it]

Epoch 1/10, Loss: 1.754


 20%|██        | 2/10 [00:14<00:58,  7.37s/it]

Epoch 2/10, Loss: 0.592


 30%|███       | 3/10 [00:21<00:51,  7.31s/it]

Epoch 3/10, Loss: 0.412


 40%|████      | 4/10 [00:29<00:43,  7.29s/it]

Epoch 4/10, Loss: 0.354


 50%|█████     | 5/10 [00:36<00:36,  7.27s/it]

Epoch 5/10, Loss: 0.322


 60%|██████    | 6/10 [00:43<00:28,  7.24s/it]

Epoch 6/10, Loss: 0.301


 70%|███████   | 7/10 [00:50<00:21,  7.23s/it]

Epoch 7/10, Loss: 0.284


 80%|████████  | 8/10 [00:58<00:14,  7.25s/it]

Epoch 8/10, Loss: 0.269


 90%|█████████ | 9/10 [01:05<00:07,  7.23s/it]

Epoch 9/10, Loss: 0.254


100%|██████████| 10/10 [01:12<00:00,  7.26s/it]

Epoch 10/10, Loss: 0.241





In [5]:
test(model)


[TEST]
Accuracy on images: 93.5


# IBP Training

In [6]:
def construct_transform():
    transform = transforms.Compose([
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Lambda(torch.flatten)
    ])

    # Identity transform - because cross entropy loss supports class indexing
    target_transform = transforms.Compose([])

    return transform, target_transform

def adversarial_logit(y_hat, y):
    batch_size = y.size(0)
    classes = torch.arange(10, device=y.device).unsqueeze(0).expand(batch_size, -1)
    mask = (classes == y.unsqueeze(-1)).to(dtype=y_hat.lower.dtype)

    # Take upper bound for logit of all but the correct class where you take the lower bound
    adversarial_logit = (1 - mask) * y_hat.upper + mask * y_hat.lower

    return adversarial_logit


def train_ibp(net):
    print('')
    print('[TRAINING]')

    transform, target_transform = construct_transform()

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

    k = 1.0
    running_eps = 0.0
    for epoch in trange(20):
        running_loss = 0.0
        running_cross_entropy = 0.0
        # e that starts from 0.0 and gradually increases to e_train = 0.1

        for i, (X, y) in enumerate(train_loader):
            X = X.view(-1, 28*28)
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad(set_to_none=True)

            y_hat = net(X)

            cross_entropy = criterion(y_hat, y)

            # z_k(e_train)
            bounds = net.ibp(HyperRectangle.from_eps(X, running_eps))
            # z^k(e_train)
            logit = adversarial_logit(bounds, y)

            loss = k * cross_entropy + (1 - k) * criterion(logit, y)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            running_cross_entropy += cross_entropy.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print(f'[{epoch + 1}, {i + 1:3d}] loss: {running_loss / 100:.3f}, cross entropy: {running_cross_entropy / 100:.3f}, k = {k}, eps = {running_eps}')
                running_loss = 0.0
                running_cross_entropy = 0.0

        k = max(k - 0.1, 0.5)
        # increase eps gradually to 0.1
        running_eps = min(running_eps + (0.1/20), 0.1)

@torch.no_grad()
def test_ibp(net):
    print('')
    print('[TEST]')

    transform, target_transform = construct_transform()

    correct = 0
    for i, (X, y) in enumerate(test_loader):
        X = X.view(-1, 28*28)
        X, y = X.to(device), y.to(device)

        y_hat = net(X)

        predicted = torch.argmax(y_hat, 1)
        correct += (predicted == y).sum().item()

    print(f'Accuracy: {correct / len(test_dataset):.3f}')


# Executing IBP Training

In [None]:
    net = FashionMNISTNetwork().to(device)

    factory = BoundModelFactory()
    net = factory.build(net)

    train_ibp(net)

In [8]:
test_ibp(net)


[TEST]
Accuracy: 0.935


# PGD Robustness Testing

In [9]:
def pgd_linf_untargeted(model, x, labels, k, eps, eps_step):
    model.eval()
    ce_loss = torch.nn.CrossEntropyLoss()
    adv_x = x.clone().detach()
    adv_x.requires_grad_(True)
    for _ in range(k):
        adv_x.requires_grad_(True)
        model.zero_grad()
        output = model(adv_x)
        # TODO: Calculate the loss
        loss = ce_loss(output, labels)
        loss.backward()
        # TODO: compute the adv_x
        grad = adv_x.grad
        adv_x = adv_x  + (eps_step * torch.sign(grad))
        # find delta, clamp with eps
        #linf so can check values individually
        # change has to be eps bubble at max: delta = [-eps, eps]
        delta = torch.clamp(adv_x - x, min = -eps, max= eps)
        # calmp to image domain: adv_x = [0,1]
        adv_x = torch.clamp(x + delta, min = 0.0, max = 1.0).detach().requires_grad_()

    return adv_x

def test_model_on_single_attack(model, attack='pgd_linf', eps=0.1, k = 10):
    model.eval()
    tot_test, tot_acc = 0.0, 0.0
    for batch_idx, (x_batch, y_batch) in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
        x_batch = x_batch.view(-1, 28*28)
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        #print('On batch: ', batch_idx)
        if attack == 'pgd_linf':
            # TODO: get x_adv untargeted pgd linf with eps, and eps_step=eps/4
            adv_x = pgd_linf_untargeted(model,x_batch,y_batch,k,eps = eps,eps_step= eps/4 )
            #
            with torch.no_grad():
                logits = model(adv_x)
                predicts = logits.argmax(dim =1 )
        else:
            pass

        # get the testing accuracy and update tot_test and tot_acc
        '''num got right = sum(predicts = y_batch)'''
        tot_acc += (predicts == y_batch).sum().item()
        '''num total'''
        tot_test += y_batch.size(0)
        #print('Accuracy So Far: %.5lf' % (tot_acc/tot_test), f'on {attack} attack with eps = {eps}')

    print('Robust accuracy %.5lf' % (tot_acc/tot_test), f'on {attack} attack with eps = {eps}')

## Normal Model

In [10]:
test_model_on_single_attack(model=model, attack = 'pgd_linf', eps = 8/225)


Evaluating: 100%|██████████| 157/157 [00:02<00:00, 59.43it/s]

Robust accuracy 0.80220 on pgd_linf attack with eps = 0.035555555555555556





## IBP Model

In [11]:
test_model_on_single_attack(model=net, attack = 'pgd_linf', eps = 8/225)


Evaluating: 100%|██████████| 157/157 [00:02<00:00, 60.45it/s]

Robust accuracy 0.90660 on pgd_linf attack with eps = 0.035555555555555556





# Box Verification

In [23]:
def get_bounds(eps):
    input_bounds = HyperRectangle.from_eps(x, eps)
    ibp_bounds = net.ibp(input_bounds)
    #print(ibp_bounds.lower)
    #print(ibp_bounds.upper)
    return ibp_bounds.lower, ibp_bounds.upper

def is_pass(lb, ub, true_labels):
    batch_size = lb.shape[0]
    passes = torch.ones(batch_size, dtype=torch.bool, device=device)

    for i in range(batch_size):
        # get logit of true class
        true_lower = lb[i, true_labels[i]]
        # Check all other classes
        for class_idx in range(10):
            if class_idx != true_labels[i]:
                # if upper bound of another class is greater than lowber bound of true class , False
                if ub[i, class_idx] >= true_lower:
                    passes[i] = False
                    break

    return passes

In [24]:
epsilons = np.linspace(0.01, 0.1, num=10)
for eps in epsilons:
    correct = 0
    for data, target in test_loader:
        #print(len(test_loader))
        x = data.view(-1, 28*28).to(device)
        y = target.to(device)
        # get bounds of output neurons after passing through starting bounds
        lb, ub = get_bounds(eps)

        # get model correctness of batch
        pass_batch = is_pass(lb, ub, y)
        correct += pass_batch.sum().item()

    print(f'Epsilon: {eps}, Accuracy: {correct/len(test_loader.dataset)}')

Epsilon: 0.01, Accuracy: 0.9231
Epsilon: 0.020000000000000004, Accuracy: 0.909
Epsilon: 0.030000000000000006, Accuracy: 0.8941
Epsilon: 0.04000000000000001, Accuracy: 0.873
Epsilon: 0.05000000000000001, Accuracy: 0.8507
Epsilon: 0.06000000000000001, Accuracy: 0.8256
Epsilon: 0.07, Accuracy: 0.7975
Epsilon: 0.08, Accuracy: 0.7639
Epsilon: 0.09000000000000001, Accuracy: 0.724
Epsilon: 0.1, Accuracy: 0.6775
