In [1]:
%cd ..
%env TORCH_CUDNN_SDPA_ENABLED=1

/workspace/shai_hulud/ppc_experiments
env: TORCH_CUDNN_SDPA_ENABLED=1


In [2]:
import argparse
import collections
import numpy as np
import pyro
import torch
import data_loader.data_loaders as module_data
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
from parse_config import ConfigParser
import trainer.trainer as module_trainer

In [3]:
# pyro.enable_validation(True)
# torch.autograd.set_detect_anomaly(True)

In [4]:
# fix random seeds for reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [5]:
from utils import read_json

config = read_json("experiments/ppc_flowers_config.json")
config = ConfigParser(config)

In [6]:
logger = config.get_logger('train')

# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()

# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)

# get function handles of metrics
metrics = [getattr(module_metric, met) for met in config['metrics']]

# build optimizer.
if "lr_scheduler" in config:
    lr_scheduler = getattr(pyro.optim, config["lr_scheduler"]["type"])
    lr_scheduler = optimizer = lr_scheduler({
        "optimizer": getattr(torch.optim, config["optimizer"]["type"]),
        "optim_args": config["optimizer"]["args"]["optim_args"],
        **config["lr_scheduler"]["args"]
    })
else:
    optimizer = config.init_obj('optimizer', pyro.optim)
    lr_scheduler = None

# build trainer
# kwargs = config['trainer'].pop('args')
trainer = config.init_obj('trainer', module_trainer, model, metrics, optimizer,
                          config=config, data_loader=data_loader,
                          valid_data_loader=valid_data_loader,
                          lr_scheduler=lr_scheduler)

DiffusionPpc(
  (diffusion): DiffusionStep(
    (unet): ScoreNetwork0(
      (_convs): ModuleList(
        (0): Sequential(
          (0): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): LogSigmoid()
        )
        (1): Sequential(
          (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (2): LogSigmoid()
        )
        (2): Sequential(
          (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (2): LogSigmoid()
        )
        (3): Sequential(
          (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (2): LogSigmoid()
        )
        (4): Sequential(
          (0): MaxPool2d(ker

In [7]:
logger.info(trainer.config.log_dir)

saved/log/FlowersDiffusion_Ppc/0502_011314


In [8]:
trainer.train()

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


    epoch          : 1
    loss           : 63886066.747826084
    ess            : 2.4758910469386888
    log_marginal   : -63886064.48695652
    val_loss       : 62639128.0
    val_ess        : 3.5768340550936184
    val_log_marginal: -62639125.84615385
    epoch          : 2
    loss           : 57436416.8
    ess            : 4.7479864700980805
    log_marginal   : -57436415.16521739
    val_loss       : 57618113.23076923
    val_ess        : 7.356190754817082
    val_log_marginal: -57618111.692307696
    epoch          : 3
    loss           : 52755137.73913044
    ess            : 6.534591430166493
    log_marginal   : -52755135.82608695
    val_loss       : 53510175.384615384
    val_ess        : 8.043346515068642
    val_log_marginal: -53510173.84615385
    epoch          : 4
    loss           : 48810138.88695652
    ess            : 8.60696031736291
    log_marginal   : -48810137.14782609
    val_loss       : 50563503.384615384
    val_ess        : 11.650914265559269
    val_

KeyboardInterrupt: 

In [None]:
trainer.model.eval()
trainer.model.cpu()
trainer.cpu()
trainer.train_particles.cpu()
trainer.valid_particles.cpu()

In [None]:
trainer.model.graph.clear()
trainer._load_particles(range(trainer.data_loader.batch_size), False)

In [None]:
for site in trainer.model.graph.nodes:
    trainer.model.graph.nodes[site]['is_observed'] = trainer.model.graph.nodes[site]['value'] is not None

In [None]:
import utils

In [None]:
with pyro.plate_stack("forward", (trainer.num_particles, trainer.data_loader.batch_size)):
    model = pyro.condition(trainer.model, data={k: v['value'] for k, v in trainer.model.graph.nodes.items()})
    xs = model()

In [None]:
import matplotlib.pyplot as plt

In [None]:
for i in range(10):
    plt.imshow(xs.mean(dim=0)[i].squeeze().detach().cpu().numpy())
    plt.show()