This notebook implements the Madras et al. 2018 baseline on CIFAR-10. Running this notebook will produce results for a single expert k.

# Data

In [None]:
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import shutil
import time
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                                                                padding=0, bias=False) or None
    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)


class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)


class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        assert ((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]
        self.softmax = nn.Softmax()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        out = self.fc(out)
        out = self.softmax(out)
        return out


We modify the WideResNet architecture to implement the baseline

In [None]:
class WideResNet_madras(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet_madras, self).__init__()
        nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        assert ((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]
        self.softmax = nn.Softmax()

        # 1st conv before any network block
        self.conv1_rej = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                                   padding=1, bias=False)
        # 1st block
        self.block1_rej = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2_rej = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3_rej = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1_rej = nn.BatchNorm2d(nChannels[3])
        self.relu_rej = nn.ReLU(inplace=True)
        self.fc_rej = nn.Linear(nChannels[3], 2)
        self.softmax_rej = nn.Softmax()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        out = self.fc(out)
        out = self.softmax(out)

        rej = self.conv1_rej(x)
        rej = self.block1_rej(rej)
        rej = self.block2_rej(rej)
        rej = self.block3_rej(rej)
        rej = self.relu_rej(self.bn1_rej(rej))
        rej = F.avg_pool2d(rej, 8)
        rej = rej.view(-1, self.nChannels)
        rej = self.fc_rej(rej)
        rej = self.softmax_rej(rej)

        return [out, rej]


In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

import random

def metrics_madras(net, expert_fn, n_classes, loader):
    correct = 0
    correct_sys = 0
    exp = 0
    exp_total = 0
    total = 0
    real_total = 0
    alone_correct = 0
    with torch.no_grad():
        for data in loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs, rej = net(images)
            _, predicted = torch.max(outputs.data, 1)
            batch_size = outputs.size()[0]  # batch_size
            exp_prediction = expert_fn(images, labels)
            for i in range(0,batch_size):
                r = (rej[i][0].item() >= 0.5)
                if r == 0:
                    total += 1
                    correct += (predicted[i] == labels[i]).item()
                    correct_sys += (predicted[i] == labels[i]).item()
                if r == 1:
                    exp += (exp_prediction[i] == labels[i].item())
                    correct_sys += (exp_prediction[i] == labels[i].item())
                    exp_total += 1
                real_total += 1
                alone_correct += (predicted[i] == labels[i]).item()
    cov = str(total) + str(" out of") + str(real_total)
    to_print = {"coverage": cov, "system accuracy": 100 * correct_sys / real_total,
                "expert accuracy": 100 * exp / (exp_total + 0.0002),
                "classifier accuracy": 100 * correct / (total + 0.0001),
                "alone classifier": 100 * alone_correct / real_total}
    print(to_print)


# Expert and data loaders

In [None]:
normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                        std=[x/255.0 for x in [63.0, 62.1, 66.7]])

transform_train = transforms.Compose([
        transforms.ToTensor(),
        normalize,
        ])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    normalize
    ])

n_dataset = 10
dataset = 'cifar10' 
kwargs = {'num_workers': 1, 'pin_memory': True}



train_loader = torch.utils.data.DataLoader( 
    datasets.__dict__[dataset.upper()]('../data', train=True, download=True,
                                    transform=transform_train),
                    batch_size=128, shuffle=True, **kwargs)

train_dataset_all = datasets.__dict__[dataset.upper()]('../data', train=True, download=True,
                                    transform=transform_train)

train_size = int(0.90 * len(train_dataset_all))
train_size_val = len(train_dataset_all) - train_size

train_dataset, train_dataset_val = torch.utils.data.random_split(train_dataset_all, [train_size, train_size_val])

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=128, shuffle=True, **kwargs)

train_loader_val = torch.utils.data.DataLoader(
    train_dataset_val,
    batch_size=128, shuffle=True, **kwargs)


val_loader = torch.utils.data.DataLoader(
    datasets.__dict__[dataset.upper()]('../data', train=False, transform=transform_test),
    batch_size=128, shuffle=True, **kwargs)


class synth_expert:
    def __init__(self, k, n_classes):
        self.k = k
        self.n_classes = n_classes

    def predict(self, input, labels):
        batch_size = labels.size()[0]  # batch_size
        outs = [0] * batch_size
        for i in range(0,batch_size):
            if labels[i].item() <= self.k:
                #coin = np.random.binomial(1,0.94,1)[0]
                #if coin:
                outs[i] = labels[i].item()
                #else:
                 #   prediction_rand = random.randint(0, self.n_classes - 1)
                 #   outs[i] = prediction_rand
            else:
                prediction_rand = random.randint(0, self.n_classes - 1)
                outs[i] = prediction_rand
        return outs





In [None]:
k = 5 # expert parameter
expert = synth_expert(k, n_dataset)

# Baseline: Madras et al. 2018

