In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import sys
sys.path.append('/content/drive/MyDrive/Duke/2022-2023/ECE 661/ECE 661 Project')

Mounted at /content/drive


In [None]:
from resnet20 import ResNetCIFAR
import torch
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, Sampler
import torchvision
import torch.nn as nn
from FP_layers import *
import sklearn.metrics as metrics
import matplotlib.pyplot as plt

In [None]:
class CIFARSampler(Sampler):
    def __init__(self, dataset, lbl_low_bd, lbl_up_bd):
        self.dataset = dataset
        self.n = len(dataset)
        self.lower_bound = lbl_low_bd
        self.upper_bound = lbl_up_bd

    def __iter__(self):
        idxs = torch.where(((self.dataset.targets >= self.lower_bound) and (self.dataset.targets <= self.upper_bound)))
        return iter(idxs)

In [None]:
def results(y_true, y_pred, plot_roc=False, plot_pr=False, title='', pl=1):
    # ROC
    fpr, tpr, threshold = metrics.roc_curve(y_true, y_pred, pos_label=pl)
    auroc = metrics.auc(fpr, tpr)
    fpr_at_95_tpr = fpr[tpr.tolist().index(min(tpr, key=lambda x:abs(x-0.95)))]
    if plot_roc:
        plt.figure(dpi=100)
        plt.plot(fpr, tpr, label='AUC = %0.3f' % auroc)
        plt.plot([0, 1], [0, 1], c='gray', linestyle='--', label='Chance')
        plt.xlim([-0.05, 1.05])
        plt.ylim([-0.05, 1.05])
        plt.title('ROC ' + title)
        plt.xlabel('FPR')
        plt.ylabel('TPR')
        plt.legend(loc='lower right')
        plt.gca().set_aspect('equal')
        plt.show()

    # PR
    precision, recall, thresholds = metrics.precision_recall_curve(y_true, y_pred, pos_label=pl)
    aupr = metrics.auc(recall, precision)
    id_pct = sum(y_true)/len(y_true)
    if plot_pr:
        plt.figure(dpi=100)
        plt.plot(recall, precision, label='AUC = %0.3f' % aupr)
        plt.plot([0, 1], [id_pct, id_pct], c='gray', linestyle='--', label='Chance')
        plt.xlim([-0.05, 1.05])
        plt.ylim([-0.05, 1.05])
        plt.title('PR ' + title)
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.legend(loc='lower right')
        plt.gca().set_aspect('equal')
        plt.show()

    return aupr, fpr_at_95_tpr

In [None]:
def train(net, optimizer, scheduler, trainloader_in, trainloader_out):
    net.train()

    for in_set, out_set in zip(trainloader_in, trainloader_out):
        inputs = torch.cat((in_set[0], out_set[0]), 0).to(device)
        targets = in_set[1].to(device)

        outputs = net(inputs)
        scheduler.step()
        optimizer.zero_grad()

        loss = F.cross_entropy(outputs[:len(in_set[0])], targets)
        term1 = outputs[len(in_set[0]):].mean(1)
        term2 = torch.logsumexp(outputs[len(in_set[0]):], dim=1)
        loss += 0.5 * -(term1 - term2).mean()
        
        loss.backward()
        optimizer.step()

    return net

In [None]:
def test(net, testloader_in, testloader_out, test_type, T=1000):
    net.eval()

    testloaders = [testloader_in, testloader_out]
    softmax = nn.Softmax(dim=1)
    scores = torch.tensor(()).to(device)
    dist_labels = torch.tensor(()).to(device)

    with torch.no_grad():
        for i in range(2):
            testloader = testloaders[i]
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)

                if test_type == 'baseline':
                  soft_probs = softmax(outputs)
                  scores_batch, _ = soft_probs.max(1)

                if test_type == 'ODIN':
                  soft_probs = softmax(outputs / T)
                  scores_batch, _ = soft_probs.max(1)

                if test_type == 'energy':
                    scores_batch = torch.logsumexp(outputs, dim=1)

                if i == 0:
                    dist_labels_batch = torch.ones(scores_batch.size()).to(device)
                else:
                    dist_labels_batch = torch.zeros(scores_batch.size()).to(device)

                scores = torch.cat((scores, scores_batch), 0)
                dist_labels = torch.cat((dist_labels, dist_labels_batch), 0)

    return scores, dist_labels

In [None]:
def train_and_test(net, trainloader_in, trainloader_out, testloader_in, testloader_out, path, epochs=10, lr=0.1, momentum=0.875, decay=0.0005, nest=False,):
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=decay, nesterov=nest)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(epochs*0.5), int(epochs*0.75)], gamma=0.1)

    start = time.time()

    for epoch in range(epochs):
        net = train(net, optimizer, scheduler, trainloader_in, trainloader_out)
        logits, labels = test(net, testloader_in, testloader_out, 'baseline')
        logits_np, labels_np = logits.cpu().detach().numpy(), labels.cpu().detach().numpy()
        aupr, _ = results(labels_np, logits_np)
        end = time.time()
        print('Epoch:', epoch, ', AUPR:', aupr, 'Time:', end-start)

        torch.save(net.state_dict(), path)

