# CW + Diversity Regularization on CIFAR10

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import torchvision
import torchvision.transforms as transforms

import pickle
import datetime
import glob
import os
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
pd.set_option('display.max_rows', None)
pd.set_option('precision', 10)

# custom code imports
from neuron_coverage import *

import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

device = torch.device("cpu")
if torch.cuda.is_available():
    print('CUDA is available!')
    device = torch.device("cuda")
else:
    print('CUDA is not available...')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
CUDA is available!


# Loading a Pretrained ResNet56
https://github.com/akamaster/pytorch_resnet_cifar10

In [2]:
from resnet import *

models_dir = 'pretrained_models/cifar10/' 
model = resnet56().cuda()
state_dict = torch.load(models_dir + 'resnet56.th', map_location='cuda')['state_dict'] # best_prec1, state_dict

new_state_dict = {}

for k, v in state_dict.items():
    if 'module' in k:
        k = k.replace('module.', '')
    new_state_dict[k]=v
    
model.load_state_dict(new_state_dict)

print('model loaded!')
# print(model)

model loaded!


# Load Data

In [3]:
data_dir = 'C:\data\CIFAR10'

classes = ['plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck']

if not os.path.exists(data_dir):
    os.makedirs(data_dir)

batch_size = 100 # determines how many of each class we want

def get_same_index(targets, label):
    '''
    Returns indices corresponding to the target label
    which the dataloader uses to serve downstream.
    '''
    label_indices = []
    for i in range(len(targets)):
        if targets[i] == label:
            label_indices.append(i)
    return label_indices

dataset = torchvision.datasets.CIFAR10(root=data_dir, 
                                       train=False, 
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor()
                                       ]))

Files already downloaded and verified


## Generate per class batches for NC evaluation

In [4]:
# data = []
# labels = []

# for i in range(len(classes)):
    
#     target_indices = get_same_index(dataset.targets, i)
    
#     test_loader = torch.utils.data.DataLoader(dataset,
#         batch_size=batch_size, 
#         sampler=torch.utils.data.sampler.SubsetRandomSampler(target_indices),
#         shuffle=False,
#         num_workers=2, 
#         pin_memory=True)
    
#     inputs, targets = next(iter(test_loader))
    
#     data.append(inputs)
#     labels.append(targets)
    
# # torch.Size([10, batch_size, 3, 32, 32])
# inputs = torch.stack(data).to(device)

# # torch.Size([new_batch_size, 3, 32, 32])
# new_batch_size = len(classes) * batch_size
# all_inputs = inputs.view(new_batch_size,3,32,32)

# # torch.Size([10, batch_size])
# targets = torch.stack(labels).to(device)

# # torch.Size([new_batch_size])
# all_targets = targets.view(-1)

# # individual class performance
# for i in range(len(classes)):
#     # confirm that loading the weights actually worked
#     orig_output = model(inputs[i])
#     orig_pred = torch.argmax(orig_output, dim=1)
#     orig_correct = orig_pred.eq(targets[i].data).sum()
#     orig_acc = 100. * orig_correct / len(targets[i])
#     print(classes[i] + '\t accuracy: {}/{} ({:.0f}%)'.format(orig_correct, len(targets[i]), orig_acc))

# # all class performance
# orig_output = model(all_inputs)
# orig_pred = torch.argmax(orig_output, dim=1)
# orig_correct = orig_pred.eq(all_labels.data).sum()
# orig_acc = 100. * orig_correct / len(all_labels)
# print('total accuracy: {}/{} ({:.0f}%)'.format(orig_correct, len(all_labels), orig_acc))

In [5]:
# # confirm that all the images are still intact after reshaping the tensor view
# orig_inputs = all_inputs.clone().cpu().detach().numpy()
# for img in orig_inputs:
#     plt.imshow(np.transpose(np.squeeze(img), (1, 2, 0))) 
#     plt.show()

## Evaluate per class and overall accuracy

In [31]:
def get_acc(model, device, test_loader, sampler=False, class_idx=None, classes=None):
    # model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            
#             # peek at images to make sure they're of the right class
#             x = 0            
#             if x % 100 == 0:
#                 np_data = data.clone().cpu().detach().numpy()
#                 plt.imshow(np.transpose(np.squeeze(np_data[0]), (1, 2, 0))) 
#                 plt.show()               
#             x += 1
            
    if sampler:
        if class_idx is None:
            raise Exception('you must provide an integer class index if sampler=True')
        if classes is None:
            raise Exception('you must provide an iterable of class indices if sampler=True')
        values, counts = np.unique(test_loader.dataset.targets, return_counts=True) 
        divisor = counts[class_idx]
        acc = 100. * correct / divisor
        print(classes[class_idx] + '\t accuracy: {}/{} ({:.2f}%)'.format(correct, divisor, acc))
    else:
        divisor = len(test_loader.dataset)
        acc = 100. * correct / divisor
        print('accuracy: {}/{} ({:.2f}%)'.format(correct, divisor, acc))

