This notebook provides a comparison of oracle performance for deep ensembles and sets of "expert" networks obtained by training end-to-end mixture of expert models. The oracle performance (both loss and accuracy) is obtained by only considering the subnetwork with the lowest loss for each sample.

In [1]:
import os
import sys
sys.path.insert(0, os.path.abspath('../'))

In [2]:
import numpy as np 
import torch
import torchvision
import matplotlib.pyplot as plt

import datasets.mnist as mnist
import datasets.cifar10 as cifar10
import constants
from configuration import Configuration
from metrics import basic_cross_entropy

from util import *

In [3]:
from tqdm.notebook import trange, tqdm

## Setup

In [9]:
def oracle_loss_and_correctness(preds, gt, criterion):
    losses = torch.stack([criterion(pred, gt, reduction='none') for pred in preds], dim=1)
    loss_val, net_id = torch.min(losses, dim=-1)
    
    oracle_loss = torch.mean(loss_val)
    preds_stacked = torch.stack(preds, dim=1)

    oracle_preds = preds_stacked[torch.arange(preds_stacked.shape[0]), net_id]

    confidence, predicted = torch.max(oracle_preds, 1)
    correct = (predicted == gt).sum().item()

    loads = np.zeros((len(preds)))
    for i in range(len(preds)):
        loads[i] = (net_id == i).sum().item()

    return oracle_loss, correct, loads

In [7]:
def oracle_test(loader, trainer, main_criterion=basic_cross_entropy, individual_criterion=torch.nn.functional.cross_entropy):
    trainer.model.eval()
    
    total = 0
    loss = 0
    correct = 0
    oracle_correct = 0
    oracle_loss = 0
    oracle_loads = np.zeros((trainer.n))

    with torch.no_grad():
        with tqdm(loader, unit="batch") as tepoch:
            for x, y in tepoch:

                n = x.shape[0]
                total += n

                cum_pred, preds = trainer.predict_test(x)
                
                loss += main_criterion(cum_pred, y).item() * n
                confidence, predicted = torch.max(cum_pred, 1)
                correct += (predicted == y).sum().item()

                oracle_batch_loss, oracle_batch_correct, batch_oracle_loads = oracle_loss_and_correctness(preds, y, individual_criterion)
                oracle_loss += oracle_batch_loss * n
                oracle_correct += oracle_batch_correct
                oracle_loads += batch_oracle_loads

    print(f'Results:\nAccuracy: {correct/total}\nOracle accuracy: {oracle_correct/total}\nLoss: {loss / total}\nOracle loss {oracle_loss/total}\nOracle Loads: {oracle_loads/total}')

    return correct/total, oracle_correct/total, loss/total, oracle_loss/total, oracle_loads/total



In [4]:
# mnist 5 lenet ensemble 
run_id_ens_mnist = 'run-20210620_200823-3pjsgj5t'

# class-specific gating moe
run_id_classgated_mnist = 'run-20210805_150649-1h06xm2k'

# mnist 5 lenet MoE with MLP gating
run_id_moe_mnist = 'run-20210715_132419-25v4nj6r'

# mnist 5 lenet MoE with convolutional gating
run_id_moe_mnist_conv = 'run-20210713_152705-r9znwdrf'

# mnist 5 lenet MoE trained via sum loss
run_id_moe_mnist_sumloss = 'run-20210720_195742-3tszfwy7'

# cifar10 5 resnet20 DE
run_id_ens_cifar = 'run-20210714_204846-2xi5eo42'

# cifar10 5 expert MoE with MLP gating
run_id_moe_cifar = 'run-20210709_132246-30tjmo1b'

# cifar 10 5 expert MoE with conv gating
run_id_moe_cifar_conv = 'run-20210714_173933-3dt16gha'

# cifar 10 5 experts MoE with MC Dropout gating
run_id_moe_cifar_mcd = 'run-20210704_235514-2k6cty25'

# cifar 10 hard class allocations
run_id_cifar_class = 'run-20210719_224716-31fxykw8'

# cifar100 ens 5
run_id_cifar100_ens = 'cifar100-baseline-2137/run-20210727_061234-13c3x9go'

# cifar100 moe 5 fixed class gate
run_id_cifar100_class_gate = 'run-20210805_144739-bttnarpp'

data_dir='/scratch/gp491/data'

In [5]:

