In [1]:
%cd ..

/home/eli/AnacondaProjects/ppc_experiments


In [2]:
import argparse
import collections
import numpy as np
import pyro
import torch
import data_loader.data_loaders as module_data
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
from parse_config import ConfigParser
import trainer.trainer as module_trainer

In [3]:
# pyro.enable_validation(True)
# torch.autograd.set_detect_anomaly(True)

In [4]:
# fix random seeds for reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [5]:
from utils import read_json

config = read_json("experiments/ppc_mnist_config.json")
config = ConfigParser(config)

In [6]:
logger = config.get_logger('train')

# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()

# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)

# get function handles of metrics
metrics = [getattr(module_metric, met) for met in config['metrics']]

# build optimizer.
if "lr_scheduler" in config:
    lr_scheduler = getattr(pyro.optim, config["lr_scheduler"]["type"])
    lr_scheduler = optimizer = lr_scheduler({
        "optimizer": getattr(torch.optim, config["optimizer"]["type"]),
        "optim_args": config["optimizer"]["args"]["optim_args"],
        **config["lr_scheduler"]["args"]
    })
else:
    optimizer = config.init_obj('optimizer', pyro.optim)
    lr_scheduler = None

# build trainer
# kwargs = config['trainer'].pop('args')
trainer = config.init_obj('trainer', module_trainer, model, metrics, optimizer,
                          config=config, data_loader=data_loader,
                          valid_data_loader=valid_data_loader,
                          lr_scheduler=lr_scheduler)

MnistPpc(
  (prior): GaussianPrior()
  (decoder1): ConditionalGaussian(
    (decoder): Sequential(
      (0): ReLU()
      (1): Linear(in_features=20, out_features=128, bias=True)
    )
  )
  (decoder2): ConditionalGaussian(
    (decoder): Sequential(
      (0): ReLU()
      (1): Linear(in_features=128, out_features=256, bias=True)
    )
  )
  (likelihood): MlpBernoulliLikelihood(
    (decoder): Sequential(
      (0): ReLU()
      (1): Linear(in_features=256, out_features=784, bias=True)
    )
  )
  (graph): PpcGraphicalModel()
)
Trainable parameters: 319540
Initialize particles: train batch 0
Initialize particles: train batch 1
Initialize particles: train batch 2
Initialize particles: train batch 3
Initialize particles: train batch 4
Initialize particles: train batch 5
Initialize particles: train batch 6
Initialize particles: train batch 7
Initialize particles: train batch 8
Initialize particles: train batch 9
Initialize particles: train batch 10
Initialize particles: train batch 11
I

In [7]:
logger.info(trainer.config.log_dir)

saved/log/Mnist_Ppc/0327_194126


In [None]:
trainer.train()

    epoch          : 1
    loss           : -647.0783520468038
    ess            : 3.8803834474482244
    log_marginal   : 647.1240250646221
    val_loss       : -805.2631200154623
    val_ess        : 3.8806121249993644
    val_log_marginal: 805.3075307210287
    epoch          : 2
    loss           : -875.4098186764107
    ess            : 3.8780529860636634
    log_marginal   : 875.4556199204865
    val_loss       : -923.991091410319
    val_ess        : 3.861002564430237
    val_log_marginal: 924.0440673828125
    epoch          : 3
    loss           : -967.4436850886774
    ess            : 3.8696132723189076
    log_marginal   : 967.4922303113892
    val_loss       : -994.2424341837565
    val_ess        : 3.8775921165943146
    val_log_marginal: 994.2884928385416
    epoch          : 4
    loss           : -1022.7254132456125
    ess            : 3.8629318108490858
    log_marginal   : 1022.7773385432095
    val_loss       : -1038.3155314127605
    val_ess        : 3.86201964

In [None]:
trainer.model.eval()
trainer.model.cpu()
trainer.cpu()
trainer.train_particles.cpu()
trainer.valid_particles.cpu()

In [None]:
trainer.model.graph.clear()
trainer._load_particles(range(trainer.data_loader.batch_size), False)

In [None]:
for site in trainer.model.graph.nodes:
    trainer.model.graph.nodes[site]['is_observed'] = trainer.model.graph.nodes[site]['value'] is not None

In [None]:
import utils

In [None]:
with pyro.plate_stack("forward", (trainer.num_particles, trainer.data_loader.batch_size)):
    model = pyro.condition(trainer.model, data={k: v['value'] for k, v in trainer.model.graph.nodes.items()})
    xs = model()

In [None]:
import matplotlib.pyplot as plt

In [None]:
for i in range(10):
    plt.imshow(xs.mean(dim=0)[i].squeeze().detach().cpu().numpy())
    plt.show()