In [1]:
%load_ext autoreload
%autoreload 2

import pydrake
import torch
import pyro
from pyro import poutine
import time

import scene_generation.data.dataset_utils as dataset_utils
from scene_generation.models.planar_multi_object_multi_class_2 import MultiObjectMultiClassModel

In [123]:
import numpy as np
#DATA_FILE = "/home/gizatt/projects/scene_generation/data/planar_bin/planar_bin_static_scenes_geometric.yaml"
DATA_FILE = "/home/gizatt/projects/scene_generation/data/planar_bin/planar_bin_static_scenes.yaml"
scenes_dataset_yaml = dataset_utils.ScenesDataset(DATA_FILE)
params_by_object_class = {}
for env_i in range(len(scenes_dataset_yaml)):
    env = scenes_dataset_yaml[env_i]
    for obj_i in range(env["n_objects"]):
        obj_yaml = env["obj_%04d" % obj_i]
        class_name = obj_yaml["class"]
        if class_name not in params_by_object_class.keys():
            params_by_object_class[class_name] = []
        params_by_object_class[class_name].append(obj_yaml["pose"] + obj_yaml["params"])

for object_name in params_by_object_class.keys():
    print object_name, ": "
    params = np.stack(params_by_object_class[object_name])
    print params.shape
    print "means: ", np.mean(params, axis=0)
    print "vars: ", np.std(params, axis=0)

2d_box : 
(2479, 5)
means:  [ 0.05841386  0.19564861 -0.0062935   0.20078573  0.2000771 ]
vars:  [0.48295732 0.1329204  1.05602027 0.05787565 0.05763485]
2d_sphere : 
(2472, 4)
means:  [ 0.04950154  0.18883129 -0.19078293  0.09987665]
vars:  [0.50612466 0.13069154 2.06112832 0.02919288]


In [124]:
scenes_dataset = dataset_utils.ScenesDatasetVectorized(DATA_FILE)
data = scenes_dataset.get_full_dataset()

In [125]:
# Rig for SVI, running with AutoDelta, which converges fairly reliably but
# confuses the variances
from collections import defaultdict
from torch.distributions import constraints
from pyro.infer import Trace_ELBO, SVI
from pyro.contrib.autoguide import AutoDelta, AutoDiagonalNormal, AutoMultivariateNormal, AutoGuideList
import datetime
from tensorboardX import SummaryWriter

pyro.enable_validation(True)

log_dir = "/home/gizatt/projects/scene_generation/models/runs/pmomc2/" + datetime.datetime.now().strftime(
    "%Y-%m-%d-%H-%m-%s")
writer = SummaryWriter(log_dir)
def write_np_array(writer, name, x, i):
    for yi, y in enumerate(x):
        writer.add_scalar(name + "/%d" % yi, y, i)
        
print "All params: ", pyro.get_param_store().get_all_param_names()
interesting_params = ["keep_going_weights",
                      "new_class_weights",
                      "params_means_0", "params_means_1",
                      "params_vars_0", "params_vars_1"]
    
model = MultiObjectMultiClassModel(scenes_dataset)
pyro.clear_param_store()
guide = AutoDelta(poutine.block(model.model, hide=["obs"]))

optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = Trace_ELBO(max_plate_nesting=1)
svi = SVI(model.model, guide, optim, loss=elbo)
losses = []
        
snapshots = {}
for i in range(101):
    # Guesses on important things:
    # Big subsamples appear really important -- I had major loss of
    # convergence when using smaller subsample sizes (like ~50).
    # Also important: prior on the variance must be REALLY low.
    # Otherwise long_box_mean diverges to negative... :(
    # I think there's a fundamental problem with variance estimation
    # under this guide / with this system -- see the single-box-dataset
    # estimates that don't capture the x vs y variance.
    loss = svi.step(data, subsample_size=250)
    losses.append(loss)
    writer.add_scalar('loss', loss, i)

    for p in pyro.get_param_store().keys():
        if p not in snapshots.keys():
            snapshots[p] = []
        snapshots[p].append(pyro.param(p).cpu().detach().numpy().copy())
    for p in interesting_params:
        write_np_array(writer, p, snapshots[p][-1], i)
    if (i % 10 == 0):
        print ".",
    if (i % 50 == 0):
        print "\n"
        for p in interesting_params:
            print p, ": ", pyro.param(p).detach().numpy()
print "Done"

All params:  ['keep_going_weights', 'context_updater_module$$$weight_hh_l0', 'new_class_weights', 'context_updater_module$$$bias_hh_l0', 'class_encoder_module_1$$$4.bias', 'params_vars_1', 'params_vars_0', 'context_updater_module$$$weight_ih_l0', 'class_encoder_module_0$$$2.weight', 'class_encoder_module_1$$$4.weight', 'class_encoder_module_0$$$4.weight', 'context_updater_module$$$bias_ih_l0', 'class_encoder_module_0$$$0.weight', 'class_encoder_module_0$$$4.bias', 'class_encoder_module_0$$$0.bias', 'class_encoder_module_1$$$0.weight', 'class_encoder_module_1$$$2.weight', 'params_means_0', 'params_means_1', 'class_encoder_module_1$$$2.bias', 'class_encoder_module_0$$$2.bias', 'class_encoder_module_1$$$0.bias']
. 

keep_going_weights :  [0.5249792 0.5249792 0.5249792 0.5249792 0.5249792 0.5249792 0.5249792
 0.5249792 0.5249792 0.5249792 0.4750208 0.4750208 0.4750208 0.4750208
 0.4750208 0.4750208 0.4750208 0.4750208 0.4750208]
new_class_weights :  [0.54983395 0.450166  ]
params_means_0 :

In [127]:
# Convert that data back to a YAML environment, which is easier to
# handle.
scene_with_most_objects = None
for k in range(1):
    generated_data, generated_encodings, generated_contexts = model.model()
    scene_yaml = scenes_dataset.convert_vectorized_environment_to_yaml(
        generated_data)
    if scene_with_most_objects is None or scene_yaml[0]["n_objects"] > scene_with_most_objects["n_objects"]:
        scene_with_most_objects = scene_yaml[0]
        
print scene_with_most_objects
dataset_utils.DrawYamlEnvironment(scene_with_most_objects, "planar_bin")

{'obj_0002': {'color': [0.6200758153988021, 1.0, 1.0, 0.5], 'pose': [0.27748194336891174, 0.26269716024398804, -0.7031012177467346], 'params': [0.1910230815410614, 0.25044190883636475], 'class': '2d_box', 'params_names': ['height', 'length']}, 'obj_0003': {'color': [0.7809528565765493, 1.0, 1.0, 0.5], 'pose': [0.2374354898929596, 0.16561977565288544, -4.372330665588379], 'params': [0.11553522944450378], 'class': '2d_sphere', 'params_names': ['radius']}, 'obj_0000': {'color': [0.6349652299591473, 1.0, 1.0, 0.5], 'pose': [-0.403425931930542, 0.08972030133008957, -3.092280864715576], 'params': [0.12652935087680817], 'class': '2d_sphere', 'params_names': ['radius']}, 'obj_0001': {'color': [0.6801744155299632, 1.0, 1.0, 0.5], 'pose': [0.4625639319419861, 0.4472746253013611, -0.5235490202903748], 'params': [0.26720577478408813, 0.21856629848480225], 'class': '2d_box', 'params_names': ['height', 'length']}, 'obj_0006': {'color': [0.6295699069295791, 1.0, 1.0, 0.5], 'pose': [0.9539144039154053