In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')
import utils.lsun_loader as lsun_loader
import utils.svhn_loader as svhn
from utils.display_results import show_performance, get_measures, print_measures, print_measures_with_std
from PIL import Image as PILImage
from skimage.filters import gaussian as gblur

from CIFAR.models.wrn import WideResNet
from tqdm import tqdm_notebook, tqdm
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as trn
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch
import time
import argparse
import pickle
import os
import numpy as np

In [None]:
!pwd

In [None]:
# !pip install --user lmdb

In [None]:
cd /home/outlier-detection

In [None]:
args = {
    'test_bs': 200,
    'num_to_avg': 1,  # 'Average measures across num_to_avg runs.'
    'validate':'',  
    'use_xent':'',  
    'method_name': 'svhn_wrn_OECC_tune',  # 'Method name.' 
    'layers': 16,
    'widen-factor': 4,
    'droprate': 0.4,
    'load': './SVHN/results',
    'save': './SVHN/results',
    'ngpu': 1,
    'prefetch': 2
}

In [None]:
root_dir = './SVHN'

#torch.manual_seed(1)
#np.random.seed(1)

test_data = svhn.SVHN(root_dir,
                      split='test',
                      transform=trn.ToTensor(),
                      download=False)
num_classes = 10

test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args['test_bs'],
                                          shuffle=False,
                                          num_workers=args['prefetch'],
                                          pin_memory=True)

In [None]:
# Create model
if 'allconv' in args['method_name']:
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args['layers'],
                     num_classes,
                     args['widen-factor'],
                     dropRate=args['droprate'])

start_epoch = 9

if 'baseline' in args['method_name']:
    subdir = 'baseline'
elif 'OECC' in args['method_name']:
    subdir = 'OECC_tune'

f = open(
    os.path.join(os.path.join(args['save'], subdir),
                 args['method_name'] + '_test.txt'), 'w+')

# Restore model
if args['load'] != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(
            os.path.join(args['load'], subdir),
            args['method_name'] + '_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            f.write(f'Model restored! Epoch: {i}')
            start_epoch = i + 1
            break
    if start_epoch == 0:
        assert False, "could not resume"

net.eval()

if args['ngpu'] > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args['ngpu'])))

if args['ngpu'] > 0:
    net.cuda()
    # torch.cuda.manual_seed(1)

cudnn.benchmark = True  # fire on all cylinders

# /////////////// Detection Prelims ///////////////

ood_num_examples = test_data.data.shape[0] // 5
expected_ap = ood_num_examples / (ood_num_examples + test_data.data.shape[0])


def concat(x):
    return np.concatenate(x, axis=0)


def to_np(x):
    return x.data.cpu().numpy()


def get_ood_scores(loader, in_dist=False):
    _score = []
    _right_score = []
    _wrong_score = []

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            if batch_idx >= ood_num_examples // args[
                    'test_bs'] and in_dist is False:
                break

            data = data.cuda()

            output = net(data)
            smax = to_np(F.softmax(output, dim=1))

            if args['use_xent']:
                _score.append(
                    to_np((output.mean(1) - torch.logsumexp(output, dim=1))))
            else:
                _score.append(-np.max(smax, axis=1))

            if in_dist:
                preds = np.argmax(smax, axis=1)
                targets = target.numpy().squeeze()
                right_indices = preds == targets
                wrong_indices = np.invert(right_indices)

                if args['use_xent']:
                    _right_score.append(
                        to_np((output.mean(1) -
                               torch.logsumexp(output, dim=1)))[right_indices])
                    _wrong_score.append(
                        to_np((output.mean(1) -
                               torch.logsumexp(output, dim=1)))[wrong_indices])
                else:
                    _right_score.append(-np.max(smax[right_indices], axis=1))
                    _wrong_score.append(-np.max(smax[wrong_indices], axis=1))

    if in_dist:
        return concat(_score).copy(), concat(_right_score).copy(), concat(
            _wrong_score).copy()
    else:
        return concat(_score)[:ood_num_examples].copy()


in_score, right_score, wrong_score = get_ood_scores(test_loader, in_dist=True)

num_right = len(right_score)
num_wrong = len(wrong_score)
print('Error Rate {:.2f}'.format(100 * num_wrong / (num_wrong + num_right)))
f.write('\nError Rate {:.2f}'.format(100 * num_wrong /
                                     (num_wrong + num_right)))

# /////////////// End Detection Prelims ///////////////

print('\nUsing SVHN as typical data')
f.write('\nUsing SVHN as typical data')

# /////////////// Error Detection ///////////////

# print('\n\nError Detection')
# f.write('\n\nError Detection')
# show_performance(wrong_score, right_score, f, method_name=args['method_name'])

# /////////////// OOD Detection ///////////////
auroc_list, aupr_list, fpr_list = [], [], []


