In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
"""
Testing stat defense vs. blackbox attack

Steps:
Load data
Train defense model on data
Use defense model as 'oracle' for substitute (replace targets in data with defense model's predictions)
Train substitute on data with targets replaced by oracle (optionally using Jacobian-based data augmentation)
Generate attack data against substitute model
Test defense model against attack data
"""

"\nTesting stat defense vs. blackbox attack\n\nSteps:\nLoad data\nTrain defense model on data\nUse defense model as 'oracle' for substitute (replace targets in data with defense model's predictions)\nTrain substitute on data with targets replaced by oracle (optionally using Jacobian-based data augmentation)\nGenerate attack data against substitute model\nTest defense model against attack data\n"

In [4]:
import os.path
import matplotlib.pyplot as plt
from IPython.display import Image 
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
from torch.autograd.functional import jacobian
from torch.autograd import grad
from sklearn.model_selection import train_test_split

import team36
from team36.mnist.data_loading import MNIST_Loader
from team36.defenses.stat_def import VGG
from team36.mnist.cnn import CNN
from team36.attacks.fast_gradient_attack_data_set import FastSignGradientAttackDataSet
from team36.defenses.fast_gradient_sign_method_loss import FastGradientSignMethodLoss
from team36.training import train, validate, accuracy, predict, predict_from_loader, do_training, load_or_train, train_val_split, train_batch

DIR = '.'
DATA_DIR = f'{DIR}/data'

In [6]:
"""
Set up the datasets
It is interesting to compare using the same dataset for the substitute and oracle vs. using different datasets
"""
ORACLE_DATASET = 'MNIST' # 'MNIST' or 'CIFAR10'
SUB_DATASET = 'MNIST' 
DATASET_NAMES = [ORACLE_DATASET, SUB_DATASET]
MODEL_NAMES = ['oracle', 'sub']
ORACLE = 0
SUB = 1
datasets = [] # two-element list to store the oracle dataset and the substitute dataset
dataset_image_sizes = []
dataset_channels = []
test_datasets = []

for idx, dataset_name in enumerate(DATASET_NAMES):
    transform_seq = [transforms.ToTensor()]
    if dataset_name == 'MNIST':
        image_size = 28
        in_channels = 1
        if idx == SUB and ORACLE_DATASET != 'MNIST': # if substitute uses mnist, but oracle uses other dataset add padding to match image size
            padding = (dataset_image_sizes[ORACLE] - image_size) // 2
            transform_seq.append(transforms.Pad(padding, fill=0))
            image_size += padding * 2
            transform_seq.append(transforms.Lambda(lambda x: x.repeat(3, 1, 1)))
            in_channels = 3 # also need to copy to 3 channels to match rgb cifar
            # NOTE: above assumes oracle_dataset has larger images and that there is an even number difference in the image sizes
        dataset = torchvision.datasets.MNIST(root=DATA_DIR, train=True, download=True, 
                                              transform=transforms.Compose(transform_seq))
        test_dataset = torchvision.datasets.MNIST(root=DATA_DIR, train=False, download=True, 
                                              transform=transforms.Compose(transform_seq))
    elif dataset_name == 'CIFAR10':
        image_size = 32
        in_channels = 3
        dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True,
                                                transform=transforms.Compose(transform_seq))
        test_dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, 
                                              transform=transforms.Compose(transform_seq))
    else:
        print('ERROR: invalid dataset name', dataset_name)
    datasets.append(dataset)
    dataset_image_sizes.append(image_size)
    dataset_channels.append(in_channels)
    test_datasets.append(test_dataset)

# optionally, use the test set to train the substitute, so that it trains on different data than the oracle:
# datasets[SUB] = test_datasets[SUB]
    
datasets

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


[Dataset MNIST
     Number of datapoints: 60000
     Root location: ./data
     Split: Train
     StandardTransform
 Transform: Compose(
                ToTensor()
            ), Dataset MNIST
     Number of datapoints: 60000
     Root location: ./data
     Split: Train
     StandardTransform
 Transform: Compose(
                ToTensor()
            )]

In [7]:
"""Initialize the defense model (aka the oracle)"""
if DATASET_NAMES[ORACLE] == 'MNIST':
    defense_model = VGG() 
