In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import sys
sys.path.append('/content/drive/MyDrive/ECE 661/ECE 661 Project')

Mounted at /content/drive


In [None]:
from resnet20 import ResNetCIFAR
import torch
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, Sampler
import torchvision
import torch.nn as nn
from FP_layers import *
import sklearn.metrics as metrics
import matplotlib.pyplot as plt
import time
import torch.optim as optim
import torch.nn.functional as F

In [None]:
class CustomDataset(Dataset):
  def __init__(self, values):
    super(CustomDataset, self).__init__()
    self.values = values
  def __len__(self):
    return len(self.values)
  def __getitem__(self, index):
    return self.values[index], 0

In [None]:
class CIFARSampler(Sampler):
    def __init__(self, dataset, lbl_low_bd, lbl_up_bd):
        self.dataset = dataset
        self.lower_bound = lbl_low_bd
        self.upper_bound = lbl_up_bd

    def __iter__(self):
        labels = torch.Tensor(self.dataset.targets)
        idxs = torch.where((labels >= self.lower_bound) & (labels <= self.upper_bound))[0]
        return iter(idxs)

In [None]:
def results(y_true, y_pred, plot_roc=False, plot_pr=False, title='', pl=1):
    # ROC
    fpr, tpr, threshold = metrics.roc_curve(y_true, y_pred, pos_label=pl)
    auroc = metrics.auc(fpr, tpr)
    fpr_at_95_tpr = fpr[tpr.tolist().index(min(tpr, key=lambda x:abs(x-0.95)))]
    if plot_roc:
        plt.figure(dpi=100)
        plt.plot(fpr, tpr, label='AUC = %0.3f' % auroc)
        plt.plot([0, 1], [0, 1], c='gray', linestyle='--', label='Chance')
        plt.xlim([-0.05, 1.05])
        plt.ylim([-0.05, 1.05])
        plt.title('ROC ' + title)
        plt.xlabel('FPR')
        plt.ylabel('TPR')
        plt.legend(loc='lower right')
        plt.gca().set_aspect('equal')
        plt.show()

    # PR
    precision, recall, thresholds = metrics.precision_recall_curve(y_true, y_pred, pos_label=pl)
    aupr = metrics.auc(recall, precision)
    id_pct = sum(y_true)/len(y_true)
    if plot_pr:
        plt.figure(dpi=100)
        plt.plot(recall, precision, label='AUC = %0.3f' % aupr)
        plt.plot([0, 1], [id_pct, id_pct], c='gray', linestyle='--', label='Chance')
        plt.xlim([-0.05, 1.05])
        plt.ylim([-0.05, 1.05])
        plt.title('PR ' + title)
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.legend(loc='lower right')
        plt.gca().set_aspect('equal')
        plt.show()

    return aupr, fpr_at_95_tpr

In [None]:
def train(net, optimizer, scheduler, trainloader_in, trainloader_out):
    net.train()

    for in_set, out_set in zip(trainloader_in, trainloader_out):
        inputs = torch.cat((in_set[0], out_set[0]), 0).to(device)
        targets = in_set[1].to(device)

        optimizer.zero_grad()
        outputs = net(inputs)

        loss = F.cross_entropy(outputs[:len(in_set[0])], targets)
        loss += 0.5 * -(outputs[len(in_set[0]):].mean(1) - torch.logsumexp(outputs[len(in_set[0]):], dim=1)).mean()

        loss.backward()
        optimizer.step()

    return net

In [None]:
def test(net, testloader_in, testloader_out, test_type, T=1000):
    net.eval()

    testloaders = [testloader_in, testloader_out]
    softmax = nn.Softmax(dim=1)
    scores = torch.tensor(()).to(device)
    dist_labels = torch.tensor(()).to(device)

    with torch.no_grad():
        for i in range(2):
            testloader = testloaders[i]
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)

                if test_type == 'baseline':
                  soft_probs = softmax(outputs)
                  scores_batch, _ = soft_probs.max(1)

                if test_type == 'ODIN':
                  soft_probs = softmax(outputs / T)
                  scores_batch, _ = soft_probs.max(1)

                if test_type == 'energy':
                    scores_batch = torch.logsumexp(outputs, dim=1)

                if i == 0:
                    dist_labels_batch = torch.ones(scores_batch.size()).to(device)
                else:
                    dist_labels_batch = torch.zeros(scores_batch.size()).to(device)

                scores = torch.cat((scores, scores_batch), 0)
                dist_labels = torch.cat((dist_labels, dist_labels_batch), 0)

    return scores, dist_labels

