In [1]:
from google.colab import drive
drive.mount('/content/drive')

# put folder name here
FOLDERNAME = 'multiclass_polyact/'

import sys
sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

Mounted at /content/drive


In [2]:
import torch
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from copy import deepcopy
import matplotlib.pyplot as plt

from models.one_vs_all import MultiClassPolyAct

from torch_solvers.robust_polyact_solver import adversarial_train_poly_act
from torch_solvers.robust_polyact_solver_batch import batch_adversarial_polyact_train, Spliced_PolyAct
from torch_solvers.alternate_solver import alt_solver

from models.praresnet import PreActResNet18
from models.one_vs_all import MultiClassPolyAct
from models.spliced import Spliced, Spliced_PolyAct

from utils import FeatureDataset

from cvx_scripts.losses import *
from cvx_scripts.cvx_nn import *
from cvx_scripts.cvx_training import *

from attacks.fgsm import eval_fgsm

%load_ext autoreload
%autoreload 2

In [3]:
embedding_size = 512

In [4]:
def extract_features(dummy_loader, model, shuffle=True):
  X = torch.zeros(0, embedding_size)
  y = torch.zeros(0)
  for img, label in dummy_loader:
      img = img.to(device)
      out = model.truncated_forward(img).detach().cpu()
      X = torch.vstack((X, out))
      y = torch.cat((y, label))
      del img

  X = X.view(X.shape[0], -1).detach()
  n = X.shape[0]
  if shuffle:
    scrambled_idxs = np.random.choice(n, n, replace=False)
    X = X[scrambled_idxs]
    y = y[scrambled_idxs]
  torch.cuda.empty_cache()
  return X, y

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [6]:
# Load CIFAR-10 data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True)

testset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

mean = torch.tensor([0.4914, 0.4822, 0.4465]).to(device)
std = torch.tensor([0.2023, 0.1994, 0.2010]).to(device)

# test loader with batch size of one for fast gradient sign method
testloader_fgsm = torch.utils.data.DataLoader(
    testset, batch_size=1000, shuffle=False)

# load in pre-trained Pre-Activaiton ResNet-18 models trained via sharpness-aware minimization and standard training.
pr18_sam = PreActResNet18(10)
pr18_sam.load_state_dict(torch.load(sys.path[-1] + 'pretrained_models/praresnet.pth'))
pr18_sam = pr18_sam.to(device)

pr18 = PreActResNet18(10)
pr18.load_state_dict(torch.load(sys.path[-1] + 'pretrained_models/praresnet_nonsam.pth'))
pr18 = pr18.to(device)

dummy_train_loader= torch.utils.data.DataLoader(
    trainset, batch_size=1000, shuffle=False,
    pin_memory=True, sampler=None)
dummy_test_loader= torch.utils.data.DataLoader(
    testset, batch_size=1000, shuffle=False,
    pin_memory=True, sampler=None)

X_train, y_train = extract_features(dummy_train_loader, pr18)
X_test, y_test = extract_features(dummy_test_loader, pr18)
_, trunc_d = X_train.shape

==> Preparing data..
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:18<00:00, 9189551.36it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


  pr18_sam.load_state_dict(torch.load(sys.path[-1] + 'pretrained_models/praresnet.pth'))
  pr18.load_state_dict(torch.load(sys.path[-1] + 'pretrained_models/praresnet_nonsam.pth'))


Specify the dataloader. Warning: batch sizes of more than ~200 are EXTREMELY memory intensive. 

In [7]:
batch_size = 100
train_dataset = FeatureDataset(X_train.cpu()[:100], y_train.cpu()[:100].type(torch.LongTensor), 10)
test_dataset = FeatureDataset(X_test[:500].cpu(), y_test[:500].cpu().type(torch.LongTensor), 10)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


Train a polynomial activation network.

In [8]:
beta = 0.01
r = 1.5
lr = 0.01
rho = 2
epochs = 30
model = MultiClassPolyAct(10, 512, device=device, init='zero')
best_model, losses, train_accs, val_accs, best_robust = batch_adversarial_polyact_train(model,
                                                                                        train_loader,
                                                                                        test_loader,
                                                                                        r,
                                                                                        beta,
                                                                                        device,
                                                                                        lr=lr,
                                                                                        epochs=epochs,
                                                                                        rho=rho,
                                                                                        batch_size=batch_size,
                                                                                        verbose=True,
                                                                                        base_model = pr18,#pr18,
                                                                                        robust_eval_loader = testloader_fgsm
                                                                                        )

Splice a robust polynomial activation network and base image classification model.

In [None]:
spliced = Spliced_PolyAct(pr18, best_model)
robusts = []
for eps in [0, 1, 2, 3, 4, 5, 6, 7, 8]:
  spliced.robust = True
  robust = eval_fgsm(spliced, device, testloader_fgsm, eps/255, mean, std, is_polyact=True)
  robusts.append(robust)
print(robusts)

Load in a polynomial activation network

In [None]:
# load in a pre-trained convex two-layer ReLU network
cvx = custom_cvx_layer(512, 500)
cvx.load_state_dict(torch.load(sys.path[-1] + 'praresnet_nonsam_500_inf_5.pth', map_location=torch.device('cpu')))
cvx.to(device)
uvec = torch.from_numpy(torch.load(sys.path[-1] + 'u_vec_praresnet_nonsam_500.pth')).to(device).float()

In [None]:
from prepare_data import *

Evaluate robustness of the base, sam, and robustified models

In [14]:
spliced = Spliced_PolyAct(pr18, best_model)
for p in spliced.parameters():
  p.requires_grad = False

for p in pr18_sam.parameters():
  p.requires_grad = False

robust_poly = []
standards = []
sams = []
for eps in [0, 1, 2, 3, 4, 5, 6, 7, 8]:
  spliced.robust = True
  robust = eval_fgsm(spliced, device, testloader_fgsm, eps/255, mean, std, is_polyact=True)
  robust_poly.append(robust)
  spliced.robust = False
  standard = eval_fgsm(spliced, device, testloader_fgsm, eps/255, mean, std)
  standards.append(standard)
  sam = eval_fgsm(pr18_sam, device, testloader_fgsm, eps/255, mean, std)
  sams.append(sam)

In [None]:
plt.plot(sams, label = 'SAM')
plt.plot(standards, label = 'Standard')
plt.plot(robust_poly, label = 'Poly')