# Summary of results

Here we report typical statistics for validation/test sets, e.g., ELBO, NLL, BPD, distortion (D) and rate (R).

In [1]:
import torch
torch.__version__

'1.8.1+cu102'

In [2]:
import numpy as np
import torch
import torch.distributions as td
import matplotlib.pyplot as plt
import torch.nn as nn

In [3]:
from collections import namedtuple, OrderedDict, defaultdict
from tqdm.auto import tqdm
from itertools import chain
from tabulate import tabulate

In [4]:
import sys
sys.path.append("../")

In [5]:
from components import GenerativeModel, InferenceModel, VAE
from data import load_mnist
from hparams import load_cfg, make_args
from main import make_state, get_batcher, validate

In [6]:
from analysis import probe_prior, compare_marginals, compare_samples

In [7]:
import pathlib

In [8]:
import random
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
rng = np.random.RandomState(0)

In [9]:
import pickle
knn_model = pickle.load(open('knnclassifier.pickle', 'rb'))

# Helper code

In [10]:
from analysis import collect_samples

# Load model and data

* Load hyperparameters
* Load model state
* Load MNIST data

In [11]:
train_loader, valid_loader, test_loader = load_mnist(
    batch_size=100, 
    save_to='../tmp', 
    height=28, 
    width=28
)

In [12]:
num_samples_test = 1000

In [13]:
ls ../trained_models/

