In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import sys
sys.path.insert(0, './src')
from data import Dataset
from training import train
from plotting import plot_mfg_classification, plot_mcdo_classification, mfg_regression_inference, mcdo_regression_inference
from metrics import log_likelihood

In [2]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

```python 
DATASETS = {
    'boston_housing': BostonHousingData,
    'concrete': ConcreteData,
    'energy_efficiency': EnergyEfficiencyData,
    'kin8nm': Kin8nmData,
    'naval_propulsion': NavalPropulsionData,
    'ccpp': CCPPData,
    'protein_structure': ProteinStructureData,
    'red_wine': RedWineData,
    'yacht_hydrodynamics': YachtHydrodynamicsData,
    'year_prediction_msd': YearPredictionMSDData,
    'mnist': MnistData,
    'fashion_mnist': FashionMnistData,
    'cifar_10': Cifar10,
    'svhn': SVHN
}
```

In [3]:
args = dotdict({})

In [4]:
args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
args['torchType'] = torch.float32

args['model_type'] = 'mcdo' # 'mcdo', 'mfg'
args['dataset_name'] = 'mnist'
if args['dataset_name'].find('mnist') > -1:
    args['num_epoches'] = 201
    args['print_info'] = 50
    args['n_IS'] = 10000
    
    args['train_batch_size'] = 100
    args['val_dataset'] = 10000
    args['val_batch_size'] = 100
    args['test_batch_size'] = 100
else:
    args['n_IS'] = 1
    args['num_epoches'] = 10001
    args['print_info'] = 1000
    args['train_batch_size'] = 150
    args['val_dataset'] = 20
    args['val_batch_size'] = 20
    args['test_batch_size'] = 10

In [5]:
dataset = Dataset(args)

Train data shape 60000




In [6]:
# model, params = train(args, dataset)

In [7]:
# obj_id = 21

# if args.problem == 'classification':
#     if args.model_type == 'mfg':
#         plot_mfg_classification(args, model, dataset, params)
#     elif args.model_type == 'mcdo':
#         plot_mcdo_classification(args, model, dataset, params)
# else:
#     if args.model_type == 'mfg':
#         mfg_regression_inference(args, model, dataset, params, obj_id)
#     elif args.model_type == 'mcdo':
#         mcdo_regression_inference(args, model, dataset, params, obj_id)

## Comparison of MFG and MCDO

In [8]:
## MFG
args['model_type'] = 'mfg' # 'mcdo', 'mfg'
model_mfg, params_mfg = train(args, dataset)

  0%|          | 0/201 [00:00<?, ?it/s]

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 14, 14]             416
            Conv2d-2             [-1, 32, 7, 7]          12,832
            Conv2d-3             [-1, 64, 4, 4]          51,264
            Linear-4                  [-1, 256]         262,400
            Linear-5                   [-1, 10]           2,570
Total params: 329,482
Trainable params: 329,482
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 1.26
Estimated Total Size (MB): 1.31
----------------------------------------------------------------


  0%|          | 1/201 [00:02<07:50,  2.35s/it]

ELBO value is -148.69802856445312 on epoch number 0
Mean validation accuracy at epoch number 0 is 0.3958999514579773
Current KL is 1.185486078262329


 25%|██▌       | 51/201 [01:54<05:38,  2.26s/it]

ELBO value is -3.6171488761901855 on epoch number 50
Mean validation accuracy at epoch number 50 is 0.991599977016449
Current KL is 2.617124080657959


 50%|█████     | 101/201 [03:46<03:46,  2.26s/it]

ELBO value is -2.271806240081787 on epoch number 100
Mean validation accuracy at epoch number 100 is 0.991100013256073
Current KL is 2.162872314453125


 75%|███████▌  | 151/201 [05:39<01:54,  2.29s/it]

ELBO value is -11.276079177856445 on epoch number 150
Mean validation accuracy at epoch number 150 is 0.9908000230789185
Current KL is 2.0957090854644775


100%|██████████| 201/201 [07:31<00:00,  2.24s/it]

ELBO value is -3.327393054962158 on epoch number 200
Mean validation accuracy at epoch number 200 is 0.9915000200271606
Current KL is 2.074145793914795





In [9]:
## MCDO
args['model_type'] = 'mcdo' # 'mcdo', 'mfg'
model_mcdo, params_mcdo = train(args, dataset)

  0%|          | 0/201 [00:00<?, ?it/s]

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 14, 14]             416
            Conv2d-2             [-1, 32, 7, 7]          12,832
            Conv2d-3             [-1, 64, 4, 4]          51,264
            Linear-4                  [-1, 256]         262,400
            Linear-5                   [-1, 10]           2,570
Total params: 329,482
Trainable params: 329,482
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 1.26
Estimated Total Size (MB): 1.31
----------------------------------------------------------------


  0%|          | 1/201 [00:01<06:13,  1.87s/it]

Log likelihood value is -162.95883178710938 on epoch number 0
Mean validation accuracy at epoch number 0 is 0.7418000102043152


 25%|██▌       | 51/201 [01:30<04:30,  1.80s/it]

Log likelihood value is -59.582763671875 on epoch number 50
Mean validation accuracy at epoch number 50 is 0.9917999505996704


 50%|█████     | 101/201 [03:00<03:05,  1.86s/it]

Log likelihood value is -25.692922592163086 on epoch number 100
Mean validation accuracy at epoch number 100 is 0.9904998540878296


 75%|███████▌  | 151/201 [04:32<01:33,  1.87s/it]

Log likelihood value is -98.06649780273438 on epoch number 150
Mean validation accuracy at epoch number 150 is 0.9925000071525574


100%|██████████| 201/201 [06:03<00:00,  1.81s/it]

Log likelihood value is -28.572059631347656 on epoch number 200
Mean validation accuracy at epoch number 200 is 0.9881000518798828





In [10]:
args['model_type'] = 'mfg' # 'mcdo', 'mfg'
mfg_ll = log_likelihood(model_mfg, params_mfg, dataset, args)

args['model_type'] = 'mcdo' # 'mcdo', 'mfg'
mcdo_ll = log_likelihood(model_mcdo, params_mcdo, dataset, args)

100it [13:13,  7.93s/it]
100it [03:33,  2.14s/it]


In [11]:
print(f'Mean-Field-Gaussian loglikelihood: {np.mean(mfg_ll)} +/- {np.std(mfg_ll)}')
print(f'Monte-Carlo Dropout loglikelihood: {np.mean(mcdo_ll)} +/- {np.std(mcdo_ll)}')

Mean-Field-Gaussian loglikelihood: -9.44112829208374 +- 2.0521901661063016
Monte-Carlo Dropout loglikelihood: -0.2614870929718018 +- 0.04902956887986008


## Inference Examples

In [17]:
plot_mfg_classification(args, model_mfg, dataset, params_mfg)

In [18]:
plot_mcdo_classification(args, model_mcdo, dataset, params_mcdo)