In [1]:
import os
import sys
sys.path.insert(0, os.path.abspath('../'))

In [1]:
import numpy as np 
import torch
import torchvision
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict

import datasets.mnist as mnist
import datasets.cifar10 as cifar10
import constants
from configuration import Configuration, DEFAULT_DICT
from methods.moe.MixtureOfExperts import SimpleMoE
from methods.mcdropout.MCDropout import MCDropout
from methods.BaseTrainer import StatisticsTracker
from util import *


### Train a MoE model using a class-based allocation to experts

In [2]:
args = Configuration(DEFAULT_DICT)
args.moe_gating = 'simple'
args.method = 'moe'
args.n = 5
args.model = 'lenet'
args.optimizer = 'adam'
# args.cpu = True
args.moe_type = 'fixed'
args.predict_gated = True


In [3]:
t = get_trainer(args, 'cuda')

Using a simple gate


In [4]:
t.model.gate_by_class = True

In [5]:
train_loader, valid_loader = mnist.get_mnist_train_valid_loader(args.data_dir, args.batch_size, random_seed=1)

In [6]:
# train for the same number of epochs as a regular MoE model would
# reported validation results can be ignored as the gating network output is used there, not class-allocations 
t.fit(train_loader, valid_loader, epochs=5, log=False)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.003
    weight_decay: 0
)
  0%|          | 0/216 [00:00<?, ?batch/s]Epoch 1
100%|██████████| 216/216 [00:06<00:00, 32.99batch/s, loss=0.0418]
  0%|          | 0/24 [00:00<?, ?batch/s]Loads for the epoch label 0: [5366.    0.    0.    0.    0.]
Loads for the epoch label 1: [   0. 6046.    0.    0.    0.]
Loads for the epoch label 2: [   0.    0. 5356.    0.    0.]
Loads for the epoch label 3: [   0.    0.    0. 5495.    0.]
Loads for the epoch label 4: [   0.    0.    0.    0. 5267.]
Loads for the epoch label 5: [4875.    0.    0.    0.    0.]
Loads for the epoch label 6: [   0. 5291.    0.    0.    0.]
Loads for the epoch label 7: [   0.    0. 5641.    0.    0.]
Loads for the epoch label 8: [   0.    0.    0. 5275.    0.]
Loads for the epoch label 9: [   0.    0.    0.    0. 5388.]
Loads for the epoch: [10241. 11337. 10997. 10770. 10655.]

