In [1]:
%cd ..
%env TORCH_CUDNN_SDPA_ENABLED=1

/workspace/shai_hulud/ppc_experiments
env: TORCH_CUDNN_SDPA_ENABLED=1


In [2]:
import argparse
import collections
import numpy as np
import pyro
import torch
import data_loader.data_loaders as module_data
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
from parse_config import ConfigParser
import trainer.trainer as module_trainer

In [3]:
# pyro.enable_validation(True)
# torch.autograd.set_detect_anomaly(True)

In [4]:
# fix random seeds for reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [5]:
from utils import read_json

config = read_json("experiments/ppc_flowers_config.json")
config = ConfigParser(config)

In [6]:
logger = config.get_logger('train')

# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()

# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)

# get function handles of metrics
metrics = [getattr(module_metric, met) for met in config['metrics']]

# build optimizer.
if "lr_scheduler" in config:
    lr_scheduler = getattr(pyro.optim, config["lr_scheduler"]["type"])
    lr_scheduler = optimizer = lr_scheduler({
        "optimizer": getattr(torch.optim, config["optimizer"]["type"]),
        "optim_args": config["optimizer"]["args"]["optim_args"],
        **config["lr_scheduler"]["args"]
    })
else:
    optimizer = config.init_obj('optimizer', pyro.optim)
    lr_scheduler = None

# build trainer
# kwargs = config['trainer'].pop('args')
trainer = config.init_obj('trainer', module_trainer, model, metrics, optimizer,
                          config=config, data_loader=data_loader,
                          valid_data_loader=valid_data_loader,
                          lr_scheduler=lr_scheduler)

DiffusionPpc(
  (diffusion): DiffusionStep(
    (unet): Unet(
      (init_conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
      (time_mlp): Sequential(
        (0): SinusoidalPosEmb()
        (1): Linear(in_features=64, out_features=256, bias=True)
        (2): GELU(approximate='none')
        (3): Linear(in_features=256, out_features=256, bias=True)
      )
      (downs): ModuleList(
        (0): ModuleList(
          (0-1): 2 x ResnetBlock(
            (mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=256, out_features=128, bias=True)
            )
            (block1): Block(
              (proj): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (proj): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 64, eps=1e-0

In [7]:
logger.info(trainer.config.log_dir)

saved/log/FlowersDiffusion_Ppc/0426_014324


In [None]:
trainer.train()

    epoch          : 1
    loss           : 1350953.785326087
    ess            : 2.085200812505639
    log_marginal   : -1350953.785326087
    val_loss       : 968653.8605769231
    val_ess        : 1.9216643021656916
    val_log_marginal: -968653.8605769231
    epoch          : 2
    loss           : 860085.5902173913
    ess            : 1.9050397230231244
    log_marginal   : -860085.5842391305
    val_loss       : 807672.5769230769
    val_ess        : 1.879855330173786
    val_log_marginal: -807672.5625


In [None]:
trainer.model.eval()
trainer.model.cpu()
trainer.cpu()
trainer.train_particles.cpu()
trainer.valid_particles.cpu()

In [None]:
trainer.model.graph.clear()
trainer._load_particles(range(trainer.data_loader.batch_size), False)

In [None]:
for site in trainer.model.graph.nodes:
    trainer.model.graph.nodes[site]['is_observed'] = trainer.model.graph.nodes[site]['value'] is not None

In [None]:
import utils

In [None]:
with pyro.plate_stack("forward", (trainer.num_particles, trainer.data_loader.batch_size)):
    model = pyro.condition(trainer.model, data={k: v['value'] for k, v in trainer.model.graph.nodes.items()})
    xs = model()

In [None]:
import matplotlib.pyplot as plt

In [None]:
for i in range(10):
    plt.imshow(xs.mean(dim=0)[i].squeeze().detach().cpu().numpy())
    plt.show()