In [35]:
# evaluate on total test set
test_loader = DataLoader(dataset,
    batch_size=batch_size, 
    shuffle=False,
    pin_memory=True)

get_acc(model, device, test_loader)

accuracy: 9107/10000 (91.07%)


In [36]:
# evaluate on each class separately
for i in range(len(classes)):
    
    target_indices = get_same_index(dataset.targets, i)
    
    test_loader = DataLoader(dataset,
        batch_size=batch_size,
        sampler=SubsetRandomSampler(target_indices),
        shuffle=False,
        pin_memory=True)
    
    get_acc(model, device, test_loader, sampler=True, class_idx=i, classes=classes)

plane	 accuracy: 215/1000 (21.50%)
car	 accuracy: 304/1000 (30.40%)
bird	 accuracy: 243/1000 (24.30%)
cat	 accuracy: 244/1000 (24.40%)
deer	 accuracy: 238/1000 (23.80%)
dog	 accuracy: 259/1000 (25.90%)
frog	 accuracy: 188/1000 (18.80%)
horse	 accuracy: 281/1000 (28.10%)
ship	 accuracy: 210/1000 (21.00%)
truck	 accuracy: 262/1000 (26.20%)


# Evaluating Neuron Coverage Per Class

In [16]:
nc_thresholds = np.arange(0,1,0.1)

neuron_coverages_by_class = {}

for i in range(len(classes)):
    
    threshold_results = {}
    
    for t in nc_thresholds:
        
        t = round(t,2)

        covered_neurons, total_neurons, neuron_coverage = eval_nc(model, inputs[i], t)
        print(classes[i] + ' neuron_coverage_' + str(t), neuron_coverage)

        threshold_results[str(t)] = neuron_coverage
    
    neuron_coverages_by_class[classes[i]] = threshold_results

HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

plane neuron_coverage_0.0 0.8185937764089466


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

plane neuron_coverage_0.1 0.3672275535690811


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

plane neuron_coverage_0.2 0.20617100790625176


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

plane neuron_coverage_0.3 0.11765103570020094


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

plane neuron_coverage_0.4 0.06495521042648689


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

plane neuron_coverage_0.5 0.03480816541155703


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

plane neuron_coverage_0.6 0.018911153261094105


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

plane neuron_coverage_0.7 0.009525061503502413


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

plane neuron_coverage_0.8 0.00339912486619467


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

plane neuron_coverage_0.9 0.0008751338053296776


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

car neuron_coverage_0.0 0.8338410110987999


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

car neuron_coverage_0.1 0.39302522113091326


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

car neuron_coverage_0.2 0.22464271629514168


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

car neuron_coverage_0.3 0.12897331405284607


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

car neuron_coverage_0.4 0.07169524310315686


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

car neuron_coverage_0.5 0.03917632256004808


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

car neuron_coverage_0.6 0.0200792503145599


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

car neuron_coverage_0.7 0.0088133110480948


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

car neuron_coverage_0.8 0.002937144359518489


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

car neuron_coverage_0.9 0.0006666791864635955


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

bird neuron_coverage_0.0 0.8354973802324926


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

bird neuron_coverage_0.1 0.3964224680275686


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

bird neuron_coverage_0.2 0.21960224605156906


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

bird neuron_coverage_0.3 0.12486807263986179


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

bird neuron_coverage_0.4 0.06970271742192341


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

bird neuron_coverage_0.5 0.039031718905519354


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

bird neuron_coverage_0.6 0.021420120565644424


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

bird neuron_coverage_0.7 0.009517549625345076


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

bird neuron_coverage_0.8 0.0030216529887885217


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

bird neuron_coverage_0.9 0.0007380420289582903


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

cat neuron_coverage_0.0 0.8376664350504235


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

cat neuron_coverage_0.1 0.40234558395462827


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

cat neuron_coverage_0.2 0.22606997314503557


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

cat neuron_coverage_0.3 0.12997990572592913


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

cat neuron_coverage_0.4 0.07498920167514883


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

cat neuron_coverage_0.5 0.042368870776916


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

cat neuron_coverage_0.6 0.02293939792296569


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

cat neuron_coverage_0.7 0.011406786981915154


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

cat neuron_coverage_0.8 0.003891152885500197


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

cat neuron_coverage_0.9 0.0008957914702623523


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

deer neuron_coverage_0.0 0.8410768277338542


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

deer neuron_coverage_0.1 0.42013746737027924


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

deer neuron_coverage_0.2 0.24404026366692333


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

