In [1]:
import numpy as np 
import torch
import torchvision
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict

import datasets.mnist as mnist
import datasets.cifar10 as cifar10
import constants
from configuration import Configuration, DEFAULT_DICT
from methods.moe.MixtureOfExperts import SimpleMoE
from methods.mcdropout.MCDropout import MCDropout
from methods.BaseTrainer import StatisticsTracker
from util import *


### Train a MoE model using a class-based allocation to experts

In [2]:
args = Configuration(DEFAULT_DICT)
args.moe_gating = 'simple'
args.method = 'moe'
args.n = 5
args.model = 'lenet'
args.optimizer = 'adam'
# args.cpu = True
args.moe_type = 'fixed'
args.predict_gated = True


In [4]:
t = get_trainer(args, 'cuda')

Using a simple gate


In [5]:
t.model.gate_by_class = True

In [6]:
train_loader, valid_loader = mnist.get_mnist_train_valid_loader(args.data_dir, args.batch_size, random_seed=1)

In [7]:
# train for the same number of epochs as a regular MoE model would
# reported validation results can be ignored as the gating network output is used there, not class-allocations 
t.fit(train_loader, valid_loader, epochs=40, log=False)

 [   0. 6046.    0.    0.    0.]
Loads for the epoch label 2: [   0.    0. 5356.    0.    0.]
Loads for the epoch label 3: [   0.    0.    0. 5495.    0.]
Loads for the epoch label 4: [   0.    0.    0.    0. 5267.]
Loads for the epoch label 5: [4875.    0.    0.    0.    0.]
Loads for the epoch label 6: [   0. 5291.    0.    0.    0.]
Loads for the epoch label 7: [   0.    0. 5641.    0.    0.]
Loads for the epoch label 8: [   0.    0.    0. 5275.    0.]
Loads for the epoch label 9: [   0.    0.    0.    0. 5388.]
Loads for the epoch: [10241. 11337. 10997. 10770. 10655.]

Validating
100%|██████████| 24/24 [00:00<00:00, 46.13batch/s, loss=2.22]
  0%|          | 0/216 [00:00<?, ?batch/s]Validation loss: 2.12387016415596; accuracy: 0.239

Epoch 21
100%|██████████| 216/216 [00:05<00:00, 43.10batch/s, loss=0.000261]
  0%|          | 0/24 [00:00<?, ?batch/s]Loads for the epoch label 0: [5366.    0.    0.    0.    0.]
Loads for the epoch label 1: [   0. 6046.    0.    0.    0.]
Loads for the

### Test using the same class-based gating

In [8]:
test_loader = mnist.get_test_loader(args.data_dir, args.batch_size, corrupted=False)#, intensity=i, corruption='rotation')

In [9]:
metric_dict = {'NLL': lambda p, g: metrics.basic_cross_entropy(p, g).item(), 
                    'ECE': metrics.wrap_ece(bins=20), 
                    'Brier': metrics.wrap_brier()}

In [10]:
t.model.eval()

stat_tracker = StatisticsTracker(args.n)

with torch.no_grad():
    with tqdm(test_loader, unit="batch") as tepoch:
        metric_accumulators = defaultdict(int)
        for X, y in tepoch:

            X, y = X.to(t.device), y.to(t.device)
            
            y_hat, preds, batch_loads, batch_loads_by_label, load_loss = t.model(X, labels=y)

            for name, metric in metric_dict.items():
                metric_val = metric(y_hat, y)
                # assumes all metrics are mean-reduced
                metric_accumulators[name] += metric_val * X.size(0)

            stat_tracker.update(y_hat, preds, y)

    correct = stat_tracker.correct
    total = stat_tracker.total

    test_accuracy = correct/total
    print(f'Results: \nAccuracy: {test_accuracy}')
    for name, val in metric_accumulators.items():
        metric_accumulators[name] = val/total
        print(f'{name}: {metric_accumulators[name]}')

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 40/40 [00:00<00:00, 53.20batch/s]Results: 
Accuracy: 0.9967
NLL: 0.016465958053686336
ECE: 0.0033668548166751926
Brier: 0.00576458890633767



