In [None]:
%reload_ext autoreload
%autoreload 2
import os, sys, argparse, time
sys.path.append('..')

import pickle
import torch
import seaborn as sns
import torch.nn as nnll
import numpy as np
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from tqdm import tqdm_notebook, tqdm
from skimage.filters import gaussian as gblur
from PIL import Image as PILImage
import seaborn as sns

from CIFAR.models.wrn import WideResNet 
from utils.display_results import show_performance, get_measures, print_measures, print_measures_with_std
import utils.svhn_loader as svhn
import utils.lsun_loader as lsun_loader

In [None]:
 args = {
        'test_bs': 200,
        'num_to_avg': 1, # 'Average measures across num_to_avg runs.' 
        'validate': '', 
        'use_xent': '',  
        'method_name': 'cifar10_wrn_OECC_tune', # 'Method name.'
        'layers': 40,
        'widen-factor': 2,
        'droprate': 0.3,
        'load': './results',
        'save': './results',
        'ngpu': 2,
        'prefetch': 4
        }

In [None]:
root_dir = './'

torch.manual_seed(1)
np.random.seed(1)

# mean and standard deviation of channels of CIFAR-10 images
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

test_transform = trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)])

if 'cifar10_' in args['method_name']:
    test_data = dset.CIFAR10(root_dir,
                             train=False,
                             download=True,
                             transform=test_transform)
    num_classes = 10
else:
    test_data = dset.CIFAR100(root_dir,
                              train=False,
                              download=True,
                              transform=test_transform)
    num_classes = 100

test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args['test_bs'],
                                          shuffle=False,
                                          num_workers=args['prefetch'],
                                          pin_memory=True)

In [None]:
# Create model
net = WideResNet(args['layers'], num_classes, args['widen_factor'], dropRate=args['droprate'])
    
device = torch.device('cuda:0')    
net = torch.nn.DataParallel(net, device_ids=list(range(args['ngpu']))).cuda()
cudnn.benchmark = True  # fire on all cylinders

start_epoch = 9

if 'baseline' in args['method_name']:
    subdir = 'baseline'
elif 'OECC' in args['method_name']:
    subdir = 'OECC_tune'

f = open(
    os.path.join(os.path.join(args['save'], subdir),
                 args['method_name'] + '_test.txt'), 'w+')

# Restore model
if args['load'] != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(
            os.path.join(args['load'], subdir),
            args['method_name'] + '_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch: ', i)
            f.write('Model restored! Epoch: {}'.format(i))
            start_epoch = i + 1
            break
    if start_epoch == 0:
        assert False, "could not resume"

net.eval()
    

    


# /////////////// Detection Prelims ///////////////

ood_num_examples = len(test_data) // 5
expected_ap = ood_num_examples / (ood_num_examples + len(test_data))

concat = lambda x: np.concatenate(x, axis=0)
to_np = lambda x: x.data.cpu().numpy()


def get_ood_scores(loader, in_dist=False):
    _score = []
    out_conf_score = []
    in_conf_score = []
    _right_score = []
    _wrong_score = []

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            if batch_idx >= ood_num_examples // args[
                    'test_bs'] and in_dist is False:
                break

            data = data.cuda()

            output = net(data)
            smax = to_np(F.softmax(output, dim=1))

            if args['use_xent']:
                _score.append(
                    to_np((output.mean(1) - torch.logsumexp(output, dim=1))))
            else:
                _score.append(-np.max(smax, axis=1))
                out_conf_score.append(np.max(smax, axis=1))

            if in_dist:
                in_conf_score.append(np.max(smax, axis=1))
                preds = np.argmax(smax, axis=1)
                targets = target.numpy().squeeze()
                right_indices = preds == targets
                wrong_indices = np.invert(right_indices)

                if args['use_xent']:
                    _right_score.append(
                        to_np((output.mean(1) -
                               torch.logsumexp(output, dim=1)))[right_indices])
                    _wrong_score.append(
                        to_np((output.mean(1) -
                               torch.logsumexp(output, dim=1)))[wrong_indices])
                else:
                    _right_score.append(-np.max(smax[right_indices], axis=1))
                    _wrong_score.append(-np.max(smax[wrong_indices], axis=1))

    if in_dist:
        return concat(in_conf_score).copy(), concat(_score).copy(), concat(_right_score).copy(), concat(_wrong_score).copy()
    else:
        return concat(out_conf_score).copy(), concat(_score)[:ood_num_examples].copy()


in_conf_score, in_score, right_score, wrong_score = get_ood_scores(test_loader, in_dist=True)

num_right = len(right_score)
num_wrong = len(wrong_score)
print('Error Rate {:.2f}'.format(100 * num_wrong / (num_wrong + num_right)))
# f.write('\nError Rate {:.2f}'.format(100 * num_wrong /
#                                      (num_wrong + num_right)))

# /////////////// End Detection Prelims ///////////////

print('\nUsing CIFAR-10 as typical data') if num_classes == 10 else print(
    '\nUsing CIFAR-100 as typical data')
f.write('\nUsing CIFAR-10 as typical data') if num_classes == 10 else f.write(
    '\nUsing CIFAR-100 as typical data')

auroc_list, aupr_list, fpr_list = [], [], []