def get_and_print_results(ood_loader, num_to_avg=args['num_to_avg']):

    aurocs, auprs, fprs = [], [], []
    for _ in range(num_to_avg):
        out_score = get_ood_scores(ood_loader)
        measures = get_measures(out_score, in_score)
        aurocs.append(measures[0])
        auprs.append(measures[1])
        fprs.append(measures[2])

    auroc = np.mean(aurocs)
    aupr = np.mean(auprs)
    fpr = np.mean(fprs)
    auroc_list.append(auroc)
    aupr_list.append(aupr)
    fpr_list.append(fpr)

    if num_to_avg >= 5:
        print_measures_with_std(aurocs, auprs, fprs, f, args['method_name'])
    else:
        print_measures(auroc, aupr, fpr, f, args['method_name'])


# /////////////// Gaussian Noise ///////////////

dummy_targets = torch.ones(ood_num_examples * args['num_to_avg'])
ood_data = torch.from_numpy(
    np.float32(
        np.clip(
            np.random.normal(size=(ood_num_examples * args['num_to_avg'], 3,
                                   32, 32),
                             loc=0.5,
                             scale=0.5).astype(np.float32), 0, 1)))
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True)

print('\n\nGaussian Noise Detection (sigma = 0.5)')
f.write('\n\nGaussian Noise Detection (sigma = 0.5)')
get_and_print_results(ood_loader)

# /////////////// Bernoulli  Noise ///////////////

dummy_targets = torch.ones(ood_num_examples * args['num_to_avg'])
ood_data = torch.from_numpy(
    np.random.binomial(n=1,
                       p=0.5,
                       size=(ood_num_examples * args['num_to_avg'], 3, 32,
                             32)).astype(np.float32))
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True)

print('\n\nBernoulli Noise Detection')
f.write('\n\nBernoulli Noise Detection')
get_and_print_results(ood_loader)

# /////////////// Blob ///////////////

ood_data = np.float32(
    np.random.binomial(n=1,
                       p=0.7,
                       size=(ood_num_examples * args['num_to_avg'], 32, 32,
                             3)))
for i in range(ood_num_examples * args['num_to_avg']):
    ood_data[i] = gblur(ood_data[i], sigma=1.5, multichannel=False)
    ood_data[i][ood_data[i] < 0.75] = 0.0

dummy_targets = torch.ones(ood_num_examples * args['num_to_avg'])
ood_data = torch.from_numpy(ood_data.transpose((0, 3, 1, 2)))
ood_data = torch.utils.data.TensorDataset(ood_data, dummy_targets)
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nBlob Detection')
f.write('\n\nBlob Detection')
get_and_print_results(ood_loader)

# /////////////// Icons-50 ///////////////

ood_data = dset.ImageFolder(root="Icons-50",
                            transform=trn.Compose(
                                [trn.Resize((32, 32)),
                                 trn.ToTensor()]))

filtered_imgs = []
for img in ood_data.imgs:
    if 'numbers' not in img[0]:  # img[0] is image name
        filtered_imgs.append(img)
ood_data.imgs = filtered_imgs

ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args["test_bs"],
                                         shuffle=True)

print('\n\nIcons-50 Detection')
f.write('\n\nIcons-50 Detection')
get_and_print_results(ood_loader)

# /////////////// Textures ///////////////

ood_data = dset.ImageFolder(
    root="dtd/images",
    transform=trn.Compose([trn.Resize(32),
                           trn.CenterCrop(32),
                           trn.ToTensor()]))
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nTexture Detection')
f.write('\n\nTexture Detection')
get_and_print_results(ood_loader)

# /////////////// Places365 ///////////////

ood_data = dset.ImageFolder(
    root="Places365",
    transform=trn.Compose([trn.Resize(32),
                           trn.CenterCrop(32),
                           trn.ToTensor()]))
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nPlaces365 Detection')
f.write('\n\nPlaces365 Detection')
get_and_print_results(ood_loader)

# /////////////// LSUN ///////////////

ood_data = lsun_loader.LSUN(
    "lsun_dataset",
    classes='test',
    transform=trn.Compose([trn.Resize(32),
                           trn.CenterCrop(32),
                           trn.ToTensor()]))
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nLSUN Detection')
f.write('\n\nLSUN Detection')
get_and_print_results(ood_loader)

# /////////////// CIFAR Data ///////////////

ood_data = dset.CIFAR10('CIFAR',
                        train=False,
                        transform=trn.ToTensor(),
                        download=False)
ood_loader = torch.utils.data.DataLoader(ood_data,
                                         batch_size=args['test_bs'],
                                         shuffle=True,
                                         num_workers=args['prefetch'],
                                         pin_memory=True)

print('\n\nCIFAR-10 Detection')
f.write('\n\nCIFAR-10 Detection')
get_and_print_results(ood_loader)

# /////////////// Mean Results ///////////////

print('\n\nMean Test Results')
f.write('\n\nMean Test Results')
print_measures(np.mean(auroc_list),
               np.mean(aupr_list),
               np.mean(fpr_list),
               f,
               method_name=args['method_name'])
f.close()