elif DATASET_NAMES[ORACLE] == 'CIFAR10':
    defense_model = VGG(image_size=32, in_channels=3) 
oracle = defense_model
    
defense_model

VGG(
  (convolution_layers): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (linear_l

In [24]:
"""Train the defense model"""
torch.manual_seed(0)
#TODO: train defense_model using datasets[ORACLE]
# insert your own training code, or use load_or_train() from training.py if it helps
# note that load_or_train() generates a checkpoint with the given name and you need to delete it if you want to re-train from scratch

In [43]:
%%time
"""Relabel the substitute dataset with the oracle's predictions"""
loader = torch.utils.data.DataLoader(datasets[SUB], batch_size=100, shuffle=False, num_workers=0)
oracle_preds = predict_from_loader(oracle, loader)

CPU times: user 9min 6s, sys: 3.63 s, total: 9min 9s
Wall time: 9min 9s


In [44]:
%%time
oracle_preds = oracle_preds.argmax(axis=1)

CPU times: user 15.9 ms, sys: 0 ns, total: 15.9 ms
Wall time: 16.3 ms


In [45]:

# sanity check: oracle preds should mostly match true labels (i.e. close to 60,000 on MNIST or 50,000 on CIFAR10)
len(oracle_preds[oracle_preds == datasets[SUB].targets])

60000

In [46]:

# replace true labels with oracle's predictions
datasets[SUB].targets = oracle_preds

In [52]:
%%time
torch.manual_seed(0)
sub = CNN(image_size = dataset_image_sizes[SUB], in_channels = dataset_channels[SUB])
sub_checkpoint = f'{DATASET_NAMES[SUB]}-substitute.pth' # e.g. 'MNIST-substitute.pth'

loader = torch.utils.data.DataLoader(datasets[SUB], batch_size=100, shuffle=False, num_workers=0)
# NOTE: due to a bug with lambda transforms on Windows, num_workers needs to be 0 or there will be an error
# https://github.com/belskikh/kekas/issues/26

epochs = 30
#TODO: test with train_size 1800 and 10,000, and tune hyperparameters if necessary
train_size = 10000
# train_size = 1800 
learning_rate = 0.05
momentum = 0.5
train = Subset(datasets[SUB], range(train_size))
val = Subset(datasets[SUB], range(train_size, train_size*2)) # val split same size as train split
load_or_train(sub, sub_checkpoint, train_split=train, val_split=val, epochs=epochs, learning_rate=learning_rate, momentum=momentum)

CPU times: user 2.93 ms, sys: 994 µs, total: 3.92 ms
Wall time: 3.94 ms


In [53]:
target = oracle
criterion = nn.CrossEntropyLoss()
models = [(oracle, 'oracle')]

In [54]:

%%time
"""Regular Test"""
test_set = test_datasets[ORACLE] 

for model, name in models:
    print(name)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, num_workers=0)

    test_accuracy, _, test_loss = validate(None, test_loader, model, criterion)

    print(f"{name} Regular Test Accuracy is {test_accuracy}")
    print(f"{name} Regular Test Loss is {test_loss}")

oracle
oracle Regular Test Accuracy is 0.12080000340938568
oracle Regular Test Loss is 2.472759425640106
CPU times: user 1min 28s, sys: 577 ms, total: 1min 29s
Wall time: 1min 28s


In [55]:

%%time
"""Attack Test"""
epsilon = 0.05 #TODO: record results for epsilon=0.25, 0.1, 0.05, and 0.01
test_set = test_datasets[ORACLE] # since the substitute tries to mimic the oracle, we may as well test on the oracle's test set
attack_test_set = FastSignGradientAttackDataSet(test_set, sub, criterion, epsilon=epsilon) # generate attack against substitute model

for model, name in models:
    print(name)
    test_loader = torch.utils.data.DataLoader(attack_test_set, batch_size=100, shuffle=False, num_workers=0)

    test_accuracy, _, test_loss = validate(None, test_loader, model, criterion)

    print(f"{name} Attack Test Accuracy is {test_accuracy}")
    print(f"{name} Attack Test Loss is {test_loss}")

oracle
oracle Attack Test Accuracy is 0.10980000346899033
oracle Attack Test Loss is 2.5156938576698304
CPU times: user 1min 44s, sys: 725 ms, total: 1min 45s
Wall time: 1min 45s
