In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import pickle
import time
from tqdm.notebook import tqdm

import torch
torch.set_default_tensor_type(torch.DoubleTensor)

from spatial_scene_grammars.nodes import *
from spatial_scene_grammars.rules import *
from spatial_scene_grammars.scene_grammar import *
from spatial_scene_grammars.visualization import *
from spatial_scene_grammars_examples.planar_clusters_gaussians.grammar import *
from spatial_scene_grammars.parsing import *
from spatial_scene_grammars.sampling import *

import meshcat
import meshcat.geometry as meshcat_geom

In [50]:
if 'vis' not in globals():
    vis = meshcat.Visualizer()
vis.delete()
base_url = "http://127.0.0.1"
meshcat_url = base_url + ":" + vis.url().split(":")[-1]
print("Meshcat url: ", meshcat_url)
from IPython.display import HTML
HTML("""
    <div style="height: 400px; width: 100%; overflow-x: auto; overflow-y: hidden; resize: both">
    <iframe src="{url}" style="width: 100%; height: 100%; border: none"></iframe>
</div>
""".format(url=meshcat_url))

Meshcat url:  http://127.0.0.1:7013/static/


In [52]:
# Sample a dataset of scenes from the default grammar params.
# Draw a random sample from the grammar and visualize it.
# (Cache output.)
torch.random.manual_seed(2)
N_samples = 1
RESAMPLE = True
scenes_file = "sampled_scenes_%d.dat" % N_samples

ground_truth_grammar = SpatialSceneGrammar(
    root_node_type = Desk,
    root_node_tf = torch.eye(4)
)

if not os.path.exists(scenes_file) or RESAMPLE:
    samples = []
    for k in range(N_samples):
        tree = ground_truth_grammar.sample_tree(detach=True)
        observed_nodes = tree.get_observed_nodes()
        samples.append((tree, observed_nodes))

    with open(scenes_file, "wb") as f:
        pickle.dump(samples, f)

with open(scenes_file, "rb") as f:
    samples = pickle.load(f)
print("Loaded %d scenes." % len(samples))
observed_node_sets = [x[1] for x in samples]

draw_scene_tree_contents_meshcat(samples[0][0], zmq_url=vis.window.zmq_url, prefix="sample")

Loaded 1 scenes.


In [4]:
# Initialize a grammar with wide parameter guesses.
grammar = SpatialSceneGrammar(
    root_node_type = Desk,
    root_node_tf = torch.eye(4),
    sample_params_from_prior=False
)
# Force parameter guesses for rules as wide as possible.
# TODO: Make this a grammar method.
for node_type in grammar.all_types:
    for xyz_param_dict, rot_param_dict in grammar.rule_params_by_node_type[node_type.__name__]:
        if "width" in xyz_param_dict.keys():
            xyz_param_dict["width"].set(torch.ones_like(xyz_param_dict["width"]()) * 5.)
        

def do_vis(tree):
    draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="sampled_in_progress")
    
def get_posterior_tree_samples_from_observation(grammar, observed_nodes, num_mcmc_steps=15, subsample_rate=3, verbose=0):
    draw_scene_tree_contents_meshcat(
        SceneTree.make_from_observed_nodes(observed_nodes), zmq_url=vis.window.zmq_url, prefix="observed"
    )
    
    # Use a MIP to get MAP structure.
    mip_results = infer_mle_tree_with_mip(
        grammar, observed_nodes, verbose=verbose>1, max_scene_extent_in_any_dir=10.
    )
    mip_optimized_tree = get_optimized_tree_from_mip_results(mip_results)
    if not mip_optimized_tree:
        return None
    
    draw_scene_tree_structure_meshcat(mip_optimized_tree, zmq_url=vis.window.zmq_url, prefix="mip_refined")
                                      
    # Use NLP to refine that to a MAP estimate.
    refinement_results = optimize_scene_tree_with_nlp(mip_optimized_tree, verbose=verbose>1)
    refined_tree = refinement_results.refined_tree
    
    # And sample trees around that MAP estimate with the
    # same structure.
    sampled_trees = do_fixed_structure_mcmc(
        grammar, refined_tree, num_samples=num_mcmc_steps, verbose=verbose,
        perturb_in_config_space=True, translation_variance=0.01, rotation_variance=0.1,
        do_hit_and_run_postprocess=False, vis_callback=do_vis
    )
    
    # Finally, subsample the sampled trees as requested and return
    # the sampled set.
    return sampled_trees[::subsample_rate]

def collect_posterior_sample_sets(grammar, observed_node_sets):
    posterior_sample_sets = []
    for observed_nodes in tqdm(observed_node_sets, desc='Collecting posterior samples'):
        posterior_samples = get_posterior_tree_samples_from_observation(
            grammar, observed_nodes, verbose=0, subsample_rate=1, num_mcmc_steps=10)
        if posterior_samples is not None:
            posterior_sample_sets.append(posterior_samples)
    return posterior_sample_sets
posterior_sample_sets = collect_posterior_sample_sets(grammar, observed_node_sets)

Collecting posterior samples:   0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
for k, tree in enumerate(posterior_sample_sets[-1]):
    draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="guesses/%d" % k)

In [48]:
tree =  deepcopy(samples[0][0])
#print(tree.trace.nodes.keys())
grammar = SpatialSceneGrammar(
    root_node_type = Desk,
    root_node_tf = torch.eye(4),
    sample_params_from_prior=True
)
# Langevin-esque:
if 0:
    sampled_trees = do_fixed_structure_hmc_with_constraint_penalties(
        grammar, tree, num_samples=200, subsample_step=5, verbose=1,
        kernel_type="HMC", num_steps=1, step_size=1E-3, adapt_step_size=True
    )
# NUTS, defaults except limiting tree depth, to save on calls to
# the slow model. Much slower, but should theoretically get much more
# diversity as long as it doesn't diverge.
if 1:
    sampled_trees = do_fixed_structure_hmc_with_constraint_penalties(
        grammar, tree, num_samples=20, subsample_step=1, verbose=1,
        kernel_type="NUTS", max_tree_depth=4
    )


Warmup:   0%|          | 0/60 [10:22, ?it/s]/it, step size=2.05e-02, acc. prob=0.827]
Sample: 100%|██████████| 30/30 [00:29,  1.02it/s, step size=4.02e-03, acc. prob=0.821]



                                                                     mean       std    median     25.0%     75.0%     n_eff     r_hat
    Desk_2/ObjectCluster_2/AxisAlignedGaussianOffsetRule_xyz[0]      0.80      0.01      0.80      0.79      0.81     15.97      0.99
    Desk_2/ObjectCluster_2/AxisAlignedGaussianOffsetRule_xyz[1]      0.74      0.02      0.75      0.74      0.76      6.79      1.23
    Desk_2/ObjectCluster_2/AxisAlignedGaussianOffsetRule_xyz[2]     -0.00      0.01     -0.00     -0.00      0.00     20.09      1.01
        Desk_2/ObjectCluster_2/GaussianChordOffsetRule_theta[0]     -1.00      0.02     -1.00     -1.00     -0.98      2.74      2.35
    Desk_2/ObjectCluster_3/AxisAlignedGaussianOffsetRule_xyz[0]      0.35      0.02      0.35      0.33      0.36      8.13      0.96
    Desk_2/ObjectCluster_3/AxisAlignedGaussianOffsetRule_xyz[1]      0.40      0.03      0.40      0.39      0.42      5.48      1.37
    Desk_2/ObjectCluster_3/AxisAlignedGaussianOffsetRule_xyz[

In [51]:
for k, tree in enumerate(sampled_trees):
    draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="guesses/%d" % k)