In [None]:
# Training Datasets for OE
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset_in = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainset_out = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader_in = torch.utils.data.DataLoader(trainset_in, batch_size=100, shuffle=False, num_workers=2, sampler=CIFARSampler(trainset_in, 1, 5))
trainloader_out = torch.utils.data.DataLoader(trainset_out, batch_size=100, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data


In [None]:
class CustomSet(Dataset):
  def __init__(self, values):
    super(CustomSet, self).__init__()
    self.values = values
  def __len__(self):
    return len(self.values)
  def __getitem__(self, index):
    return self.values[index], 0

In [None]:
# create ID and OOD testloaders
def select_testloaders(oe=False, ood_set='CIFAR', LSUN_type='random'):
  transform_CIFAR = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  
  transform_LSUN_random = transforms.Compose([
        transforms.RandomCrop(32), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  
  transform_LSUN_scale = transforms.Compose([
        transforms.Resize((32,32)), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  
  transform_MNIST = transforms.Compose([
        transforms.Pad(2), transforms.Grayscale(3), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  
  transform_rand = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

  # ID set is always CIFAR-10
  testset_in = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_CIFAR)

  sampler = None
  if oe:
    sampler = CIFARSampler(testset_in, 1, 5)
  testloader_in = torch.utils.data.DataLoader(testset_in, batch_size=100, shuffle=False, num_workers=2, sampler=sampler)
  
  if ood_set=='CIFAR':
    testset_out = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_CIFAR)
    if oe:
      sampler = CIFARSampler(testset_in, 6, 10)
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2, sampler=sampler)

  if ood_set == 'SVHN':
    testset_out = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_CIFAR)
    testset_out = torch.utils.data.Subset(testset_out, np.random.choice(len(testset_out), 10000, replace=False))
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2)

  if ood_set == 'LSUN':
    if LSUN_type == 'random':
      testset_out = torchvision.datasets.LSUN(root='/content/drive/MyDrive/Duke/2022-2023/ECE 661/ECE 661 Project/LSUN_data', classes='test', transform=transform_LSUN_random)
    elif LSUN_type == 'scale':
      testset_out = torchvision.datasets.LSUN(root='/content/drive/MyDrive/Duke/2022-2023/ECE 661/ECE 661 Project/LSUN_data', classes='test', transform=transform_LSUN_scale)
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2)
  
  if ood_set == 'MNIST':
    testset_out = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_MNIST)
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2)

  if ood_set == 'uni':
    values = torch.rand((10000, 3, 32, 32))
    testset_out = CustomSet(values)
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2)

  if ood_set == 'gauss':
    values = torch.normal(0.5, 1, size=(10000, 3, 32, 32))
    values = torch.clamp(values, min=0, max=1)
    testset_out = CustomSet(values)
    testloader_out = torch.utils.data.DataLoader(testset_out, batch_size=100, shuffle=False, num_workers=2)

  return testloader_in, testloader_out

In [None]:
# instantiate a resnet, load the appropriate model checkpoint
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = ResNetCIFAR(num_layers=34)
net = net.to(device)
path = '/content/drive/MyDrive/Duke/2022-2023/ECE 661/ECE 661 Project/Model Checkpoints/trained_model_34.pt'
net.load_state_dict(torch.load(path))

<All keys matched successfully>

In [None]:
# choose the test OOD set
testloader_in, testloader_out = select_testloaders(oe=False, ood_set='gauss', LSUN_type='scale')

Files already downloaded and verified


In [None]:
# train and provide AUPR results for the OE model if needed
#train_and_test(net, trainloader_in, trainloader_out, testloader_in, testloader_out, path)

In [None]:
# Test with the given method
logits, labels = test(net, testloader_in, testloader_out, 'baseline')

print('Logits:', logits.shape, logits)
print('Labels:', labels.shape, labels)

Logits: torch.Size([20000]) tensor([0.9974, 0.9999, 0.9009,  ..., 0.4394, 0.5017, 0.4633], device='cuda:0')
Labels: torch.Size([20000]) tensor([1., 1., 1.,  ..., 0., 0., 0.], device='cuda:0')


In [None]:
# find and display results
labels_np = labels.cpu().detach().numpy()
logits_np = logits.cpu().detach().numpy()
title = 'Baseline (ID: CIFAR-10, OOD: CIFAR-100)'

aupr, fpr_at_95_tpr = results(labels_np, logits_np, False, False, title)
print('AUPR: {:.4}'.format(aupr))
print('FPR@95TPR: {:.4}'.format(fpr_at_95_tpr))

AUPR: 0.9709
FPR@95TPR: 0.2763
