This notebook provides a comparison of oracle performance for deep ensembles and sets of "expert" networks obtained by training end-to-end mixture of expert models. The oracle performance (both loss and accuracy) is obtained by only considering the subnetwork with the lowest loss for each sample.

In [66]:
import numpy as np 
import torch
import torchvision
import matplotlib.pyplot as plt

import datasets.mnist as mnist
import datasets.cifar10 as cifar10
import constants
from configuration import Configuration
from metrics import basic_cross_entropy

from util import *

In [65]:
from tqdm.notebook import trange, tqdm

## Setup

In [94]:
def oracle_loss_and_correctness(preds, gt, criterion):
    losses = torch.stack([criterion(pred, gt, reduction='none') for pred in preds], dim=1)
    loss_val, net_id = torch.min(losses, dim=-1)
    
    oracle_loss = torch.mean(loss_val)
    preds_stacked = torch.stack(preds, dim=1)

    oracle_preds = preds_stacked[torch.arange(preds_stacked.shape[0]), net_id]

    confidence, predicted = torch.max(oracle_preds, 1)
    correct = (predicted == gt).sum().item()

    return oracle_loss, correct

In [76]:
def oracle_test(loader, trainer, main_criterion=basic_cross_entropy, individual_criterion=torch.nn.functional.cross_entropy):
    trainer.model.eval()
    
    total = 0
    loss = 0
    correct = 0
    oracle_correct = 0
    oracle_loss = 0
    with torch.no_grad():
        with tqdm(loader, unit="batch") as tepoch:
            for x, y in tepoch:

                n = x.shape[0]
                total += n

                cum_pred, preds = trainer.predict_test(x)
                
                loss += main_criterion(cum_pred, y).item() * n
                confidence, predicted = torch.max(cum_pred, 1)
                correct += (predicted == y).sum().item()

                oracle_batch_loss, oracle_batch_correct = oracle_loss_and_correctness(preds, y, individual_criterion)
                oracle_loss += oracle_batch_loss * n
                oracle_correct += oracle_batch_correct

    print(f'Results:\nAccuracy: {correct/total}\nOracle accuracy: {oracle_correct/total}\nLoss: {loss / total}\nOracle loss {oracle_loss/total}')

    return correct/total, oracle_correct/total, loss/total, oracle_loss/total



In [112]:
# mnist 5 lenet ensemble 
run_id_ens_mnist = 'run-20210620_200823-3pjsgj5t'

# mnist 5 lenet MoE with MLP gating
run_id_moe_mnist = 'run-20210715_132419-25v4nj6r'

# cifar10 5 resnet20 DE
run_id_ens_cifar = 'run-20210714_204846-2xi5eo42'

# cifar10 5 expert MoE with MLP gating
run_id_moe_cifar = 'run-20210709_132246-30tjmo1b'

# cifar 10 5 expert MoE with conv gating
run_id_moe_cifar_conv = 'run-20210714_173933-3dt16gha'

# cifar 10 5 experts MoE with MC Dropout gating
run_id_moe_cifar_mcd = 'run-20210704_235514-2k6cty25'


In [103]:
test_loader_mnist = mnist.get_test_loader(model_args.data_dir, model_args.batch_size, corrupted=False)#, intensity=i, corruption='rotation')
test_loader_cifar = cifar10.get_test_loader(model_args.data_dir, model_args.batch_size, corrupted=False)#, corruptions=constants.CORRUPTIONS, intensities=[corruption_intensity]is_cifar10=model_args.dataset_type=='cifar10')

### Run a test for a regular deep ensemble on MNIST

In [98]:
trainer, model_args = load_trainer(run_id_ens_mnist, 40, device='cpu')
oracle_test(test_loader_mnist, trainer);

Initialising an ensemble of 5 networks


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.9897
Oracle accuracy: 0.9951
Loss: 0.029992355357296763
Oracle loss 0.014500808902084827


### Run a test for a MoE model on MNIST

In [99]:
trainer, model_args = load_trainer(run_id_moe_mnist, 60, device='cpu')
oracle_test(test_loader_mnist, trainer);

  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.9847
Oracle accuracy: 0.9969
Loss: 0.05549217186557362
Oracle loss 0.010589959099888802


### Run a test for a regular deep ensemble on CIFAR10


In [105]:
trainer, model_args = load_trainer(run_id_ens_cifar, 180, device='cpu')
oracle_test(test_loader_cifar, trainer);

Initialising an ensemble of 5 networks
SGD optimizer


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.936
Oracle accuracy: 0.9773
Loss: 0.20362979245185853
Oracle loss 0.07903696596622467


### Run a test for a MoE model on CIFAR

In [110]:
trainer, model_args = load_trainer(run_id_moe_cifar, 180, device='cpu')
oracle_test(test_loader_cifar, trainer);

Using a simple gate
SGD optimizer
using multistep scheduler


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.8918
Oracle accuracy: 0.9727
Loss: 0.3614530630111694
Oracle loss 0.08147614449262619


In [111]:
trainer, model_args = load_trainer(run_id_moe_cifar_conv, 180, device='cpu')
oracle_test(test_loader_cifar, trainer);

Using a simple convolutional gate
SGD optimizer
using multistep scheduler


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.9023
Oracle accuracy: 0.9751
Loss: 0.3513138517379761
Oracle loss 0.07271169871091843


In [113]:
trainer, model_args = load_trainer(run_id_moe_cifar_mcd, 180, device='cpu')
oracle_test(test_loader_cifar, trainer);

mc drop gate with p = 0.9
SGD optimizer
using multistep scheduler


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.9235
Oracle accuracy: 0.9738
Loss: 0.2539546750307083
Oracle loss 0.09243524819612503
