# Data & Helper Functions

Initializing some helper functions that would be used later during training

In [4]:
'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math
import shutil

import torch.nn as nn
import torch.nn.init as init

seed_v = 42

def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)


_, term_width = shutil.get_terminal_size()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f

# Model 

Here, we define our model - **SimpleDLA class**. In our experiments, we use the DLA (https://arxiv.org/abs/1707.06484) architecture with ZigZag-aligned modifications, specifically with a modified first layer.

In [5]:
'''Simplified version of DLA in PyTorch.
Note this implementation is not identical to the original paper version.
But it seems works fine.
See dla.py for the original paper version.
Reference:
    Deep Layer Aggregation. https://arxiv.org/abs/1707.06484
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Root(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1):
        super(Root, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            stride=1, padding=(kernel_size - 1) // 2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, xs):
        x = torch.cat(xs, 1)
        out = F.relu(self.bn(self.conv(x)))
        return out


class Tree(nn.Module):
    def __init__(self, block, in_channels, out_channels, level=1, stride=1):
        super(Tree, self).__init__()
        self.root = Root(2*out_channels, out_channels)
        if level == 1:
            self.left_tree = block(in_channels, out_channels, stride=stride)
            self.right_tree = block(out_channels, out_channels, stride=1)
        else:
            self.left_tree = Tree(block, in_channels,
                                  out_channels, level=level-1, stride=stride)
            self.right_tree = Tree(block, out_channels,
                                   out_channels, level=level-1, stride=1)

    def forward(self, x):
        out1 = self.left_tree(x)
        out2 = self.right_tree(out1)
        out = self.root([out1, out2])
        return out


class SimpleDLA(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super(SimpleDLA, self).__init__()
        self.base = nn.Sequential(
            nn.Conv2d(4, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.layer1 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True)
        )

        self.layer3 = Tree(block,  32,  64, level=1, stride=1)
        self.layer4 = Tree(block,  64, 128, level=2, stride=2)
        self.layer5 = Tree(block, 128, 256, level=2, stride=2)
        self.layer6 = Tree(block, 256, 512, level=1, stride=2)
        self.linear = nn.Linear(512, num_classes)

    def forward(self, x, y=None):
        out = self.base(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = F.avg_pool2d(out, 4)
        z = out.view(out.size(0), -1)
        out = self.linear(z)
        return out

# Initialize Training

Here, we define the optimization parameters and the functions for **train** and **test**. As described in the paper, the first inference of our method is performed with a "blank" additional input, while the second inference incorporates the class labels.

In [6]:
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import tqdm

    
# Creating our model
device = "cuda"
print('==> Building model..')
net = SimpleDLA()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

# Class that stores hyper-parameters for the optimization
class Args:
  def __init__(self):
    self.lr = 0.001
    self.resume = False
    self.checkpoint = "cifar_zigzag"
    self.batch_size = 56

args = Args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, shuffle=True, num_workers=64)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=args.batch_size, shuffle=False, num_workers=64)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./checkpoint/{args.checkpoint}.pth')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

C = 0 # this parameter defines "blank" value for the input
S = 10 # this parameter defines scale of the additional input

# Training
def train(epoch):

    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    t = tqdm.trange(len(trainset) // args.batch_size + 1, desc='Current Loss = ', leave=True)

    for _, (batch_idx, (inputs, targets)) in zip(t, enumerate(trainloader)):

        inputs, targets = inputs.to(device), targets.to(device)
        batch_size, _, H, W = inputs.shape

        # creating "blank" inputs for the first inferense
        # in this case, it's 4d tensor filled with constant C
        targets_1 = C * torch.ones(batch_size, 1, H, W).to(inputs.device)
        x1 = torch.cat([inputs, targets_1 / S], dim=1)

        # creating additional input for the second inference
        # in this case, this input filled with ground truth classes
        targets_2 = torch.ones(batch_size, 1, H, W).to(inputs.device)
        targets_2 = targets_2 * targets.reshape(-1, 1, 1, 1) + 1
        x2 = torch.cat([inputs, targets_2 / S], dim=1)

        inputs = torch.cat([x1, x2], dim=0)
        targets = torch.cat([targets, targets], dim=0)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        inputs, targets = inputs.to(device), targets.to(device)

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        t.set_description(f"Epoch {epoch} Current Loss = {round(100.*correct/total, 3)}", refresh=True)

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size, _, H, W = inputs.shape

            # creating "blank" inputs for the first inferense
            # in this case, it's 4d tensor filled with constant C
            targets_1 = C * torch.ones(batch_size, 1, H, W).to(inputs.device)
            inputs = torch.cat([inputs, targets_1 / S], dim=1)

            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    acc = 100.*correct/total

    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, f'./checkpoint/{args.checkpoint}.pth')
        best_acc = acc

    print("BEST ACCURACY: ", best_acc, "Current Accuracy: ", acc)


==> Building model..
==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


# Training 

In [None]:
import tqdm
training_epochs = 40

for epoch in range(start_epoch, start_epoch + training_epochs):
    train(epoch)
    test(epoch)
    scheduler.step()
    
for param_group in optimizer.param_groups:
    param_group['lr'] *= 0.01
    
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

for epoch in range(start_epoch, start_epoch + training_epochs):
    train(epoch)
    test(epoch)
    scheduler.step()

# OOD Evaluation

**CIFAR vs. SVHN OOD Evaluation:** In this subsection, we evaluate out-of-distribution (OOD) detection by testing our model trained on CIFAR10 against SVHN images, which contain classes not present in CIFAR10. We perform inference on both datasets and generate uncertainty estimates for each sample. We then compute standard ROC-AUC and PR-AUC metrics to assess in-distribution vs. out-of-distribution classification performance.

In [None]:
# Creating the model
print('==> Building model..')
net = SimpleDLA()
net = net.to(device)
net = torch.nn.DataParallel(net)

### Download pretrained weights

In [None]:
!gdown https://drive.google.com/uc?id=1qSCCl6hMMX65AVXms5kS7SYIA6e0fPkv

In [42]:
# using pretrained weights, you could replace it with your weights file
net.load_state_dict(torch.load(f"./zigzag_pretrained_cifar.pth", map_location='cpu')["net"])

<All keys matched successfully>

### Loading In- and Out-of-distribution Dataset

In [43]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Loading SVHN data as OOD
trainset_svhn = torchvision.datasets.SVHN(
    root='./data', split="test", download=True, transform=transform_test)
trainloader_svhn = torch.utils.data.DataLoader(
    trainset_svhn, batch_size=32, shuffle=True, num_workers=32)

# Loading CIFAR test split as in-distribution
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=32, shuffle=False, num_workers=32)

Using downloaded and verified file: ./data/test_32x32.mat
Files already downloaded and verified


In [44]:
# Some initializations before running inference

import numpy as np
from tqdm import tqdm 

entropies = np.array([])
ys = np.array([])

net.eval()

rc = 0
counter = 0

C = 0 # this parameter defines "blank" value for the input
S = 10 # this parameter defines scale of the additional input

### Predict Uncertainties for In-distribution

In [45]:
for batch in tqdm(testloader, total=len(testloader)):

    X, y = batch

    batch_size, _, H, W = X.shape

    # creating "blank" inputs for the first inferense
    # in this case, it's 4d tensor filled with constant C
    targets_1 = C * torch.ones(batch_size, 1, H, W).to(X.device)
    X1 = torch.cat([X, targets_1 / S], dim=1)

    predictions = torch.nn.functional.softmax(net(X1), dim=1)
    pred_labels = predictions.argmax(dim=1).cpu()

    # creating additional input for the second inference
    # in this case, this input filled with predicted classes
    targets_2 = torch.ones(batch_size, 1, H, W).to(X.device)
    targets_2 = targets_2 * pred_labels.reshape(-1, 1, 1, 1) + 1
    X2 = torch.cat([X, targets_2 / S], dim=1)
    
    predictions_2 = torch.nn.functional.softmax(net(X2), dim=1)

    unc = (predictions - predictions_2).max(dim=1)[0].detach().cpu().numpy()
    y = np.zeros_like(y.cpu().numpy())

    entropies = np.concatenate([entropies, unc])
    ys = np.concatenate([ys, y])

100%|██████████| 313/313 [00:09<00:00, 33.63it/s]


### Predict Uncertainties for Out-of-distribution

In [46]:
for batch in tqdm(trainloader_svhn, total=len(trainloader_svhn)):

    X, y = batch

    batch_size, _, H, W = X.shape

    # creating "blank" inputs for the first inferense
    # in this case, it's 4d tensor filled with constant C
    targets_1 = C * torch.ones(batch_size, 1, H, W).to(X.device)
    X1 = torch.cat([X, targets_1 / S], dim=1)

    predictions = torch.nn.functional.softmax(net(X1), dim=1)
    pred_labels = predictions.argmax(dim=1).cpu()

    # creating additional input for the second inference
    # in this case, this input filled with predicted classes
    targets_2 = torch.ones(batch_size, 1, H, W).to(X.device)
    targets_2 = targets_2 * pred_labels.reshape(-1, 1, 1, 1) + 1
    X2 = torch.cat([X, targets_2 / S], dim=1)

    predictions_2 = torch.nn.functional.softmax(net(X2), dim=1)

    unc = (predictions - predictions_2).max(dim=1)[0].detach().cpu().numpy()
    y = np.ones_like(y.cpu().numpy())

    entropies = np.concatenate([entropies, unc])
    ys = np.concatenate([ys, y])

100%|██████████| 814/814 [00:17<00:00, 45.71it/s]


### Computing AUC metrics

In [47]:
import sklearn.metrics
roc_auc = sklearn.metrics.roc_auc_score(ys, entropies)

precision, recall, thresholds = sklearn.metrics.precision_recall_curve(ys, entropies)
pr_auc = sklearn.metrics.auc(recall, precision)

In [48]:
print(f"OOD ROC-AUC: {round(roc_auc, 3)}, PR-AUC: {round(pr_auc, 3)}")

OOD ROC-AUC: 0.901, PR-AUC: 0.933
