In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import pickle
import time
from tqdm.notebook import tqdm

import torch
torch.set_default_tensor_type(torch.DoubleTensor)

from spatial_scene_grammars.constraints import *
from spatial_scene_grammars.nodes import *
from spatial_scene_grammars.rules import *
from spatial_scene_grammars.scene_grammar import *
from spatial_scene_grammars.visualization import *
from spatial_scene_grammars_examples.table.grammar import *
from spatial_scene_grammars.parsing import *
from spatial_scene_grammars.sampling import *
from spatial_scene_grammars.parameter_estimation import *
from spatial_scene_grammars.dataset import *

import meshcat
import meshcat.geometry as meshcat_geom

In [2]:
if 'vis' not in globals():
    vis = meshcat.Visualizer()
vis.delete()
base_url = "http://127.0.0.1"
meshcat_url = base_url + ":" + vis.url().split(":")[-1]
print("Meshcat url: ", meshcat_url)
'''
from IPython.display import HTML
HTML("""
    <div style="height: 400px; width: 100%; overflow-x: auto; overflow-y: hidden; resize: both">
    <iframe src="{url}" style="width: 100%; height: 100%; border: none"></iframe>
</div>
""".format(url=meshcat_url))
'''

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/
Meshcat url:  http://127.0.0.1:7001/static/


'\nfrom IPython.display import HTML\nHTML("""\n    <div style="height: 400px; width: 100%; overflow-x: auto; overflow-y: hidden; resize: both">\n    <iframe src="{url}" style="width: 100%; height: 100%; border: none"></iframe>\n</div>\n""".format(url=meshcat_url))\n'

In [3]:
# Set up grammar and constraint set.
grammar = SpatialSceneGrammar(
    root_node_type = Table,
    root_node_tf = drake_tf_to_torch_tf(RigidTransform(p=[0.0, 0., 0.8]))
)
constraints = [
    ObjectsOnTableConstraint(),
    ObjectSpacingConstraint()
]

def sample_realistic_scene(seed=None):
    if seed is not None:
        torch.random.manual_seed(seed)
    topology_constraints, continuous_constraints = split_constraints(constraints)
    if len(topology_constraints) > 0:
        tree, success = rejection_sample_under_constraints(grammar, topology_constraints, 1000)
        if not success:
            logging.error("Couldn't rejection sample a feasible tree config.")
            return None
    else:
        tree = grammar.sample_tree(detach=True)
    samples = do_fixed_structure_hmc_with_constraint_penalties(
        grammar, tree, num_samples=25, subsample_step=5,
        with_nonpenetration=False, zmq_url=vis.window.zmq_url,
        constraints=continuous_constraints,
        kernel_type="NUTS", max_tree_depth=6, target_accept_prob=0.8, adapt_step_size=True
    )
    # Step through samples backwards in HMC process and pick out a tree that satisfies
    # the constraints.
    good_tree = None
    for candidate_tree in samples[::-1]:
        if eval_total_constraint_set_violation(candidate_tree, constraints) <= 0.:
            good_tree = candidate_tree
            break
    if good_tree == None:
        logging.error("No tree in samples satisfied constraints.")
        return None
    
    return project_tree_to_feasibility(good_tree, do_forward_sim=True, timestep=0.001, T=1.)



In [4]:
test_tree = sample_realistic_scene(seed=42)
if test_tree is not None:
    draw_scene_tree_contents_meshcat(test_tree, zmq_url=vis.window.zmq_url,  prefix="test_tree/contents")
    draw_scene_tree_structure_meshcat(test_tree, zmq_url=vis.window.zmq_url,  prefix="test_tree/structure")



Initial trace log prob:  tensor(-14030.3788)


Warmup:   0%|          | 0/37 [00:00, ?it/s]

Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6001...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/
Connected to meshcat-server.


Sample: 100%|██████████| 37/37 [00:52,  1.42s/it, step size=5.32e-04, acc. prob=0.979]


In [None]:
# Save grammar state dict
state_dict_file = "target_dataset_grammar_state_dict.torch"
print("Saving state dict to ", state_dict_file)
torch.save(grammar.state_dict(), state_dict_file)
# Try to collect a target number of examples, and save them out
dataset_save_file = "target_dataset_examples.pickle"
N = 100
k = 0
pbar = tqdm(total=N, desc="Samples")
while k < N:
    tree = sample_realistic_scene()
    if tree is not None:
        with open(dataset_save_file, "a+b") as f:
            pickle.dump(tree, f)
        k += 1
        pbar.update(k)

Saving state dict to  target_dataset_grammar_state_dict.torch


Samples:   0%|          | 0/100 [00:00<?, ?it/s]

Initial trace log prob:  tensor(-10358.7043)


Warmup:   0%|          | 0/37 [00:00, ?it/s]

Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6001...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/
Connected to meshcat-server.


Sample: 100%|██████████| 37/37 [00:56,  1.52s/it, step size=7.41e-04, acc. prob=0.962]


Initial trace log prob:  tensor(-964.4689)
Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6001...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7001/static/
Connected to meshcat-server.


Sample: 100%|██████████| 37/37 [01:06,  1.81s/it, step size=1.02e-03, acc. prob=0.944]
