In [1]:
import os
import sys
sys.path.insert(0, os.path.abspath('../'))

In [7]:
import numpy as np 
import torch
import torchvision
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict

import datasets.mnist as mnist
import datasets.cifar10 as cifar10
import constants
from configuration import Configuration, DEFAULT_DICT
from methods.moe.MixtureOfExperts import SimpleMoE
from methods.mcdropout.MCDropout import MCDropout
from methods.BaseTrainer import StatisticsTracker
from util import *


In [8]:
args = Configuration(DEFAULT_DICT)
args.moe_gating = 'simple'
args.method = 'moe'
args.n = 5
args.model = 'lenet'
args.optimizer = 'adam'
# args.cpu = True
args.moe_type = 'fixed'
args.predict_gated = True


In [9]:
t = get_trainer(args, 'cuda')

Using a simple gate


In [10]:
t.model.gate_by_class = True

In [11]:
train_loader, valid_loader = mnist.get_mnist_train_valid_loader(args.data_dir, args.batch_size, random_seed=1)

In [13]:
test_loader = mnist.get_test_loader(args.data_dir, args.batch_size, corrupted=False)#, intensity=i, corruption='rotation')

In [14]:
metric_dict = {'NLL': lambda p, g: metrics.basic_cross_entropy(p, g).item(), 
                    'ECE': metrics.wrap_ece(bins=20), 
                    'Brier': metrics.wrap_brier()}

In [27]:
from methods.moe.gate_models import SimpleConvGate, GateWrapper
from methods.moe.laplace_gating import get_adjusted_loader
import methods.moe.gate_train as gt
from importlib import reload

In [28]:
gt = reload(gt)

In [29]:
t.model.to('cpu')

gate_train_loader = get_adjusted_loader(t.model.experts, train_loader, return_original=True)
gate_val_loader = get_adjusted_loader(t.model.experts, valid_loader, return_original=True)

t.model.to(t.device);

In [30]:
sg = GateWrapper(SimpleConvGate(28, 5))
sg.to('cuda')
exps, g = gt.fit_gating(t.model.experts, sg, gate_train_loader, gate_val_loader, gt.loss_sum_criterion, 'cuda', 10)


  0%|          | 1/216 [00:00<00:39,  5.48batch/s, loss=9.25]Using a simple convolutional gate
Epoch 1
100%|██████████| 216/216 [00:07<00:00, 29.54batch/s, loss=0.61]
 17%|█▋        | 4/24 [00:00<00:00, 33.63batch/s, loss=0.288]
Training
--------------
Ensemble accuracy 0.9435555555555556
Gate oracle accuracy 0.9478148148148148
Loss 0.8493824156208171
100%|██████████| 24/24 [00:00<00:00, 29.82batch/s, loss=0.506]
  0%|          | 1/216 [00:00<00:37,  5.80batch/s, loss=0.235]
Validation
--------------
Ensemble accuracy 0.9646666666666667
Gate oracle accuracy 0.9716666666666667
Loss 0.39983570016920567
Epoch 2
100%|██████████| 216/216 [00:07<00:00, 28.87batch/s, loss=0.154]
 17%|█▋        | 4/24 [00:00<00:00, 36.12batch/s, loss=0.525]
Training
--------------
Ensemble accuracy 0.9743333333333334
Gate oracle accuracy 0.9795555555555555
Loss 0.29242939843485755
100%|██████████| 24/24 [00:00<00:00, 33.15batch/s, loss=0.427]
  1%|▏         | 3/216 [00:00<00:08, 25.63batch/s, loss=0.199]
Valid

In [32]:
dummy_args = Configuration(DEFAULT_DICT)
dummy_args.moe_gating = 'simple'
dummy_args.method = 'moe'
dummy_args.n = 5
dummy_args.model = 'lenet'
dummy_args.optimizer = 'adam'
# args.cpu = True
dummy_args.moe_type = 'dense'
dummy_args.predict_gated = True

In [35]:
dummy_trainer = get_trainer(dummy_args, 'cuda')

Using a simple gate


In [37]:
dummy_trainer.model.gating_network = g
dummy_trainer.model.experts = exps

In [38]:
dummy_trainer.test(test_loader, metric_dict);

  0%|          | 0/40 [00:00<?, ?batch/s]
Testing
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 40/40 [00:00<00:00, 63.39batch/s]Results: 
Accuracy: 0.9833
NLL: 0.08413592033321038
ECE: 0.015074500668048853
Brier: 0.028405361668774276