deer neuron_coverage_0.3 0.14266558996413078


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

deer neuron_coverage_0.4 0.08242596105091175


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

deer neuron_coverage_0.5 0.04726285939642059


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

deer neuron_coverage_0.6 0.025470900861988017


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

deer neuron_coverage_0.7 0.01129410880955511


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

deer neuron_coverage_0.8 0.0036169693327574226


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

deer neuron_coverage_0.9 0.0007962590846776465


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

dog neuron_coverage_0.0 0.8376964825630528


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

dog neuron_coverage_0.1 0.3934947135157468


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

dog neuron_coverage_0.2 0.21995154838588518


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

dog neuron_coverage_0.3 0.12715356156923135


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

dog neuron_coverage_0.4 0.07310184228811809


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

dog neuron_coverage_0.5 0.03904298672275536


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

dog neuron_coverage_0.6 0.020606959755112772


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

dog neuron_coverage_0.7 0.009649007493098462


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

dog neuron_coverage_0.8 0.0031155514657552254


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

dog neuron_coverage_0.9 0.0007117504554076132


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

frog neuron_coverage_0.0 0.8356720313996507


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

frog neuron_coverage_0.1 0.40694285338691805


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

frog neuron_coverage_0.2 0.23220717759957935


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

frog neuron_coverage_0.3 0.13496779282240043


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

frog neuron_coverage_0.4 0.07763150481699187


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

frog neuron_coverage_0.5 0.042182951792521926


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

frog neuron_coverage_0.6 0.021545944524779807


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

frog neuron_coverage_0.7 0.009756051756840503


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

frog neuron_coverage_0.8 0.0032507652725872788


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

frog neuron_coverage_0.9 0.0007117504554076132


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

horse neuron_coverage_0.0 0.8336964074442712


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

horse neuron_coverage_0.1 0.4045578320719638


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

horse neuron_coverage_0.2 0.229664406843321


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

horse neuron_coverage_0.3 0.13203816034103927


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

horse neuron_coverage_0.4 0.07367837893669364


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

horse neuron_coverage_0.5 0.03994816804071438


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

horse neuron_coverage_0.6 0.020640763206820786


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

horse neuron_coverage_0.7 0.009427407087457042


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

horse neuron_coverage_0.8 0.003100527709440553


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

horse neuron_coverage_0.9 0.0007624556329696332


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

ship neuron_coverage_0.0 0.826280305733441


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

ship neuron_coverage_0.1 0.37701928674716895


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

ship neuron_coverage_0.2 0.2104828259685628


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

ship neuron_coverage_0.3 0.11996093823358185


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

ship neuron_coverage_0.4 0.06637871133730211


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

ship neuron_coverage_0.5 0.03523446449698586


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

ship neuron_coverage_0.6 0.01820315874476516


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

ship neuron_coverage_0.7 0.009232098255366299


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

ship neuron_coverage_0.8 0.0037371593832748033


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

ship neuron_coverage_0.9 0.0008488422317790005


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

truck neuron_coverage_0.0 0.830496347349246


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

truck neuron_coverage_0.1 0.3855490243948243


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

truck neuron_coverage_0.2 0.21524911265939267


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

truck neuron_coverage_0.3 0.12043043061841537


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

truck neuron_coverage_0.4 0.06625852128678472


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

truck neuron_coverage_0.5 0.03556686510544799


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

truck neuron_coverage_0.6 0.01858063062217131


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

truck neuron_coverage_0.7 0.008452740896542658


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

truck neuron_coverage_0.8 0.002916486694585814


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))

truck neuron_coverage_0.9 0.0007474318766549607


In [17]:
save_file_path = "assets/neuron_coverages_cifar10_10.pkl"
pickle.dump(neuron_coverages_by_class, open(save_file_path, "wb")) 

In [21]:
df = pd.DataFrame.from_dict(neuron_coverages_by_class)
df.to_clipboard(excel=True)

## Create Extreme Case Inputs

In [41]:
input_shape = (batch_size, 3, 32, 32)
all_blacks = torch.zeros(input_shape).to(device)
all_whites = torch.ones(input_shape).to(device)

In [42]:
covered_neurons, total_neurons, neuron_coverage = eval_nc(model, all_blacks, 0.0)
print('all_blacks neuron_coverage', neuron_coverage)

covered_neurons, total_neurons, neuron_coverage = eval_nc(model, all_whites, 0.0)
print('all_whites neuron_coverage', neuron_coverage)

HBox(children=(IntProgress(value=0, max=56), HTML(value='')))


all_blacks neuron_coverage 0.42952919303648895


HBox(children=(IntProgress(value=0, max=56), HTML(value='')))


all_whites neuron_coverage 0.439799808447107