It does not appear that the networks overfitting as part of a MoE model would have a negative effect on their performance in terms of accuracy - we have shown above that when using the "oracle" expert - the one that has been trained on the class of a given sample - we can obtain a near perfect classification accuracy. There could, however, be an issue in the gating network and its level of over- or under-fitting

### Train a post-hoc gating network via the ensemble loss

In [11]:
from methods.moe.gate_models import SimpleConvGate, GateWrapper

In [12]:
t.model.gate_by_class = False
t.gated_predict = True
t.model.train();

In [13]:

sg = GateWrapper(SimpleConvGate(28, 5))
sg.to(t.device)
t.model.gating_network = sg

# for param in t.model.gating_network.parameters():
#     print(param.requires_grad)

for param in t.model.gating_network.parameters():
    param.requires_grad = True

# for param in t.model.gating_network.parameters():
#     print(param.requires_grad)

for param in t.model.experts.parameters():
    param.requires_grad = False

optim = torch.optim.SGD(t.model.gating_network.parameters(), lr=args.lr)

t.optimizer = optim

Using a simple convolutional gate


In [14]:
# loads are off because here all non-zero weights are counted and I've changed the gating to be dense
t.fit(train_loader, valid_loader, epochs=40, log=False)

ds for the epoch label 2: [5356. 5356. 5356. 5356. 5356.]
Loads for the epoch label 3: [5495. 5495. 5495. 5495. 5495.]
Loads for the epoch label 4: [5267. 5267. 5267. 5267. 5267.]
Loads for the epoch label 5: [4875. 4875. 4875. 4875. 4875.]
Loads for the epoch label 6: [5291. 5291. 5291. 5291. 5291.]
Loads for the epoch label 7: [5641. 5641. 5641. 5641. 5641.]
Loads for the epoch label 8: [5275. 5275. 5275. 5275. 5275.]
Loads for the epoch label 9: [5388. 5388. 5388. 5388. 5388.]
Loads for the epoch: [54000. 54000. 54000. 54000. 54000.]

Validating
100%|██████████| 24/24 [00:00<00:00, 44.02batch/s, loss=0.121]
  0%|          | 0/216 [00:00<?, ?batch/s]Validation loss: 0.1364691744868954; accuracy: 0.965

Epoch 21
100%|██████████| 216/216 [00:03<00:00, 67.52batch/s, loss=0.0972]
  0%|          | 0/24 [00:00<?, ?batch/s]Loads for the epoch label 0: [5366. 5366. 5366. 5366. 5366.]