test_loader_mnist = mnist.get_test_loader(data_dir, 128, corrupted=False)#, intensity=i, corruption='rotation')
test_loader_cifar = cifar10.get_test_loader(data_dir, 128, corrupted=False)#, corruptions=constants.CORRUPTIONS, intensities=[corruption_intensity]is_cifar10=model_args.dataset_type=='cifar10')
test_loader_cifar_100 = cifar10.get_test_loader(data_dir, 128, corrupted=False, is_cifar10=False)#, corruptions=constants.CORRUPTIONS, intensities=[corruption_intensity]is_cifar10=model_args.dataset_type=='cifar10')

### Run a test for a regular deep ensemble on MNIST

In [10]:
trainer, model_args = load_trainer(run_id_ens_mnist, 40, device='cpu')
oracle_test(test_loader_mnist, trainer);

Initialising an ensemble of 5 networks


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.9897
Oracle accuracy: 0.9951
Loss: 0.02999235608279705
Oracle loss 0.014500806108117104
Oracle Loads: [0.1337 0.2073 0.1601 0.2491 0.2498]


### Run a test for a MoE model on MNIST

In [23]:
trainer, model_args = load_trainer(run_id_moe_mnist, 60, device='cpu')
oracle_test(test_loader_mnist, trainer);

  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.9847
Oracle accuracy: 0.9969
Loss: 0.055492172188602855
Oracle loss 0.010589958168566227
Loads: [0.0897 0.1799 0.1581 0.2881 0.2842]


In [24]:
trainer, model_args = load_trainer(run_id_moe_mnist_conv, 40, device='cpu')
oracle_test(test_loader_mnist, trainer);

Using a simple convolutional gate


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.983
Oracle accuracy: 0.9973
Loss: 0.060985859308578076
Oracle loss 0.009276136755943298
Loads: [0.194  0.1462 0.2236 0.2507 0.1855]


### Run a test for a regular deep ensemble on CIFAR10


In [25]:
trainer, model_args = load_trainer(run_id_ens_cifar, 180, device='cpu')
oracle_test(test_loader_cifar, trainer);

Initialising an ensemble of 5 networks
SGD optimizer


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.936
Oracle accuracy: 0.9773
Loss: 0.20362979168891907
Oracle loss 0.07903696596622467
Loads: [0.4186 0.1942 0.1514 0.1208 0.115 ]


### Run a test for a MoE model on CIFAR

In [26]:
trainer, model_args = load_trainer(run_id_moe_cifar, 180, device='cpu')
oracle_test(test_loader_cifar, trainer);

Using a simple gate
SGD optimizer
using multistep scheduler


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.8918
Oracle accuracy: 0.9727
Loss: 0.36145306406021116
Oracle loss 0.08147614449262619
Loads: [0.2979 0.23   0.2068 0.0644 0.2009]


In [25]:
trainer, model_args = load_trainer(run_id_moe_cifar_conv, 180, device='cpu')
oracle_test(test_loader_cifar, trainer);

Initialising a Mixture of Experts
criterion: <function ensemble_criterion at 0x7f8833f0f8b0>
Using a simple convolutional gate
SGD optimizer
using multistep scheduler


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.9023
Oracle accuracy: 0.9751
Loss: 0.35131385192871095
Oracle loss 0.07271169871091843
Oracle Loads: [0.279  0.2218 0.2381 0.0291 0.232 ]


In [28]:
trainer, model_args = load_trainer(run_id_moe_cifar_mcd, 180, device='cpu')
oracle_test(test_loader_cifar, trainer);

mc drop gate with p = 0.9
SGD optimizer
using multistep scheduler


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.9234
Oracle accuracy: 0.9738
Loss: 0.25389006974697115
Oracle loss 0.09243524819612503
Loads: [0.2616 0.2046 0.1997 0.2002 0.1339]


In [11]:
trainer, model_args = load_trainer(run_id_cifar_class, 180, device='cpu')
oracle_test(test_loader_cifar, trainer);

Using a simple convolutional gate
SGD optimizer
using multistep scheduler


  0%|          | 0/79 [00:00<?, ?batch/s]

Results:
Accuracy: 0.5425
Oracle accuracy: 0.9872
Loss: nan
Oracle loss 0.05161862075328827
Oracle Loads: [0.2    0.2005 0.1995 0.1996 0.2004]