Validating
100%|██████████| 24/24 [00:00<00:00, 41.08ba

### Test using the same class-based gating

In [7]:
test_loader = mnist.get_test_loader(args.data_dir, args.batch_size, corrupted=False)#, intensity=i, corruption='rotation')

In [8]:
metric_dict = {'NLL': lambda p, g: metrics.basic_cross_entropy(p, g).item(), 
                    'ECE': metrics.wrap_ece(bins=20), 
                    'Brier': metrics.wrap_brier()}

In [9]:
t.model.eval()

stat_tracker = StatisticsTracker(args.n)

with torch.no_grad():
    with tqdm(test_loader, unit="batch") as tepoch:
        metric_accumulators = defaultdict(int)
        for X, y in tepoch:

            X, y = X.to(t.device), y.to(t.device)
            
            y_hat, preds, batch_loads, batch_loads_by_label, load_loss = t.model(X, labels=y)

            for name, metric in metric_dict.items():
                metric_val = metric(y_hat, y)
                # assumes all metrics are mean-reduced
                metric_accumulators[name] += metric_val * X.size(0)

            stat_tracker.update(y_hat, preds, y)

    correct = stat_tracker.correct
    total = stat_tracker.total

    test_accuracy = correct/total
    print(f'Results: \nAccuracy: {test_accuracy}')
    for name, val in metric_accumulators.items():
        metric_accumulators[name] = val/total
        print(f'{name}: {metric_accumulators[name]}')

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 40/40 [00:00<00:00, 40.67batch/s]Results: 
Accuracy: 0.996
NLL: 0.01472758359159343
ECE: 0.005053635632991796
Brier: 0.006774707189106266



It does not appear that the networks overfitting as part of a MoE model would have a negative effect on their performance in terms of accuracy - we have shown above that when using the "oracle" expert - the one that has been trained on the class of a given sample - we can obtain a near perfect classification accuracy. There could, however, be an issue in the gating network and its level of over- or under-fitting

### Train a post-hoc gating network via the ensemble loss

In [16]:
from methods.moe.gate_models import SimpleConvGate, GateWrapper

In [17]:
t.model.gate_by_class = False
t.gated_predict = True
t.model.train();

In [18]:

sg = GateWrapper(SimpleConvGate(28, 5))
sg.to(t.device)
t.model.gating_network = sg

# for param in t.model.gating_network.parameters():
#     print(param.requires_grad)

for param in t.model.gating_network.parameters():
    param.requires_grad = True

# for param in t.model.gating_network.parameters():
#     print(param.requires_grad)

for param in t.model.experts.parameters():
    param.requires_grad = False

optim = torch.optim.SGD(t.model.gating_network.parameters(), lr=args.lr)

t.optimizer = optim

Using a simple convolutional gate


In [19]:
# loads are off because here all non-zero weights are counted and I've changed the gating to be dense
t.fit(train_loader, valid_loader, epochs=10, log=False)

  0%|          | 0/216 [00:00<?, ?batch/s]SGD (
Parameter Group 0
    dampening: 0
    lr: 0.003
    momentum: 0
    nesterov: False
    weight_decay: 0
)
Epoch 1
100%|██████████| 216/216 [00:02<00:00, 86.98batch/s, loss=0.657]
  0%|          | 0/24 [00:00<?, ?batch/s]Loads for the epoch label 0: [5366. 5366. 5366. 5366. 5366.]
Loads for the epoch label 1: [6046. 6046. 6046. 6046. 6046.]
Loads for the epoch label 2: [5356. 5356. 5356. 5356. 5356.]
Loads for the epoch label 3: [5495. 5495. 5495. 5495. 5495.]
Loads for the epoch label 4: [5267. 5267. 5267. 5267. 5267.]
Loads for the epoch label 5: [4875. 4875. 4875. 4875. 4875.]
Loads for the epoch label 6: [5291. 5291. 5291. 5291. 5291.]
Loads for the epoch label 7: [5641. 5641. 5641. 5641. 5641.]
Loads for the epoch label 8: [5275. 5275. 5275. 5275. 5275.]
Loads for the epoch label 9: [5388. 5388. 5388. 5388. 5388.]
Loads for the epoch: [54000. 54000. 54000. 54000. 54000.]

Validating
100%|██████████| 24/24 [00:00<00:00, 49.73batch/s, 

In [20]:
t.test(test_loader, metric_dict);

  0%|          | 0/40 [00:00<?, ?batch/s]
Testing
100%|██████████| 40/40 [00:00<00:00, 68.87batch/s]Results: 
Accuracy: 0.9568
NLL: 0.179173562861979
ECE: 0.06775618422031403
Brier: 0.07396664386615157



The gating network being trained post-hoc but via primarily the ensemble loss is a very indirect approach to the MoE post-hoc gating training, reminiscent of the end-to-edn approach, only split apart to take turns. We might expect better results if we define a loss specifically for the gating network and train it in isolation.

In [10]:
from methods.moe.laplace_gating import get_adjusted_loader

t.model.to('cpu')

gate_train_loader = get_adjusted_loader(t.model.experts, train_loader)
gate_val_loader = get_adjusted_loader(t.model.experts, valid_loader)

t.model.to(t.device);


In [23]:
sg = SimpleConvGate(28, 5)
gate_train_epochs = 10


Using a simple convolutional gate


In [24]:
# t.device='cpu'
optim = torch.optim.Adam(sg.parameters(), weight_decay=0.001)

sg.to(t.device)

from tqdm import tqdm
for i in range(gate_train_epochs):
    
    print(f'Epoch {i + 1}')

    correct = 0
    total = 0
    sg.train()

    with tqdm(gate_train_loader, unit="batch") as tepoch:
        for X, y in tepoch:
            X, y = X.to(t.device), y.to(t.device)
            
            # compute loss        
            y_hat = sg(X)
            loss = nn.functional.cross_entropy(y_hat, y)
            
            # backpropogate
            optim.zero_grad()
            loss.backward()
            optim.step()

            loss = loss.item()
            tepoch.set_postfix(loss=loss)
            _, predicted = torch.max(y_hat, 1)
            correct += (predicted == y).sum().item()
            total += X.shape[0]

    print(f'\nAccuracy: {correct/total}')
    
    correct = 0
    total = 0
    sg.train()
    
    with tqdm(gate_val_loader, unit="batch") as tepoch:
        for X, y in tepoch:
            X, y = X.to(t.device), y.to(t.device)
            
            # compute loss        
            y_hat = sg(X)
            loss = nn.functional.cross_entropy(y_hat, y)

            loss = loss.item()
            tepoch.set_postfix(loss=loss)
            _, predicted = torch.max(y_hat, 1)
            correct += (predicted == y).sum().item()
            total += X.shape[0]

    print(f'\nValidation accuracy: {correct/total}')

  1%|▏         | 3/216 [00:00<00:09, 21.76batch/s, loss=1.08]Epoch 1
100%|██████████| 216/216 [00:06<00:00, 33.00batch/s, loss=0.123]
 21%|██        | 5/24 [00:00<00:00, 42.06batch/s, loss=0.13] Accuracy: 0.9203518518518519
100%|██████████| 24/24 [00:00<00:00, 39.56batch/s, loss=0.155]
  1%|          | 2/216 [00:00<00:12, 17.56batch/s, loss=0.165]Validation accuracy: 0.9668333333333333
Epoch 2
100%|██████████| 216/216 [00:07<00:00, 30.74batch/s, loss=0.0855]
 17%|█▋        | 4/24 [00:00<00:00, 39.52batch/s, loss=0.0778]Accuracy: 0.9757037037037037
100%|██████████| 24/24 [00:00<00:00, 33.23batch/s, loss=0.091]
  1%|▏         | 3/216 [00:00<00:08, 26.10batch/s, loss=0.0478]Validation accuracy: 0.9771666666666666
Epoch 3
100%|██████████| 216/216 [00:06<00:00, 31.12batch/s, loss=0.0557]
 17%|█▋        | 4/24 [00:00<00:00, 33.91batch/s, loss=0.0837]Accuracy: 0.982925925925926
100%|██████████| 24/24 [00:00<00:00, 38.68batch/s, loss=0.0449]
  1%|          | 2/216 [00:00<00:12, 17.38batch/s, l

In [25]:

t.model.gating_network = GateWrapper(sg).to(t.device)


In [26]:
t.test(test_loader, metric_dict);

  0%|          | 0/40 [00:00<?, ?batch/s]
Testing
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 40/40 [00:00<00:00, 71.44batch/s]Results: 
Accuracy: 0.9821
NLL: 0.06378110595978796
ECE: 0.018323120385408398
Brier: 0.02862546475371346



Here we observe more of a problem: the training loss of the post-hoc trained gating network is nearly zero, and the accuracy over 95%, indicating only limited further training would be possible, however, the generalisation error, reflecting the overall MoE error in this case nearly perfectly, is still somewhat higher. Even with some level of weight regularisation introduced when training the gating network, the problem persists. We are pushing the generealisation error limits here, and it is likely the sheer "learnability" of the dataset is affecting this to some extent.

In [14]:
from methods.moe.gate_models import SimpleConvGate, GateWrapper
from methods.moe.laplace_gating import get_adjusted_loader
import gate_train as gt
from importlib import reload

In [18]:
from methods.moe.gate_models import get_gating_network

In [None]:
gt = reload(gt)

In [15]:
t.model.to('cpu')

gate_train_loader = get_adjusted_loader(t.model.experts, train_loader, return_original=True)
gate_val_loader = get_adjusted_loader(t.model.experts, valid_loader, return_original=True)

t.model.to(t.device);

In [26]:
(gate_train_loader.dataset.new_labels == 4).sum() / gate_train_loader.dataset.new_labels.shape[0] 

tensor(0.1965)

In [None]:
# mnist 5 lenet MoE with convolutional gating
run_id_moe_mnist_conv = 'run-20210713_152705-r9znwdrf'


In [19]:
# sg = GateWrapper(SimpleConvGate(28, 5))
sg = get_gating_network(None, 'conv', 28*28, 5)
sg.to('cuda')
exps, g = gt.fit_gating(t.model.experts, sg, gate_train_loader, gate_val_loader, 0.001, 0, gt.loss_sum_criterion, 'cuda', 10)


  1%|          | 2/216 [00:00<00:11, 18.27batch/s, loss=8.37]Using a simple convolutional gate
Epoch 1
100%|██████████| 216/216 [00:08<00:00, 26.47batch/s, loss=0.768]
 17%|█▋        | 4/24 [00:00<00:00, 25.79batch/s, loss=0.618]
Training
--------------
Ensemble accuracy 0.9082407407407408
Gate oracle accuracy 0.9062407407407408
Loss 2.231543489076473
100%|██████████| 24/24 [00:00<00:00, 26.54batch/s, loss=0.58]
  1%|          | 2/216 [00:00<00:11, 18.81batch/s, loss=0.531]
Validation
--------------
Ensemble accuracy 0.9656666666666667
Gate oracle accuracy 0.9691666666666666
Loss 0.7032889698942503
Epoch 2
100%|██████████| 216/216 [00:08<00:00, 26.92batch/s, loss=0.467]
 17%|█▋        | 4/24 [00:00<00:00, 32.48batch/s, loss=0.314]
Training
--------------
Ensemble accuracy 0.9743703703703703
Gate oracle accuracy 0.9758333333333333
Loss 0.5031139244221978
100%|██████████| 24/24 [00:00<00:00, 29.87batch/s, loss=0.333]
  0%|          | 1/216 [00:00<00:37,  5.69batch/s, loss=0.32]
Validatio

In [32]:
dummy_args = Configuration(DEFAULT_DICT)
dummy_args.moe_gating = 'simple'
dummy_args.method = 'moe'
dummy_args.n = 5
dummy_args.model = 'lenet'
dummy_args.optimizer = 'adam'
# args.cpu = True
dummy_args.moe_type = 'dense'
dummy_args.predict_gated = True

In [35]:
dummy_trainer = get_trainer(dummy_args, 'cuda')

Using a simple gate


In [37]:
dummy_trainer.model.gating_network = g
dummy_trainer.model.experts = exps

In [38]:
dummy_trainer.test(test_loader, metric_dict);

  0%|          | 0/40 [00:00<?, ?batch/s]
Testing
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 40/40 [00:00<00:00, 63.39batch/s]Results: 
Accuracy: 0.9833
NLL: 0.08413592033321038
ECE: 0.015074500668048853
Brier: 0.028405361668774276

