Architecture based on the solution proposed in https://arxiv.org/pdf/1911.00937.pdf , appendix D.1

In [None]:
pip install foolbox

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import foolbox as fb
import time
import matplotlib.pyplot as plt
import torch.nn.functional as F
from utils import *
from multiClassHinge import multiClassHingeLoss

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) #cifar100
])

# Normalize the test set same as training set without augmentation
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) #cifar100
])

In [None]:
batch_size = 128

n_classes = 10

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_chan, out_chan, kernel_size, padding, shape):
        super(CNNBlock, self).__init__()

        self.shape = shape
        self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=kernel_size, padding=padding)
        self.storedEigVect = torch.rand(in_chan,shape[0], shape[1])

    def reg(self,):
      return deconv_orth_dist(torch.transpose(self.conv.weight,0,1))

    def forward(self,x):
      return self.conv(x)

In [None]:
dt = 1 # standard ResNET

class CNNBlock2(nn.Module):
    def __init__(self, in_chan, nf, n_layers):
        super(CNNBlock2, self).__init__()

        self.nlayers = n_layers
        self.nf = nf
        self.chans = in_chan

        self.matrices = self.nlayers

        self.conv1 = conv_block(self.chans, self.nf,pool=False)
        self.convs = nn.ModuleList([nn.Conv2d(self.nf, self.nf,3,1,1,bias=True) for i in range(self.matrices)])
        self.convsO = nn.ModuleList([nn.Conv2d(self.nf, self.nf,3,1,1,bias=False) for i in range(self.matrices)])

        for i in range(len(self.convs)):
          makeDeltaOrthogonal(self.convs[i].weight.data, nn.init.calculate_gain('leaky_relu',0))
          makeDeltaOrthogonal(self.convsO[i].weight.data, nn.init.calculate_gain('leaky_relu',0))

        self.mp = nn.MaxPool2d(2,2)

    def getReg(self,):
        reg = 0
        for i in range(self.nlayers):
            reg += deconv_orth_dist(self.convs[i].weight) + deconv_orth_dist(self.convsO[i].weight)
        return reg

    def forward(self, x):
        x = self.conv1(x)

        count = 0

        #cc = Positive(self.ells)

        for i in np.arange(0,self.matrices):
          A = self.convs[i]
          B = self.convsO[i]
          x = 0.5 * (x + dt * B(torch.relu(A(x))))

        x = self.mp(x)
        return x

class CNN(nn.Module):
    def __init__(self, in_chan, nf1, nf2, nf3, n_l1, n_l2, n_l3):
        super(CNN, self).__init__()

        self.input = in_chan
        self.nf1 = nf1
        self.nf2 = nf2
        self.nf3 = nf3
        self.n_l1 = n_l1
        self.n_l2 = n_l2
        self.n_l3 = n_l3

        self.seq = nn.Sequential(
            CNNBlock2(self.input,self.nf1,self.n_l1),
            CNNBlock2(self.nf1,self.nf2,self.n_l2),
            CNNBlock2(self.nf2,self.nf3,self.n_l3),
            nn.Flatten()
          )
        self.lin = nn.Linear(2048,100)
 
    def getReg(self,):
        
        return self.seq[0].getReg() + self.seq[1].getReg() + self.seq[2].getReg()

    def forward(self,x):
      x = self.seq(x)
      x = self.lin(x)
      return x

model = CNN(3,32,64,128,4,4,4)
model.to(device);

In [None]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f"The network has {params} trainable parameters")

In [None]:
lr = 1e-3
momentum = 0.9
optimizer = 'stochastic gradient descent'
scheduler = 'stepLR'
step_size = 30
EPOCHS = 100
weight_decay = 0

In [None]:
class Normalisation(torch.nn.Module):
    def __init__(self, means=(0.485, 0.456, 0.406), stds=(0.229, 0.224, 0.225)): #cifar100
        super().__init__()
        assert len(means) == len(stds)
        self.means = means
        self.stds = stds
    
    def forward(self, x):
        return (x - torch.tensor(self.means, device=x.device).view(1, len(self.means), 1, 1)) / torch.tensor(self.stds, device=x.device).view(1, len(self.means), 1, 1)

In [None]:
import torch.optim as optim

marginList = [0.07,0.15,0.3]
epsilons = [0.0, 8/255, 16/255, 36/255, 0.3, 0.5, 0.6, 0.8, 1.0]

robust_accuracy = np.zeros((len(marginList),len(epsilons)))

