# Boilerplate

Packae installation, loading, and dataloaders.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import torchvision
import copy

from torchvision import datasets, transforms
from tensorboardX import SummaryWriter
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm

use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 64

np.random.seed(42)
torch.manual_seed(42)


## Dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Simple NN. You can change this if you want.
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # self.flatten = nn.Flatten()
        self.fc = nn.Linear(28*28, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 50)
        self.fc4 = nn.Linear(50, 10)
        self.nn = nn.Sequential(
          nn.Flatten(),
          self.fc,
          nn.ReLU(),
          self.fc2,
          nn.ReLU(),
          self.fc3,
          nn.ReLU(),
          self.fc4,
        )

    def forward(self, x):
        x = self.nn.forward(x)
        return x

class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307)/0.3081

# Add the data normalization as a first "layer" to the network
# this allows us to search for adverserial examples to the real image, rather than
# to the normalized image
model = nn.Sequential(Normalize(), Net())
model = model.to(device)
model.train()

# Implement Training

In [None]:
k = 40
eps = 0.1
eps_step = 0.01
num_epochs = 5
num_class = 10

def train_model(model, num_epochs, enable_defense=True, epsilon=0.01):
    learning_rate = 0.0001

    opt = optim.Adam(params=model.parameters(), lr=learning_rate)

    ce_loss = torch.nn.CrossEntropyLoss()

    tot_steps = 0

    for epoch in range(1,num_epochs+1):
        t1 = time.time()
        for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            tot_steps += 1
            opt.zero_grad()

            if enable_defense:
              out = model(x_batch)
              original_batch_loss = ce_loss(out, y_batch)
              ptb = PerturbationLpNorm(norm=np.inf, eps=epsilon)
              x_batch = BoundedTensor(x_batch, ptb)
              c = torch.eye(num_class).type_as(x_batch)[y_batch].unsqueeze(1) - torch.eye(num_class).type_as(x_batch).unsqueeze(0)
              I = (~(y_batch.data.unsqueeze(1) == torch.arange(num_class).type_as(y_batch.data).unsqueeze(0)))
              c = (c[I].view(x_batch.size(0), num_class - 1, num_class))
              lb, ub = model.compute_bounds(x=(x_batch,), IBP=True, method=None, C=c)
              lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype), lb), dim=1)
              fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64)
              batch_loss = ce_loss(-lb_padded, fake_labels)
            #   batch_loss = 0.5*original_batch_loss + 0.5*ce_loss(-lb_padded, fake_labels)
            else:
              out = model(x_batch)
              batch_loss = ce_loss(out, y_batch)
            batch_loss.backward()
            opt.step()

        tot_test, tot_acc = 0.0, 0.0
        for batch_idx, (x_batch, y_batch) in enumerate(test_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            out = model(x_batch)
            pred = torch.max(out, dim=1)[1]
            acc = pred.eq(y_batch).sum().item()
            tot_acc += acc
            tot_test += x_batch.size()[0]
        t2 = time.time()

        print('Epoch %d: Accuracy %.5lf [%.2lf seconds]' % (epoch, tot_acc/tot_test, t2-t1))

print("Training original model:")
train_model(model, num_epochs, enable_defense=False)
torch.save(model, "normal_model")

# Interval Analysis

In [None]:
image = test_dataset.data[:1].view(1,1,28,28)
image = image.to(torch.float32) / 255.0
model = BoundedModule(model, torch.empty_like(torch.tensor(image)))
model.eval()
# Interval Analysis of the network
for epsilon in [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:
  tot_test, tot_acc = 0.0, 0.0
  t1 = time.time()
  for batch_idx, (x_batch, y_batch) in enumerate(test_loader):
      x_batch, y_batch = x_batch.to(device), y_batch.to(device)
      ptb = PerturbationLpNorm(norm=np.inf, eps=epsilon)
      x_batch = BoundedTensor(x_batch, ptb)
      c = torch.eye(num_class).type_as(x_batch)[y_batch].unsqueeze(1) - torch.eye(num_class).type_as(x_batch).unsqueeze(0)
      I = (~(y_batch.data.unsqueeze(1) == torch.arange(num_class).type_as(y_batch.data).unsqueeze(0)))
      c = (c[I].view(x_batch.size(0), num_class - 1, num_class))
      lb, ub = model.compute_bounds(x=(x_batch,), IBP=True, method=None, C=c)
      acc = (lb.min(1)[0]>=0).sum().item()
      tot_acc += acc
      tot_test += x_batch.size()[0]
  t2 = time.time()
  print('Interval Analysis on original model: Robustness %.5lf [%.2lf seconds]' % (tot_acc/tot_test, t2-t1))

# provably robust training

In [None]:
for epsilon in [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:
    model = torch.load("normal_model")
    image = test_dataset.data[:1].view(1,1,28,28)
    image = image.to(torch.float32) / 255.0
    model = BoundedModule(model, torch.empty_like(torch.tensor(image)))
    model = model.to(device)
    model.train()
    print("Training robust model wit epsilon = " + str(epsilon) + ":")
    train_model(model, num_epochs, enable_defense=True, epsilon=epsilon)
    model.eval()
    # Interval Analysis of the network
    tot_test, tot_acc = 0.0, 0.0
    t1 = time.time()
    for batch_idx, (x_batch, y_batch) in enumerate(test_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        ptb = PerturbationLpNorm(norm=np.inf, eps=epsilon)
        x_batch = BoundedTensor(x_batch, ptb)
        c = torch.eye(num_class).type_as(x_batch)[y_batch].unsqueeze(1) - torch.eye(num_class).type_as(x_batch).unsqueeze(0)
        I = (~(y_batch.data.unsqueeze(1) == torch.arange(num_class).type_as(y_batch.data).unsqueeze(0)))
        c = (c[I].view(x_batch.size(0), num_class - 1, num_class))
        lb, ub = model.compute_bounds(x=(x_batch,), IBP=True, method=None, C=c)
        acc = (lb.min(1)[0]>=0).sum().item()
        tot_acc += acc
        tot_test += x_batch.size()[0]
    t2 = time.time()
    print('Interval Analysis on Robust model: Robustness %.5lf [%.2lf seconds]' % (tot_acc/tot_test, t2-t1))