Running Adversarial Neuron Pruning

In [1]:
import torch
import torchvision

from torchsummary import summary

import numpy as np

In [2]:
torch.cuda.is_available()

True

In [3]:
#check gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
import torchvision.transforms as transforms

cifar10_train = torchvision.datasets.CIFAR10('datasets/cifar_10', download=True, transform=transforms.ToTensor())

Files already downloaded and verified


In [5]:
# create a subset of the dataset
l = len(cifar10_train)
indices = np.arange(l)

np.random.seed(78125)
np.random.shuffle(indices)
subset_indices = indices[:l // 10]
subset = torch.utils.data.Subset(cifar10_train, subset_indices)

len(subset)

5000

In [6]:
train_loader = torch.utils.data.DataLoader(subset, batch_size=125, shuffle=True, num_workers=4)

### Create and load the ResNet18

In [7]:
res18 = torchvision.models.resnet18().cuda() if torch.cuda.is_available() else torchvision.models.resnet18()

In [8]:
# check if the parameters are in cuda
next(res18.parameters()).is_cuda

True

In [9]:
res18.load_state_dict(torch.load(f'saved_models/ResNet18-CIFAR10-Epoch-100.pth'))

  res18.load_state_dict(torch.load(f'saved_models/ResNet18-CIFAR10-Epoch-100.pth'))


<All keys matched successfully>

### Create ANP wrapper

In [10]:
from ANP import *

In [11]:
anp_system = ANPWrapper(res18, 0.2, 0.2, 0.4)

In [12]:
for epoch in range(50):
    i = 0
    for inputs, label in train_loader:
        inputs, label = inputs.to(device), label.to(device)
        # perform perturb step
        weight_masks_loss = anp_system.perturb_step(inputs, label)
        print(f'epoch: {epoch} | iteration: {i} | weight_mask_loss: {weight_masks_loss}')
        i += 1

epoch: 0 | iteration: 0 | weight_mask_loss: 7.308289937674999
epoch: 0 | iteration: 1 | weight_mask_loss: 7.488302744925022
epoch: 0 | iteration: 2 | weight_mask_loss: 7.485943049192429
epoch: 0 | iteration: 3 | weight_mask_loss: 7.627250820398331
epoch: 0 | iteration: 4 | weight_mask_loss: 7.4386937618255615
epoch: 0 | iteration: 5 | weight_mask_loss: 6.833263494074345
epoch: 0 | iteration: 6 | weight_mask_loss: 7.152169294655323
epoch: 0 | iteration: 7 | weight_mask_loss: 7.416731666773558
epoch: 0 | iteration: 8 | weight_mask_loss: 5.642515629529953
epoch: 0 | iteration: 9 | weight_mask_loss: 6.890158586204052
epoch: 0 | iteration: 10 | weight_mask_loss: 7.149699196219444
epoch: 0 | iteration: 11 | weight_mask_loss: 5.4636852368712425
epoch: 0 | iteration: 12 | weight_mask_loss: 8.035139553248882
epoch: 0 | iteration: 13 | weight_mask_loss: 6.65930313616991
epoch: 0 | iteration: 14 | weight_mask_loss: 7.840177871286869
epoch: 0 | iteration: 15 | weight_mask_loss: 7.205662623047829
e