In [1]:
%reload_ext autoreload
%autoreload 2
import os, sys, argparse, time
sys.path.append('..')

import pickle
import torch
import seaborn as sns
import torch.nn as nn
import numpy as np
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from tqdm import tqdm_notebook, tqdm
from skimage.filters import gaussian as gblur
from PIL import Image as PILImage
import seaborn as sns

from CIFAR.models.wrn import WideResNet 
from utils.display_results import show_performance, get_measures, print_measures, print_measures_with_std
import utils.svhn_loader as svhn
import utils.lsun_loader as lsun_loader

In [2]:
cd ..

/home/outlier-detection


In [None]:
 args = {
        'test_bs': 200,
        'num_to_avg': 1, # 'Average measures across num_to_avg runs.' 
        'validate': '', 
        'use_xent': '', 
        'method_name': 'cifar10_wrn_OECC_tune', # 'Method name.'
        'layers': 40,
        'widen-factor': 2,
        'droprate': 0.3,
        'load': './CIFAR/results',
        'save': './CIFAR/results',
        'ngpu': 2,
        'prefetch': 4
        }

In [None]:
root_dir = 'CIFAR'

torch.manual_seed(1)
np.random.seed(1)

# mean and standard deviation of channels of CIFAR-10 images
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

test_transform = trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)])

if 'cifar10_' in args['method_name']:
    test_data = dset.CIFAR10(root_dir,
                             train=False,
                             download=True,
                             transform=test_transform)
    num_classes = 10
else:
    test_data = dset.CIFAR100(root_dir,
                              train=False,
                              download=True,
                              transform=test_transform)
    num_classes = 100

test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args['test_bs'],
                                          shuffle=False,
                                          num_workers=args['prefetch'],
                                          pin_memory=True)

In [None]:
# Create model
if 'allconv' in args['method_name']:
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args['layers'],
                     num_classes,
                     args['widen-factor'],
                     dropRate=args['droprate'])

if args['ngpu'] > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args['ngpu'])))
    
start_epoch = 0

if 'baseline' in args['method_name']:
    subdir = 'baseline'
elif 'OECC' in args['method_name']:
    subdir = 'OECC_tune'

f = open(
    os.path.join(os.path.join(args['save'], subdir),
                 args['method_name'] + '_test.txt'), 'w+')

# Restore model
if args['load'] != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(
            os.path.join(args['load'], subdir),
            args['method_name'] + '_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch: ', i)
            f.write('Model restored! Epoch: {}'.format(i))
            start_epoch = i + 1
            break
    if start_epoch == 0:
        assert False, "could not resume"

net.eval()


if args['ngpu'] > 0:
    torch.cuda.set_device(0)    
    device = torch.device('cuda:0')    
    net = net.cuda()
    # torch.cuda.manual_seed(1)

cudnn.benchmark = True  # fire on all cylinders

# /////////////// Detection Prelims ///////////////

ood_num_examples = len(test_data) // 5
expected_ap = ood_num_examples / (ood_num_examples + len(test_data))

concat = lambda x: np.concatenate(x, axis=0)
to_np = lambda x: x.data.cpu().numpy()


def get_ood_scores(loader, in_dist=False):
    _score = []
    out_conf_score = []
    in_conf_score = []
    _right_score = []
    _wrong_score = []

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            if batch_idx >= ood_num_examples // args[
                    'test_bs'] and in_dist is False:
                break