In [None]:
for iterate,margin in enumerate(marginList):

  model = CNN(3,32,64,128,4,4,4)
  model.to(device);

  pretrained_dict = torch.load(f"trained_model_{margin}_cifar10.pt",map_location=device)
  model.load_state_dict(pretrained_dict, strict=False)

  criterion = multiClassHingeLoss(margin = margin)
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
  
  gamma = .01

  EPOCHS = 100
  for epoch in range(EPOCHS):
      losses = []
      running_loss = 0
      correct = 0
      count = 0
      for i, inp in enumerate(trainloader):
          inputs, labels = inp
          inputs, labels = inputs.to(device), labels.to(device)
          optimizer.zero_grad()
      
          outputs = model(inputs)
          l1_reg = gamma * model.getReg()
          loss = criterion(outputs, labels)+l1_reg
          losses.append(loss.item())
          loss.backward()

          optimizer.step()
          running_loss += loss.item()

          if i%100 == 0 and i > 0:
              print(f'Loss [{epoch+1}, {i}](epoch, minibatch): ', running_loss / 100)
              running_loss = 0.0
      
      #Check current accuracy
      correct = 0
      total = 0
      # since we're not training, we don't need to calculate the gradients for our outputs
      
      model.eval()
      with torch.no_grad():
          
          norms = []
          regTerms = []
          for k in range(3):
            regTerms.append(model.seq[k].getReg().item())
            
          print(f"Norms: {norms}")
          print(f"Orthogonality violation: {regTerms}")


          for data in testloader:
              images, labels = data
              images.shape
              # calculate outputs by running images through the network
              outputs = model(images.to(device))
              # the class with the highest energy is what we choose as prediction
              _, predicted = torch.max(outputs.data, 1)
              total += labels.size(0)
              correct += (predicted == labels.to(device)).sum().item()
          print('Current accuracy on 10000 test images: %d %%' % (
              100 * correct / total))
      model.train()
      scheduler.step()
      if epoch%10 == 0 and epoch>0:
          lr=optimizer.param_groups[0]["lr"]
  print('Training Done')

  model.eval();
  transform_test_rob = transforms.Compose([
    transforms.ToTensor()
  ])
  batch_size = 1024
  testset_rob = torchvision.datasets.CIFAR100(root='./data', train=False,
                                        download=True, transform=transform_test_rob)
  testloader_rob = torch.utils.data.DataLoader(testset_rob, batch_size=batch_size,
                                          shuffle=False, num_workers=2)
  images, labels = next(iter(testloader_rob))
  images, labels = images.to(device), labels.to(device)
  model = nn.Sequential(Normalisation(),model).eval()
  fmodel = fb.PyTorchModel(model, bounds=(0, 1))
  
  acc = fb.utils.accuracy(fmodel, images, labels)
  attack = fb.attacks.L2PGD(steps=10)
  
  _, advs, success = attack(fmodel, images, labels, epsilons=epsilons)
  robust_accuracy[iterate] = torch.mean((1-1.*success),axis=1).detach().cpu().numpy()

  fig = plt.figure(figsize=(20,10))
  plt.plot(epsilons,robust_accuracy[iterate],'r-*',label="Experimental")
  plt.xlabel(r"$\varepsilon$",fontsize=20)
  plt.ylabel("Robust accuracy",fontsize=20)
  plt.xticks(fontsize=20)
  plt.yticks(fontsize=20)
  plt.legend(fontsize=20,loc=1)
  plt.title(f"L2 robustness. Trained with margin = {margin}. Test accuracy = {round(acc * 100,2)}%",fontsize=20);

  plt.savefig(f'Cifar100_L2margin_{margin}.png')

  dataForTxt = robust_accuracy[iterate]
  destination = f"Cifar100_updateMargin_{margin}.txt"
  np.savetxt(destination, dataForTxt.reshape(-1,1))

  torch.save(model.state_dict(), f"trained_model_margin_{margin}.pt")

In [None]:
marginList = [0.07, 0.15,0.3]
epsilons = [0.0, 8/255, 16/255, 36/255, 0.3, 0.5, 0.6, 0.8, 1.0]
fig = plt.figure(figsize=(20,10))
for i in range(len(marginList)):
  plt.plot(epsilons,robust_accuracy[i],'-*',label=f"Margin = {marginList[i]}")
plt.xlabel(r"$\varepsilon$",fontsize=20)
plt.ylabel("Robust accuracy",fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(fontsize=20,loc=1)
plt.title(f"L2 robustness comparison CIFAR100",fontsize=20);
plt.savefig("Cifar100_RobustnessCNN.png")
plt.show()