Loads for the epoch label 1: [6046. 6046. 6046. 6046. 6046.]
Loads for the epoch label 2: [5356. 5356. 5356. 

In [15]:
t.test(test_loader, metric_dict);

  0%|          | 0/40 [00:00<?, ?batch/s]
Testing
100%|██████████| 40/40 [00:00<00:00, 54.80batch/s]Results: 
Accuracy: 0.977
NLL: 0.08236458208411931
ECE: 0.026623837739229207
Brier: 0.03558165775611997



The gating network being trained post-hoc but via primarily the ensemble loss is a very indirect approach to the MoE post-hoc gating training, reminiscent of the end-to-edn approach, only split apart to take turns. We might expect better results if we define a loss specifically for the gating network and train it in isolation.

In [16]:
from methods.moe.laplace_gating import get_adjusted_loader

t.model.to('cpu')

gate_train_loader = get_adjusted_loader(t.model, train_loader, precompute_labels=True)
gate_val_loader = get_adjusted_loader(t.model, valid_loader, precompute_labels=True)

t.model.to(t.device);


In [17]:
sg = SimpleConvGate(28, 5)
gate_train_epochs = 20


Using a simple convolutional gate


In [18]:
# t.device='cpu'
optim = torch.optim.Adam(sg.parameters(), weight_decay=0.001)

sg.to(t.device)

from tqdm import tqdm
for i in range(gate_train_epochs):
    
    print(f'Epoch {i + 1}')

    correct = 0
    total = 0
    sg.train()

    with tqdm(gate_train_loader, unit="batch") as tepoch:
        for X, y in tepoch:
            X, y = X.to(t.device), y.to(t.device)
            
            # compute loss        
            y_hat = sg(X)
            loss = nn.functional.cross_entropy(y_hat, y)
            
            # backpropogate
            optim.zero_grad()
            loss.backward()
            optim.step()

            loss = loss.item()
            tepoch.set_postfix(loss=loss)
            _, predicted = torch.max(y_hat, 1)
            correct += (predicted == y).sum().item()
            total += X.shape[0]

    print(f'Accuracy: {correct/total}')
    
    correct = 0
    total = 0
    sg.train()
    
    with tqdm(gate_val_loader, unit="batch") as tepoch:
        for X, y in tepoch:
            X, y = X.to(t.device), y.to(t.device)
            
            # compute loss        
            y_hat = sg(X)
            loss = nn.functional.cross_entropy(y_hat, y)

            loss = loss.item()
            tepoch.set_postfix(loss=loss)
            _, predicted = torch.max(y_hat, 1)
            correct += (predicted == y).sum().item()
            total += X.shape[0]

    print(f'Validation accuracy: {correct/total}')

  0%|          | 1/216 [00:00<00:31,  6.77batch/s, loss=1.53]Epoch 1
100%|██████████| 216/216 [00:06<00:00, 31.54batch/s, loss=0.0932]
 21%|██        | 5/24 [00:00<00:00, 41.19batch/s, loss=0.126]Accuracy: 0.9217037037037037
100%|██████████| 24/24 [00:00<00:00, 34.95batch/s, loss=0.08]
  0%|          | 1/216 [00:00<00:32,  6.54batch/s, loss=0.128]Validation accuracy: 0.9721666666666666
Epoch 2
100%|██████████| 216/216 [00:07<00:00, 30.36batch/s, loss=0.0724]
 17%|█▋        | 4/24 [00:00<00:00, 34.88batch/s, loss=0.0738]Accuracy: 0.9750740740740741
100%|██████████| 24/24 [00:00<00:00, 36.46batch/s, loss=0.0716]
  0%|          | 1/216 [00:00<00:29,  7.38batch/s, loss=0.0675]Validation accuracy: 0.9778333333333333
Epoch 3
100%|██████████| 216/216 [00:07<00:00, 29.31batch/s, loss=0.0488]
 17%|█▋        | 4/24 [00:00<00:00, 34.61batch/s, loss=0.0849]Accuracy: 0.9832037037037037
100%|██████████| 24/24 [00:00<00:00, 37.29batch/s, loss=0.0738]
  0%|          | 1/216 [00:00<00:30,  7.08batch/s,

In [19]:

t.model.gating_network = GateWrapper(sg).to(t.device)


In [20]:
t.test(test_loader, metric_dict);

  0%|          | 0/40 [00:00<?, ?batch/s]
Testing
100%|██████████| 40/40 [00:00<00:00, 56.96batch/s]Results: 
Accuracy: 0.9857
NLL: 0.05441958163864911
ECE: 0.016462561118602743
Brier: 0.02290211016079411



Here we observe more of a problem: the training loss of the post-hoc trained gating network is nearly zero, and the accuracy over 95%, indicating only limited further training would be possible, however, the generalisation error, reflecting the overall MoE error in this case nearly perfectly, is still somewhat higher. Even with some level of weight regularisation introduced when training the gating network, the problem persists. We are pushing the generealisation error limits here, and it is likely the sheer "learnability" of the dataset is affecting this to some extent.