#             data = data.cuda(device)
            data = data.cuda()

            output = net(data)
            smax = to_np(F.softmax(output, dim=1))

            if args['use_xent']:
                _score.append(
                    to_np((output.mean(1) - torch.logsumexp(output, dim=1))))
            else:
                _score.append(-np.max(smax, axis=1))
                out_conf_score.append(np.max(smax, axis=1))

            if in_dist:
                in_conf_score.append(np.max(smax, axis=1))
                preds = np.argmax(smax, axis=1)
                targets = target.numpy().squeeze()
                right_indices = preds == targets
                wrong_indices = np.invert(right_indices)

                if args['use_xent']:
                    _right_score.append(
                        to_np((output.mean(1) -
                               torch.logsumexp(output, dim=1)))[right_indices])
                    _wrong_score.append(
                        to_np((output.mean(1) -
                               torch.logsumexp(output, dim=1)))[wrong_indices])
                else:
                    _right_score.append(-np.max(smax[right_indices], axis=1))
                    _wrong_score.append(-np.max(smax[wrong_indices], axis=1))

    if in_dist:
        return concat(in_conf_score).copy(), concat(_score).copy(), concat(_right_score).copy(), concat(_wrong_score).copy()
    else:
        return concat(out_conf_score).copy(), concat(_score)[:ood_num_examples].copy()


in_conf_score, in_score, right_score, wrong_score = get_ood_scores(test_loader, in_dist=True)

num_right = len(right_score)
num_wrong = len(wrong_score)
print('Error Rate {:.2f}'.format(100 * num_wrong / (num_wrong + num_right)))
f.write('\nError Rate {:.2f}'.format(100 * num_wrong /
                                     (num_wrong + num_right)))

# /////////////// End Detection Prelims ///////////////

print('\nUsing CIFAR-10 as typical data') if num_classes == 10 else print(
    '\nUsing CIFAR-100 as typical data')
f.write('\nUsing CIFAR-10 as typical data') if num_classes == 10 else f.write(
    '\nUsing CIFAR-100 as typical data')

# /////////////// Error Detection ///////////////

# print('\n\nError Detection')
# f.write('\n\nError Detection')
# show_performance(wrong_score, right_score, f, method_name=args['method_name'])

# /////////////// OOD Detection ///////////////
auroc_list, aupr_list, fpr_list = [], [], []


def get_and_print_results(ood_loader, num_to_avg=args['num_to_avg']):

    aurocs, auprs, fprs = [], [], []
    for _ in range(num_to_avg):
        out_conf_score, out_score = get_ood_scores(ood_loader)
        measures = get_measures(out_score, in_score)
        aurocs.append(measures[0])
        auprs.append(measures[1])
        fprs.append(measures[2])

    auroc = np.mean(aurocs)
    aupr = np.mean(auprs)
    fpr = np.mean(fprs)
    auroc_list.append(auroc)
    aupr_list.append(aupr)
    fpr_list.append(fpr)

    if num_to_avg >= 5:
        print_measures_with_std(aurocs, auprs, fprs, f, args['method_name'])
    else:
        print_measures(auroc, aupr, fpr, f, args['method_name'])
    return out_conf_score    


# /////////////// Gaussian Noise ///////////////

dummy_targets = torch.ones(ood_num_examples * args['num_to_avg'])
ood_data = torch.from_numpy(
    np.float32(
        np.clip(
            np.random.normal(size=(ood_num_examples * args['num_to_avg'], 3,
                                   32, 32),
                             scale=0.5), -1, 1)))
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nGaussian Noise (sigma = 0.5) Detection')
f.write('\n\nGaussian Noise (sigma = 0.5) Detection')
get_and_print_results(ood_loader)

#/////////////// Rademacher Noise ///////////////

dummy_targets = torch.ones(ood_num_examples * args['num_to_avg'])
ood_data = torch.from_numpy(
    np.random.binomial(
        n=1, p=0.5, size=(ood_num_examples * args['num_to_avg'], 3, 32,
                          32)).astype(np.float32)) * 2 - 1
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True)

print('\n\nRademacher Noise Detection')
f.write('\n\nRademacher Noise Detection')
get_and_print_results(ood_loader)

# /////////////// Blob ///////////////

ood_data = np.float32(
    np.random.binomial(n=1,
                       p=0.7,
                       size=(ood_num_examples * args['num_to_avg'], 32, 32,
                             3)))