In [None]:
def train_and_test(net, trainloader_in, trainloader_out, testloader_in, testloader_out, path, epochs=20, lr=0.1, momentum=0.875, decay=0.0005, nest=False,):
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=decay, nesterov=nest)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(epochs*0.5), int(epochs*0.75)], gamma=0.1)

    start = time.time()

    for epoch in range(epochs):
        net = train(net, optimizer, scheduler, trainloader_in, trainloader_out)
        logits, labels = test(net, testloader_in, testloader_out, 'baseline')
        logits_np, labels_np = logits.cpu().detach().numpy(), labels.cpu().detach().numpy()
        aupr, _ = results(labels_np, logits_np)
        end = time.time()
        print('Epoch:', epoch, ', AUPR:', aupr, 'Time:', end-start)

        torch.save(net.state_dict(), path)

In [None]:
# create ID and OOD testloaders
def select_testloaders(oe=False, ood_set='CIFAR', LSUN_type='random'):
  transform_CIFAR = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  
  transform_LSUN_random = transforms.Compose([
        transforms.RandomCrop(32), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  
  transform_LSUN_scale = transforms.Compose([
        transforms.Resize((32,32)), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  
  transform_MNIST = transforms.Compose([
        transforms.Pad(2), transforms.Grayscale(3), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

  # ID set is always CIFAR-10
  testset_in = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_CIFAR)

  sampler = None
  if oe:
    sampler = CIFARSampler(testset_in, 1, 5)
  testloader_in = torch.utils.data.DataLoader(testset_in, batch_size=100, shuffle=False, num_workers=1, sampler=sampler)
  
  # OOD test set: OE uses 6-10 from CIFAR-10, others use CIFAR-100
  if ood_set=='CIFAR':
      if oe:
        testset_out = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_CIFAR)
        testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=1, sampler=CIFARSampler(testset_in, 6, 10))
      else:
        testset_out = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_CIFAR)
        testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=1)

  if ood_set == 'SVHN':
    testset_out = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_CIFAR)
    testset_out = torch.utils.data.Subset(testset_out, np.random.choice(len(testset_out), 10000, replace=False))
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2)

  if ood_set == 'LSUN':
    if LSUN_type == 'random':
      testset_out = torchvision.datasets.LSUN(root='/content/drive/MyDrive/ECE 661/ECE 661 Project/LSUN_data', classes='test', transform=transform_LSUN_random)
    elif LSUN_type == 'scale':
      testset_out = torchvision.datasets.LSUN(root='/content/drive/MyDrive/ECE 661/ECE 661 Project/LSUN_data', classes='test', transform=transform_LSUN_scale)
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2)
  
  if ood_set == 'MNIST':
    testset_out = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_MNIST)
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2)

  if ood_set == 'uni':
    values = torch.rand((10000, 3, 32, 32))
    testset_out = CustomDataset(values)
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2)

  if ood_set == 'gauss':
    values = torch.normal(0.5, 1, size=(10000, 3, 32, 32))
    values = torch.clamp(values, min=0, max=1)
    testset_out = CustomDataset(values)
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2)

  return testloader_in, testloader_out

In [None]:
# instantiate a resnet, load the appropriate model checkpoint
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Training Datasets
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset_in = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainset_out = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader_in = torch.utils.data.DataLoader(trainset_in, batch_size=100, shuffle=False, num_workers=1, sampler=CIFARSampler(trainset_in, 1, 5))

# choose the test OOD set
testloader_in, testloader_out = select_testloaders(oe=True, ood_set='SVHN')

train and provide AUPR results for the OE model if needed
samplers = [CIFARSampler(trainset_out, 1, 1),
            CIFARSampler(trainset_out, 1, 5),
            CIFARSampler(trainset_out, 1, 20),
            CIFARSampler(trainset_out, 1, 50),
            None]

save_paths = ['/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/new_oe_model_1.pt',
              '/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/new_oe_model_5.pt',
              '/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/new_oe_model_20.pt',
              '/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/new_oe_model_50.pt',
              '/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/new_oe_model_100.pt']

for sampler, save_path in zip(samplers, save_paths):
  print('START TRAINING:', save_path)
  net = ResNetCIFAR(num_layers=20).to(device)
  trainloader_out = torch.utils.data.DataLoader(trainset_out, batch_size=100, shuffle=False, num_workers=1, sampler=sampler)
  train_and_test(net, trainloader_in, trainloader_out, testloader_in, testloader_out, save_path)
  print('TRAINING COMPLETED:', save_path)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/test_32x32.mat
