In [1]:
%cd ..

/home/eli/AnacondaProjects/ppc_experiments


In [2]:
import argparse
import collections
import numpy as np
import pyro
import torch
import data_loader.data_loaders as module_data
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
from parse_config import ConfigParser
import trainer.trainer as module_trainer

In [3]:
# pyro.enable_validation(True)
# torch.autograd.set_detect_anomaly(True)

In [4]:
# fix random seeds for reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [5]:
from utils import read_json

config = read_json("experiments/ppc_mnist_config.json")
config = ConfigParser(config)

In [6]:
logger = config.get_logger('train')

# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()

# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)

# get function handles of metrics
metrics = [getattr(module_metric, met) for met in config['metrics']]

# build optimizer.
if "lr_scheduler" in config:
    lr_scheduler = getattr(pyro.optim, config["lr_scheduler"]["type"])
    lr_scheduler = optimizer = lr_scheduler({
        "optimizer": getattr(torch.optim, config["optimizer"]["type"]),
        "optim_args": config["optimizer"]["args"]["optim_args"],
        **config["lr_scheduler"]["args"]
    })
else:
    optimizer = config.init_obj('optimizer', pyro.optim)
    lr_scheduler = None

# build trainer
# kwargs = config['trainer'].pop('args')
trainer = config.init_obj('trainer', module_trainer, model, metrics, optimizer,
                          config=config, data_loader=data_loader,
                          valid_data_loader=valid_data_loader,
                          lr_scheduler=lr_scheduler)

MnistPpc(
  (digit_features): DigitFeatures()
  (decoder): DigitDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=10, out_features=200, bias=True)
      (1): ReLU()
      (2): Linear(in_features=200, out_features=400, bias=True)
      (3): ReLU()
      (4): Linear(in_features=400, out_features=784, bias=True)
      (5): Sigmoid()
    )
  )
  (graph): GraphicalModel()
)
Trainable parameters: 396984
Initialize particles: train batch 0
Initialize particles: train batch 1
Initialize particles: train batch 2
Initialize particles: train batch 3
Initialize particles: train batch 4
Initialize particles: train batch 5
Initialize particles: train batch 6
Initialize particles: train batch 7
Initialize particles: train batch 8
Initialize particles: train batch 9
Initialize particles: train batch 10
Initialize particles: train batch 11
Initialize particles: train batch 12
Initialize particles: train batch 13
Initialize particles: train batch 14
Initialize particles: train batch 15
I

In [7]:
logger.info(trainer.config.log_dir)

saved/log/Mnist_Ppc/0208_133044


In [8]:
# ACTIVITIES = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
# SCHEDULE = torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1)
# ON_TRACE_READY = torch.profiler.tensorboard_trace_handler(trainer.config.log_dir)
# with torch.profiler.profile(activities=ACTIVITIES, on_trace_ready=ON_TRACE_READY, profile_memory=True, schedule=SCHEDULE, with_stack=True) as profiler:
#     trainer.train(profiler)

In [None]:
trainer.train()

    epoch          : 1
    loss           : -1371.6550497378944
    ess            : 8.001120202946213
    log_marginal   : 1371.6550520411079
    val_loss       : -1461.624532063802
    val_ess        : 8.00117572148641
    val_log_marginal: 1461.624532063802
    epoch          : 2
    loss           : -1465.0692006237102
    ess            : 8.001165551959344
    log_marginal   : 1465.069201775317
    val_loss       : -1512.1118265787761
    val_ess        : 8.001176516215006
    val_log_marginal: 1512.1118265787761
    epoch          : 3
    loss           : -1505.1728987783756
    ess            : 8.001172641538224
    log_marginal   : 1505.1728987783756
    val_loss       : -1541.2795003255208
    val_ess        : 8.001179218292236
    val_log_marginal: 1541.2795003255208
    epoch          : 4
    loss           : -1544.1310493901092
    ess            : 8.00117652821091
    log_marginal   : 1544.1310493901092
    val_loss       : -1575.1084798177083
    val_ess        : 8.001182

In [None]:
trainer.model.eval()
trainer.cpu()
trainer.train_particles.cpu()
trainer.valid_particles.cpu()