[0m[01;34mcategorical[0m/  [01;34mgaussian[0m/           [01;34mmixed-maxent[0m/
[01;34mdirichlet[0m/    [01;34mgaussiansp-maxent[0m/  [01;34monehotcat[0m/


In [14]:
valid_results = defaultdict(list)
test_results = defaultdict(list)

In [15]:
dirs = [('gaussian', d, True) for d in pathlib.Path('../trained_models/gaussian/').iterdir() if d.is_dir()]
dirs += [('dirichlet', d, True) for d in pathlib.Path('../trained_models/dirichlet/').iterdir() if d.is_dir()]
dirs += [('categorical', d, True) for d in pathlib.Path('../trained_models/categorical').iterdir() if d.is_dir()]
dirs += [('onehotcat', d, True) for d in pathlib.Path('../trained_models/onehotcat').iterdir() if d.is_dir()]
dirs += [('mixed-maxent', d, True) for d in pathlib.Path('../trained_models/mixed-maxent/').iterdir() if d.is_dir()]

In [16]:
dirs

[('gaussian', PosixPath('../trained_models/gaussian/bumbling-haze-15'), True),
 ('gaussian', PosixPath('../trained_models/gaussian/crimson-yogurt-13'), True),
 ('gaussian',
  PosixPath('../trained_models/gaussian/glamorous-music-12'),
  True),
 ('gaussian', PosixPath('../trained_models/gaussian/noble-water-11'), True),
 ('gaussian', PosixPath('../trained_models/gaussian/dainty-wood-14'), True),
 ('dirichlet',
  PosixPath('../trained_models/dirichlet/gallant-flower-19'),
  True),
 ('dirichlet', PosixPath('../trained_models/dirichlet/dark-meadow-17'), True),
 ('dirichlet', PosixPath('../trained_models/dirichlet/firm-water-16'), True),
 ('dirichlet', PosixPath('../trained_models/dirichlet/icy-dust-18'), True),
 ('dirichlet', PosixPath('../trained_models/dirichlet/jumping-bee-20'), True),
 ('categorical',
  PosixPath('../trained_models/categorical/chocolate-voice-30'),
  True),
 ('categorical',
  PosixPath('../trained_models/categorical/lively-smoke-26'),
  True),
 ('categorical',
  PosixP

In [17]:
for cls, directory, redo in tqdm(dirs):
    if not redo:
        continue
    args = make_args(
        load_cfg(
            f"{directory}/cfg.json", 
            # use this to specify a decide for analysis
            device='cuda:0',
            # use this to change paths if you need
            data_dir='../tmp',
            # you don't really need to change the output_dir
        )
    )
    experiment = directory.name
    print(f"Experiment: {cls}/{experiment}")

    state = make_state(
        args, 
        device=args.device, 
        ckpt_path=f"{directory}/ckpt.last"
    )
        
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    rng = np.random.RandomState(0) 
        
    print('Validating...')
    val_metrics = validate(
        state.vae, get_batcher(valid_loader, args), 
        num_samples=num_samples_test, 
        compute_DR=True,
        progressbar=True,
    )            
    
    print()
                
    r = [
        val_metrics[0].numpy(),  # NLL
        val_metrics[1].numpy(),  # BPD
        val_metrics[2]['ELBO'].mean(),  # ELBO
        val_metrics[2]['D'].mean(),  # D
        val_metrics[2]['R'].mean(),  # R
        val_metrics[2].get('R_F', np.zeros(1)).mean(),  # R
        val_metrics[2].get('R_Y|f', np.zeros(1)).mean(),  # R
        val_metrics[2].get('R_Y', np.zeros(1)).mean(),  # R
        val_metrics[2].get('R_Z', np.zeros(1)).mean(),  # R
    ]
    valid_results[cls].append(r)
    
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    rng = np.random.RandomState(0) 
    

    print('Testing...')
    test_metrics = validate(
        state.vae, get_batcher(test_loader, args), 
        num_samples=num_samples_test, 
        compute_DR=True,
        progressbar=True,
    )            

    print()

    r = [
        test_metrics[0].numpy(),  # NLL
        test_metrics[1].numpy(),  # BPD
        test_metrics[2]['ELBO'].mean(),  # ELBO
        test_metrics[2]['D'].mean(),  # D
        test_metrics[2]['R'].mean(),  # R
        test_metrics[2].get('R_F', np.zeros(1)).mean(),  # R
        test_metrics[2].get('R_Y|f', np.zeros(1)).mean(),  # R
        test_metrics[2].get('R_Y', np.zeros(1)).mean(),  # R
        test_metrics[2].get('R_Z', np.zeros(1)).mean(),  # R
    ]
    test_results[cls].append(r)
    

  0%|          | 0/25 [00:00<?, ?it/s]

Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: gaussian/bumbling-haze-15
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: gaussian/crimson-yogurt-13
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: gaussian/glamorous-music-12
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: gaussian/noble-water-11
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: gaussian/dainty-wood-14
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: dirichlet/gallant-flower-19
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: dirichlet/dark-meadow-17
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: dirichlet/firm-water-16
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: dirichlet/icy-dust-18
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: dirichlet/jumping-bee-20
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: categorical/chocolate-voice-30
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: categorical/lively-smoke-26
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: categorical/earthy-frog-29
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: categorical/iconic-valley-28
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: categorical/splendid-glade-27
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: onehotcat/gallant-salad-22
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: onehotcat/peach-valley-25
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: onehotcat/pleasant-disco-24
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: onehotcat/fluent-violet-23
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: onehotcat/confused-sponge-21
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: mixed-maxent/electric-firefly-5
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: mixed-maxent/robust-wildflower-4
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: mixed-maxent/leafy-sky-3
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: mixed-maxent/lucky-grass-1
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]


Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: mixed-maxent/smooth-armadillo-2
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]




In [18]:
from tabulate import tabulate

In [19]:
headers = ['NLL', 'BPD', 'ELBO', 'D', 'R', 'R_F', 'R_Y|f', 'R_Y', 'R_Z']

In [20]:
print("Validation - Gaussian")
print(tabulate(valid_results['gaussian'], headers=headers, floatfmt='.2f'))
print("Validation - Dirichlet")
print(tabulate(valid_results['dirichlet'], headers=headers, floatfmt='.2f'))
print("Validation - Categorical")
print(tabulate(valid_results['categorical'], headers=headers, floatfmt='.2f'))
print("Validation - GS-ST")
print(tabulate(valid_results['onehotcat'], headers=headers, floatfmt='.2f'))
print("Validation - Mixed Dir")
print(tabulate(valid_results['mixed-maxent'], headers=headers, floatfmt='.2f'))

Validation - Gaussian
  NLL    BPD    ELBO      D      R    R_F    R_Y|f    R_Y    R_Z
-----  -----  ------  -----  -----  -----  -------  -----  -----
91.67  13.22  -96.93  77.03  19.90   0.00     0.00   0.00  19.90
91.69  13.23  -96.90  76.90  20.00   0.00     0.00   0.00  20.00
91.69  13.23  -97.06  77.16  19.90   0.00     0.00   0.00  19.90
91.72  13.23  -96.94  76.98  19.96   0.00     0.00   0.00  19.96
91.63  13.22  -97.03  77.17  19.86   0.00     0.00   0.00  19.86
Validation - Dirichlet
  NLL    BPD    ELBO      D      R    R_F    R_Y|f    R_Y    R_Z
-----  -----  ------  -----  -----  -----  -------  -----  -----
94.12  13.58  -98.70  78.58  20.12   0.00     0.00   0.00  20.12
94.74  13.67  -99.45  79.56  19.89   0.00     0.00   0.00  19.89
94.47  13.63  -99.66  79.29  20.38   0.00     0.00   0.00  20.38
94.65  13.66  -99.37  79.26  20.11   0.00     0.00   0.00  20.11
94.39  13.62  -99.24  79.07  20.17   0.00     0.00   0.00  20.17
Validation - Categorical
   NLL    BPD     EL

In [21]:
# print('mixed-dir')
# print(tabulate(
#     [
#         ['mean'] + [x for x in np.mean(valid_results['mixed-maxent'], 0)],
#         ['std'] + [x for x in np.std(valid_results['mixed-maxent'], 0)],
#         ['min'] + [x for x in np.min(valid_results['mixed-maxent'], 0)],
#         ['max'] + [x for x in np.max(valid_results['mixed-maxent'], 0)]
#     ], 
#     headers=headers, floatfmt='.2f'))

In [22]:
idx = np.array([3, 4, 0], dtype=int)

print(tabulate(
    [
        ['gaussian', 'valid', 'mean'] + [x for x in np.array(valid_results['gaussian'])[:,idx].mean(0)],        
        ['dirichlet', 'valid', 'mean'] + [x for x in np.array(valid_results['dirichlet'])[:,idx].mean(0)],
        ['categorical', 'valid', 'mean'] + [x for x in np.array(valid_results['categorical'])[:,idx].mean(0)],
        ['onehotcat', 'valid', 'mean'] + [x for x in np.array(valid_results['onehotcat'])[:,idx].mean(0)],
        ['mixed-dir', 'valid', 'mean'] + [x for x in np.array(valid_results['mixed-maxent'])[:,idx].mean(0)],        

        ['gaussian', 'valid', 'stddev'] + [x for x in np.array(valid_results['gaussian'])[:,idx].std(0)],        
        ['dirichlet', 'valid', 'stddev'] + [x for x in np.array(valid_results['dirichlet'])[:,idx].std(0)],
        ['categorical', 'valid', 'stddev'] + [x for x in np.array(valid_results['categorical'])[:,idx].std(0)],
        ['onehotcat', 'valid', 'stddev'] + [x for x in np.array(valid_results['onehotcat'])[:,idx].std(0)],
        ['mixed-dir', 'valid', 'stddev'] + [x for x in np.array(valid_results['mixed-maxent'])[:,idx].std(0)],        

        ['gaussian', 'test', 'mean'] + [x for x in np.array(test_results['gaussian'])[:,idx].mean(0)],
        ['dirichlet', 'test', 'mean'] + [x for x in np.array(test_results['dirichlet'])[:,idx].mean(0)],
        ['categorical', 'test', 'mean'] + [x for x in np.array(test_results['categorical'])[:,idx].mean(0)],
        ['onehotcat', 'test', 'mean'] + [x for x in np.array(test_results['onehotcat'])[:,idx].mean(0)],
        ['mixed-dir', 'test', 'mean'] + [x for x in np.array(test_results['mixed-maxent'])[:,idx].mean(0)],
        
        ['gaussian', 'test', 'stddev'] + [x for x in np.array(test_results['gaussian'])[:,idx].std(0)],
        ['dirichlet', 'test', 'stddev'] + [x for x in np.array(test_results['dirichlet'])[:,idx].std(0)],
        ['categorical', 'test', 'stddev'] + [x for x in np.array(test_results['categorical'])[:,idx].std(0)],
        ['onehotcat', 'test', 'stddev'] + [x for x in np.array(test_results['onehotcat'])[:,idx].std(0)],
        ['mixed-dir', 'test', 'stddev'] + [x for x in np.array(test_results['mixed-maxent'])[:,idx].std(0)],
    ], 
    headers=['Model', 'Dataset', 'Statistic', 'D', 'R', 'NLL'], floatfmt='.2f'))

Model        Dataset    Statistic         D      R     NLL
-----------  ---------  -----------  ------  -----  ------
gaussian     valid      mean          77.05  19.93   91.68
dirichlet    valid      mean          79.15  20.13   94.48
categorical  valid      mean         163.90   2.28  166.14
onehotcat    valid      mean         171.02   1.73  167.83
mixed-dir    valid      mean          90.97  19.16  107.12
gaussian     valid      stddev         0.10   0.05    0.03
dirichlet    valid      stddev         0.33   0.15    0.22
categorical  valid      stddev         0.08   0.00    0.08
onehotcat    valid      stddev         0.26   0.01    0.25
mixed-dir    valid      stddev         0.75   0.38    0.48
gaussian     test       mean          76.67  19.94   91.12
dirichlet    test       mean          78.62  19.94   93.81
categorical  test       mean         164.72   2.28  166.95
onehotcat    test       mean         171.76   1.70  168.50
mixed-dir    test       mean          90.34  19.39  106.