def get_and_print_results(ood_loader, num_to_avg=args['num_to_avg']):

    aurocs, auprs, fprs = [], [], []
    for _ in range(num_to_avg):
        out_conf_score, out_score = get_ood_scores(ood_loader)
        measures = get_measures(out_score, in_score)
        aurocs.append(measures[0])
        auprs.append(measures[1])
        fprs.append(measures[2])

    auroc = np.mean(aurocs)
    aupr = np.mean(auprs)
    fpr = np.mean(fprs)
    auroc_list.append(auroc)
    aupr_list.append(aupr)
    fpr_list.append(fpr)

    if num_to_avg >= 5:
        print_measures_with_std(aurocs, auprs, fprs, f, args['method_name'])
    else:
        print_measures(auroc, aupr, fpr, f, args['method_name'])
    return out_conf_score

# /////////////// Uniform Noise ///////////////

dummy_targets = torch.ones(ood_num_examples * args['num_to_avg'])
ood_data = torch.from_numpy(
    np.random.uniform(size=(ood_num_examples * args['num_to_avg'], 3, 32, 32),
                      low=-1.0, high=1.0).astype(np.float32))
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args['test_bs'], shuffle=True)

print('\n\nUniform[-1,1] Noise Detection')
get_and_print_results(ood_loader)


# /////////////// Arithmetic Mean of Images ///////////////

if 'cifar100' in args['method_name']:
    ood_data = dset.CIFAR100('.', download=True, train=False, transform=test_transform)
else:
    ood_data = dset.CIFAR10('.', download=True,train=False, transform=test_transform)


class AvgOfPair(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.shuffle_indices = np.arange(len(dataset))
        np.random.shuffle(self.shuffle_indices)

    def __getitem__(self, i):
        random_idx = np.random.choice(len(self.dataset))
        while random_idx == i:
            random_idx = np.random.choice(len(self.dataset))

        return self.dataset[i][0] / 2. + self.dataset[random_idx][0] / 2., 0

    def __len__(self):
        return len(self.dataset)


ood_loader = torch.utils.data.DataLoader(AvgOfPair(ood_data),
                                         batch_size=args['test_bs'], shuffle=True,
                                         num_workers=args['prefetch'], pin_memory=True)

print('\n\nArithmetic Mean of Random Image Pair Detection')
get_and_print_results(ood_loader)


# /////////////// Geometric Mean of Images ///////////////

if 'cifar10_' in args['method_name']:
    ood_data = dset.CIFAR100('.', download=True, train=False, transform=trn.ToTensor())
else:
    ood_data = dset.CIFAR10('.', download=True,train=False, transform=trn.ToTensor())


class GeomMeanOfPair(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.shuffle_indices = np.arange(len(dataset))
        np.random.shuffle(self.shuffle_indices)

    def __getitem__(self, i):
        random_idx = np.random.choice(len(self.dataset))
        while random_idx == i:
            random_idx = np.random.choice(len(self.dataset))

        return trn.Normalize(mean, std)(torch.sqrt(self.dataset[i][0] * self.dataset[random_idx][0])), 0

    def __len__(self):
        return len(self.dataset)


ood_loader = torch.utils.data.DataLoader(
    GeomMeanOfPair(ood_data), batch_size=args['test_bs'], shuffle=True,
    num_workers=args['prefetch'], pin_memory=True)

print('\n\nGeometric Mean of Random Image Pair Detection')
get_and_print_results(ood_loader)

# /////////////// Jigsaw Images ///////////////

ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args['test_bs'], shuffle=True,
                                         num_workers=args['prefetch'], pin_memory=True)

jigsaw = lambda x: torch.cat((
    torch.cat((torch.cat((x[:, 8:16, :16], x[:, :8, :16]), 1),
               x[:, 16:, :16]), 2),
    torch.cat((x[:, 16:, 16:],
               torch.cat((x[:, :16, 24:], x[:, :16, 16:24]), 2)), 2),
), 1)

ood_loader.dataset.transform = trn.Compose([trn.ToTensor(), jigsaw, trn.Normalize(mean, std)])

print('\n\nJigsawed Images Detection')
get_and_print_results(ood_loader)

# /////////////// Speckled Images ///////////////

speckle = lambda x: torch.clamp(x + x * torch.randn_like(x), 0, 1)
ood_loader.dataset.transform = trn.Compose([trn.ToTensor(), speckle, trn.Normalize(mean, std)])

print('\n\nSpeckle Noised Images Detection')
get_and_print_results(ood_loader)

# /////////////// Pixelated Images ///////////////

pixelate = lambda x: x.resize((int(32 * 0.2), int(32 * 0.2)), PILImage.BOX).resize((32, 32), PILImage.BOX)
ood_loader.dataset.transform = trn.Compose([pixelate, trn.ToTensor(), trn.Normalize(mean, std)])

print('\n\nPixelate Detection')
get_and_print_results(ood_loader)

# /////////////// RGB Ghosted/Shifted Images ///////////////

rgb_shift = lambda x: torch.cat((x[1:2].index_select(2, torch.LongTensor([i for i in range(32 - 1, -1, -1)])),
                                 x[2:, :, :], x[0:1, :, :]), 0)
ood_loader.dataset.transform = trn.Compose([trn.ToTensor(), rgb_shift, trn.Normalize(mean, std)])

print('\n\nRGB Ghosted/Shifted Image Detection')
get_and_print_results(ood_loader)

# /////////////// Inverted Images ///////////////

# not done on all channels to make image ood with higher probability
invert = lambda x: torch.cat((x[0:1, :, :], 1 - x[1:2, :, ], 1 - x[2:, :, :],), 0)
ood_loader.dataset.transform = trn.Compose([trn.ToTensor(), invert, trn.Normalize(mean, std)])

print('\n\nInverted Image Detection')
get_and_print_results(ood_loader)

# /////////////// Mean Results ///////////////

print('\n\nMean Validation Results')
print_measures(np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr_list), f, method_name=args['method_name'])