In [None]:
def madras_loss(outputs, rej, labels, expert, eps = 10e-12):
    # MixOfExperts loss of Madras et al. 2018
    batch_size = outputs.size()[0]
    output_no_grad = outputs.detach()
    net_loss_no_grad = -torch.log2(output_no_grad[range(batch_size), labels]+eps)
    net_loss = -torch.log2(outputs[range(batch_size), labels]+eps)
    exp_loss = -torch.log2(expert[range(batch_size), labels]+eps) 
    system_loss =  (rej[range(batch_size),0])  *  net_loss_no_grad + rej[range(batch_size),1]  * exp_loss
    system_loss += net_loss
    return torch.sum(system_loss)/batch_size



def cross_entropy_loss(outputs, labels):
    batch_size = outputs.size()[0]
    net_loss = -torch.log2(outputs[range(batch_size), labels])
    return torch.sum(net_loss)/batch_size


def train_madras(train_loader, model, optimizer, scheduler, epoch, expert_fn, n_classes):
    """Train for one epoch on the training set"""
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        target = target.to(device)
        input = input.to(device)

        # compute output
        output, rej = model(input)
        output.detach()
        # expert  predictions
        expert_pred = expert_fn(input, target)
        expert_pred_n = [[0.005555]*10] * output.size()[0]
        for k in range(len(expert_pred)):
            if target[k].item() ==expert_pred[k]:
                expert_pred_n[k][target[k].item()] = 1 - 10e-12
            else:
                expert_pred_n[k] = [0.1] * 10
        expert_pred_n = torch.tensor(expert_pred_n)
        expert_pred_n = expert_pred_n.to(device)

        loss = madras_loss(output, rej, target, expert_pred_n)

        optimizer.zero_grad()
        loss.backward()


        # measure accuracy and record loss
        prec1 = accuracy(output.data, target, topk=(1,))[0]
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # compute gradient and do SGD step
        #loss.backward()
        optimizer.step()
        scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 10 == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      loss=losses, top1=top1))


def validate_madras(val_loader, model, epoch, expert_fn, n_classes):
    """Perform validation on the validation set"""
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        target = target.to(device)
        input = input.to(device)

        # compute output
        with torch.no_grad():
            output, rej = model(input)
        # expert prediction
                # compute output

        # expert  predictions
        expert_pred = model_expert(input)
        expert_pred = expert_pred.to(device)
        loss = madras_loss(output, rej, target, expert_pred)

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target, topk=(1,))[0]
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 10 == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      i, len(val_loader), batch_time=batch_time, loss=losses,
                      top1=top1))

    print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))

    return top1.avg
best_prec1 = 0
def run_madras(model, data_aug, n_dataset, expert_fn, epochs):
    global best_prec1

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.to(device)

    # optionally resume from a checkpoint
    

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    optimizer = torch.optim.SGD(model.parameters(), 0.1,
                                momentum=0.9, nesterov = True,
                                weight_decay=5e-4)

    # cosine learning rate
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*200)


    prec1 = 0
    for epoch in range(0, epochs):
        # train for one epoch
        train_madras(train_loader, model, optimizer, scheduler, epoch, expert_fn, n_dataset)


    print('Best accuracy: ', best_prec1)


In [None]:
model = WideResNet_madras(10, n_dataset, 4, dropRate=0)


In [None]:
run_madras(model, True, n_dataset, expert.predict, 220)


## Evaluate
- coverage: percentage of examples where classifier predicts
- classifier accuracy: accuracy of classifier on non-deferred examples
- expert accuracy: accuracy of expert on deferred examples
- classifier alone accuracy: accuracy of classifier on all the data

In [None]:
metrics_madras(model, expert.predict, n_dataset, val_loader)


The below are useful functions to implement a gumbel_softmax for approximating the probabilities, not used in above code.

In [None]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    U = U.to(device)
    return -Variable(torch.log(-torch.log(U + eps) + eps))

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature, hard=False):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        return y_hard.view(-1, latent_dim * categorical_dim)
    return y



#print(gumbel_softmax(Variable(torch.tensor([[math.log(0.2), math.log(0.8)]] * 20000)),     0.0001).sum(dim=0))
#gumbel_softmax(Variable(torch.tensor([math.log(0.5),math.log(0.5)])),0.8)


def gumbel_binary_sample(logits, t=0.5, eps=1e-20):
    """ Draw a sample from the Gumbel-Softmax distribution"""
    gumbel_noise_on = sample_gumbel(logits.size())
    gumbel_noise_off = sample_gumbel(logits.size())
    concrete_on = (torch.log2(logits + eps) + gumbel_noise_on) / t
    concrete_off = (torch.log2(1 - logits + eps) + gumbel_noise_off) / t
    concrete_softmax = torch.div(torch.exp(concrete_on), torch.exp(concrete_on) + torch.exp(concrete_off))
    return concrete_softmax