START TRAINING: /content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/oe_model_1.pt
Epoch: 0 , AUPR: 0.3746976103045627 Time: 8.223967552185059
Epoch: 1 , AUPR: 0.6235124011602367 Time: 18.100631713867188
Epoch: 2 , AUPR: 0.5889103764681903 Time: 26.110844135284424
Epoch: 3 , AUPR: 0.46907956565496944 Time: 34.2521755695343
Epoch: 4 , AUPR: 0.6040973839828401 Time: 42.35422468185425
Epoch: 5 , AUPR: 0.44140061498658856 Time: 50.40211892127991
Epoch: 6 , AUPR: 0.37089111892338694 Time: 59.85407280921936
Epoch: 7 , AUPR: 0.48028023040879864 Time: 67.83442044258118
Epoch: 8 , AUPR: 0.5222187415682153 Time: 75.8900351524353
Epoch: 9 , AUPR: 0.46570401310107223 Time: 84.01820611953735
Epoch: 10 , AUPR: 0.42861673547323503 Time: 92.18977546691895
Epoch: 11 , AUPR: 0.46086076402049025 Time: 100.8061990737915
Epoc

In [None]:
# Test with the given method
logits, labels = test(net, testloader_in, testloader_out, 'baseline')

print('Logits:', logits.shape, logits)
print('Labels:', labels.shape, labels)

In [None]:
# find and display results
labels_np = labels.cpu().detach().numpy()
logits_np = logits.cpu().detach().numpy()
title = 'Baseline (ID: CIFAR-10, OOD: CIFAR-100)'

aupr, fpr_at_95_tpr = results(labels_np, logits_np, False, False, title)
print('AUPR: {:.4}'.format(aupr))
print('FPR@95TPR: {:.4}'.format(fpr_at_95_tpr))

In [None]:
checkpoints = ['/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/pretrained_model.pt',
              '/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/oe_model_1.pt',
              '/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/oe_model_5.pt',
              '/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/oe_model_20.pt',
              '/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/oe_model_50.pt',
              '/content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/oe_model_100.pt']
ood_sets = ['CIFAR', 'SVHN', 'LSUN', 'MNIST', 'uni', 'gauss']

for checkpoint in checkpoints:
  print('MODEL:', checkpoint)
  print('')
  net = ResNetCIFAR(num_layers=20).to(device)
  net.load_state_dict(torch.load(checkpoint))
  for ood_set in ood_sets:
    print('DATASET:', ood_set)
    testloader_in, testloader_out = select_testloaders(oe=True, ood_set=ood_set)
    logits, labels = test(net, testloader_in, testloader_out, 'baseline')
    labels_np = labels.cpu().detach().numpy()
    logits_np = logits.cpu().detach().numpy()
    aupr, fpr_at_95_tpr = results(labels_np, logits_np)
    print('AUPR: {:.4}'.format(aupr))
    print('FPR@95TPR: {:.4}'.format(fpr_at_95_tpr))
  print('')
  print('')

MODEL: /content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/pretrained_model.pt

DATASET: CIFAR
Files already downloaded and verified
Files already downloaded and verified
AUPR: 0.5348
FPR@95TPR: 0.9722
DATASET: SVHN
Files already downloaded and verified
Using downloaded and verified file: ./data/test_32x32.mat
AUPR: 0.8457
FPR@95TPR: 0.7547
DATASET: LSUN
Files already downloaded and verified
AUPR: 0.8676
FPR@95TPR: 0.5801
DATASET: MNIST
Files already downloaded and verified
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

AUPR: 0.8634
FPR@95TPR: 0.6403
DATASET: uni
Files already downloaded and verified
AUPR: 0.8036
FPR@95TPR: 0.9339
DATASET: gauss
Files already downloaded and verified
AUPR: 0.7572
FPR@95TPR: 0.937


MODEL: /content/drive/MyDrive/ECE 661/ECE 661 Project/Model Checkpoints/oe_model_1.pt

DATASET: CIFAR
Files already downloaded and verified
Files already downloaded and verified
AUPR: 0.6492
FPR@95TPR: 0.9517
DATASET: SVHN
Files already downloaded and verified
Using downloaded and verified file: ./data/test_32x32.mat
AUPR: 0.2971
FPR@95TPR: 0.9092
DATASET: LSUN
Files already downloaded and verified
AUPR: 0.3966
FPR@95TPR: 0.8629
DATASET: MNIST
Files already downloaded and verified
AUPR: 0.8015
FPR@95TPR: 0.816
DATASET: uni
Files already downloaded and verified
AUPR: 0.712
FPR@95TPR: 1.0
DATASET: gauss
Files already downloaded and verified
AUPR: 0.3113
FPR@95TPR: 1.0


MODEL: /content/drive/MyDrive/ECE 661/ECE 661 Proj