# Reproducing NeurIPS2019 - Subspace Attack

As part of the NeurIPS2019 reproducibility challenge (project 2 of EPFL CS-433 2019) we chose to reproduce the paper [__Subspace Attack: Exploiting Promising Subspaces for Query-Efficient Black-box Attacks__](https://openreview.net/pdf?id=S1g-OVBl8r)

The algorithm is specified in: 

<img src="img/algo1.png" style="width:600px;"/>

We need to create the following functions:
- Load random reference model
- Loss function calculation
- Prior gradient calculation wrt dropout/layer ratio
- Attack

The pre-trained models are in (https://drive.google.com/file/d/1aXTmN2AyNLdZ8zOeyLzpVbRHZRZD0fW0/view).
The least demanding target model is the __GDAS__.


__Note!__ we start with 0 droupout ratio.

In [1]:
import torch

from models.cifar.gdas import load_gdas
from models.cifar import vgg

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

MODELS = 'pretrained/'

MODELS_DATA = {
    'gdas': {
        'folder': 'gdas/',
        'model_checkpoint': 'gdas-cifar10.pth',
        'config': 'gdas-cifar10.config'
    },
    'vgg16': {
        'folder': 'vgg16_bn/',
        'model_checkpoint': 'vgg16_bn.pth',
        'model': vgg.vgg16
    }
}

In [2]:
gdas_data = MODELS_DATA['gdas']

gdas = load_gdas(MODELS + gdas_data['folder'] + gdas_data['model_checkpoint'])

gdas.eval()

NetworkCIFAR(
  (stem): Sequential(
    (0): Conv2d(3, 108, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(108, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (cells): ModuleList(
    (0): Cell(
      (preprocess0): ReLUConvBN(
        (op): Sequential(
          (0): ReLU()
          (1): Conv2d(108, 36, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (preprocess1): ReLUConvBN(
        (op): Sequential(
          (0): ReLU()
          (1): Conv2d(108, 36, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (_ops): ModuleList(
        (0): Identity()
        (1): Identity()
        (2): Identity()
        (3): SepConv(
          (op): Sequential(
            (0): ReLU()
            (1): Conv2d(36, 3

In [3]:
vgg16_data = MODELS_DATA['vgg16']

vgg16 = vgg.vgg16_bn(num_classes=10)
vgg16_module_state_dict = torch.load(MODELS + vgg16_data['folder'] + vgg16_data['model_checkpoint'], map_location=device)['state_dict']


vgg16_state_dict = { key.replace('module.', ''): vgg16_module_state_dict[key] for key in vgg16_module_state_dict }
vgg16.load_state_dict(vgg16_state_dict)

vgg16.eval()

VGG(
  (features): Sequential(
    (0): DropoutConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): DropoutConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): DropoutConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): DropoutConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=Fals

In [5]:
def load_model(models_data, name, num_classes):
    model_data = models_data[name]
    model = model_data['model'](num_classes=num_classes)
    model_raw_state_dict = (
        torch
            .load(
                MODELS + model_data['folder'] + model_data['model_checkpoint'],
                map_location=device
            )['state_dict']
    )
    
    model_state_dict = { key.replace('module.', ''): model_raw_state_dict[key] for key in model_raw_state_dict }
    
    model.load_state_dict(model_state_dict)
    
    return model

In [6]:
vgg16 = load_model(MODELS_DATA, 'vgg16', 10)

vgg16.eval()

RuntimeError: Error(s) in loading state_dict for VGG:
	Missing key(s) in state_dict: "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.12.weight", "features.12.bias", "features.19.weight", "features.19.bias", "features.26.weight", "features.26.bias". 
	Unexpected key(s) in state_dict: "features.31.weight", "features.31.bias", "features.31.running_mean", "features.31.running_var", "features.31.num_batches_tracked", "features.34.weight", "features.34.bias", "features.35.weight", "features.35.bias", "features.35.running_mean", "features.35.running_var", "features.35.num_batches_tracked", "features.37.weight", "features.37.bias", "features.38.weight", "features.38.bias", "features.38.running_mean", "features.38.running_var", "features.38.num_batches_tracked", "features.40.weight", "features.40.bias", "features.41.weight", "features.41.bias", "features.41.running_mean", "features.41.running_var", "features.41.num_batches_tracked", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.1.num_batches_tracked", "features.3.weight", "features.3.bias", "features.4.weight", "features.4.bias", "features.4.running_mean", "features.4.running_var", "features.4.num_batches_tracked", "features.8.weight", "features.8.bias", "features.8.running_mean", "features.8.running_var", "features.8.num_batches_tracked", "features.11.weight", "features.11.bias", "features.11.running_mean", "features.11.running_var", "features.11.num_batches_tracked", "features.15.weight", "features.15.bias", "features.15.running_mean", "features.15.running_var", "features.15.num_batches_tracked", "features.18.weight", "features.18.bias", "features.18.running_mean", "features.18.running_var", "features.18.num_batches_tracked", "features.20.weight", "features.20.bias", "features.21.running_mean", "features.21.running_var", "features.21.num_batches_tracked", "features.25.weight", "features.25.bias", "features.25.running_mean", "features.25.running_var", "features.25.num_batches_tracked", "features.27.weight", "features.27.bias", "features.28.running_mean", "features.28.running_var", "features.28.num_batches_tracked", "features.30.weight", "features.30.bias". 
	size mismatch for features.7.weight: copying a param with shape torch.Size([128, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for features.10.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
	size mismatch for features.10.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for features.14.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for features.17.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 256, 3, 3]).
	size mismatch for features.17.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for features.21.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for features.21.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for features.24.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for features.28.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).