for i in range(ood_num_examples * args['num_to_avg']):
    ood_data[i] = gblur(ood_data[i], sigma=1.5, multichannel=False)
    ood_data[i][ood_data[i] < 0.75] = 0.0

dummy_targets = torch.ones(ood_num_examples * args['num_to_avg'])
ood_data = torch.from_numpy(ood_data.transpose((0, 3, 1, 2))) * 2 - 1
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nBlob Detection')
f.write('\n\nBlob Detection')
get_and_print_results(ood_loader)

# /////////////// Textures ///////////////

ood_data = dset.ImageFolder(root="./dtd/images",
                            transform=trn.Compose([
                                trn.Resize(32),
                                trn.CenterCrop(32),
                                trn.ToTensor(),
                                trn.Normalize(mean, std)
                            ]))
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nTexture Detection')
f.write('\n\nTexture Detection')
texture_out_score = get_and_print_results(ood_loader)

# /////////////// SVHN ///////////////

ood_data = svhn.SVHN(root='SVHN',
                     split="test",
                     transform=trn.Compose([
                         trn.Resize(32),
                         trn.ToTensor(),
                         trn.Normalize(mean, std)
                     ]),
                     download=True)
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nSVHN Detection')
f.write('\n\nSVHN Detection')
svhn_out_score = get_and_print_results(ood_loader)

# /////////////// Places365 ///////////////
ood_data = dset.ImageFolder(root="./Places365/",
                            transform=trn.Compose([
                                trn.Resize(32),
                                trn.CenterCrop(32),
                                trn.ToTensor(),
                                trn.Normalize(mean, std)
                            ]))
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nPlaces365 Detection')
f.write('\n\nPlaces365 Detection')
places_out_score = get_and_print_results(ood_loader)

# /////////////// LSUN ///////////////

ood_data = lsun_loader.LSUN("./lsun_dataset",
                            classes='test',
                            transform=trn.Compose([
                                trn.Resize(32),
                                trn.CenterCrop(32),
                                trn.ToTensor(),
                                trn.Normalize(mean, std)
                            ]))
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nLSUN Detection')
f.write('\n\nLSUN Detection')
lsun_out_score = get_and_print_results(ood_loader)

# /////////////// CIFAR Data ///////////////

train_transform = trn.Compose([
    trn.RandomHorizontalFlip(),
    trn.RandomCrop(32, padding=4),
    trn.ToTensor(),
    trn.Normalize(mean, std)
])

if 'cifar10_' in args['method_name']:
    ood_data = dset.CIFAR100(root_dir,
                             train=False,
                             download=True,
                             transform=train_transform)
else:
    ood_data = dset.CIFAR10(root_dir,
                            train=False,
                            download=True,
                            transform=train_transform)

ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print(
    '\n\nCIFAR-100 Detection') if 'cifar10_' in args['method_name'] else print(
        '\n\nCIFAR-10 Detection')
f.write('\n\nCIFAR-100 Detection'
        ) if 'cifar10_' in args['method_name'] else f.write(
            '\n\nCIFAR-10 Detection')
get_and_print_results(ood_loader)

# /////////////// Mean Results ///////////////

print('\n\nMean Test Results')
f.write('\n\nMean Test Results')
print_measures(np.mean(auroc_list),
               np.mean(aupr_list),
               np.mean(fpr_list),
               f,
               method_name=args['method_name'])
f.close()

In [None]:
import matplotlib.pyplot as plt
plt.style.use('seaborn-notebook')
plt.hist(in_conf_score[:1000], bins=10)
plt.hist(places_out_score[:1000], bins=10)
plt.legend(labels=['cifar10', 'places365'],loc='upper center', prop={'size': 20})
plt.ylabel('Number of examples', fontdict={'fontsize': 20})
plt.xlabel('Softmax probablities', fontdict={'fontsize': 20})
plt.xticks(size=20)
plt.yticks(size=20)
# plt.figure(background_color='white')
plt.savefig('ours_cifar10_places365_distribution.png')