In [34]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import pickle
import time
from tqdm.notebook import tqdm

import torch
torch.set_default_tensor_type(torch.DoubleTensor)

from spatial_scene_grammars.constraints import *
from spatial_scene_grammars.nodes import *
from spatial_scene_grammars.rules import *
from spatial_scene_grammars.scene_grammar import *
from spatial_scene_grammars.visualization import *
from spatial_scene_grammars_examples.table.grammar import *
from spatial_scene_grammars.parsing import *
from spatial_scene_grammars.sampling import *
from spatial_scene_grammars.parameter_estimation import *
from spatial_scene_grammars.dataset import *

import meshcat
import meshcat.geometry as meshcat_geom

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [39]:
if 'vis' not in globals():
    vis = meshcat.Visualizer()
vis.delete()
base_url = "http://127.0.0.1"
meshcat_url = base_url + ":" + vis.url().split(":")[-1]
print("Meshcat url: ", meshcat_url)
'''
from IPython.display import HTML
HTML("""
    <div style="height: 400px; width: 100%; overflow-x: auto; overflow-y: hidden; resize: both">
    <iframe src="{url}" style="width: 100%; height: 100%; border: none"></iframe>
</div>
""".format(url=meshcat_url))
'''

Meshcat url:  http://127.0.0.1:7007/static/


'\nfrom IPython.display import HTML\nHTML("""\n    <div style="height: 400px; width: 100%; overflow-x: auto; overflow-y: hidden; resize: both">\n    <iframe src="{url}" style="width: 100%; height: 100%; border: none"></iframe>\n</div>\n""".format(url=meshcat_url))\n'

In [64]:
# Set up grammar
grammar = SpatialSceneGrammar(
    root_node_type = Table,
    root_node_tf = drake_tf_to_torch_tf(RigidTransform(p=[0.0, 0., 0.8]))
)

# Single unconstrained draw
torch.random.manual_seed(44)
tree = grammar.sample_tree(detach=True)
draw_scene_tree_contents_meshcat(tree, zmq_url=vis.window.zmq_url,  prefix="sample/contents")
draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url,  prefix="sample/structure", node_sphere_size=0.01,
                                  linewidth=5)

In [30]:
# Constraints
class TablesInRoomConstraint(Constraint):
    # XY coord of each table inside of [-5, 5]
    def __init__(self):
        lb = torch.tensor([-4., -4.])
        ub = torch.tensor([4., 4.])
        super().__init__(
            lower_bound=lb,
            upper_bound=ub
        )
    def eval(self, scene_tree):
        xys = [node.translation[:2] for node in scene_tree.nodes if type(node) is Table]
        if len(xys) > 0:
            return torch.stack(xys, axis=0)
        else:
            return torch.empty(size=(0, 2))
    def add_to_ik_prog(self, scene_tree, ik, mbp, mbp_context, node_to_free_body_ids_map):
        raise NotImplementedError()

class TableSpacingConstraint(Constraint):
    # Table centers all a minimum distance apart. Distance is in squared
    # euclidean distance space; for some reason, adding sqrt causes gradients
    # to go NaN.
    def __init__(self):
        lb = torch.tensor([2.]).square()
        ub = torch.tensor([np.inf])
        super().__init__(
            lower_bound=lb,
            upper_bound=ub
        )
    def eval(self, scene_tree):
        xys = [node.translation[:2] for node in scene_tree.nodes if type(node) is Table]
        if len(xys) > 1:
            xys = torch.stack(xys, axis=0) # N x 2
            N = xys.shape[0]
            xys_rowwise = xys.unsqueeze(1).expand(-1, N, -1)
            xys_colwise = xys.unsqueeze(0).expand(N, -1, -1)
            dists = (xys_rowwise - xys_colwise).square().sum(axis=-1)
            # Get only lower triangular non-diagonal elems
            rows, cols = torch.tril_indices(N, N, -1)
            dists = dists[rows, cols].reshape(-1, 1)
            return dists
        else:
            return torch.empty(size=(0, 1))
    def add_to_ik_prog(self, scene_tree, ik, mbp, mbp_context, node_to_free_body_ids_map):
        raise NotImplementedError()

class ObjectsOnTableConstraint(Constraint):
    def __init__(self):
        lb = torch.tensor([-0.35, -0.35])
        ub = torch.tensor([0.35, 0.35])
        super().__init__(
            lower_bound=lb,
            upper_bound=ub
        )
    def eval(self, scene_tree):
        tables = scene_tree.find_nodes_by_type(Table)
        xys = [] # in parent table frame
        for table in tables:
            # Collect table children
            objs = [node for node in scene_tree.get_children_recursive(table) if isinstance(node, ObjectModel)]
            for obj in objs:
                offset = torch.matmul(table.rotation.T, obj.translation - table.translation)[:2]
                xys.append(offset)
        if len(xys) > 0:
            return torch.stack(xys, axis=0)
        else:
            return torch.empty(size=(0, 2))
    def add_to_ik_prog(self, scene_tree, ik, mbp, mbp_context, node_to_free_body_ids_map):
        raise NotImplementedError()

class ObjectSpacingConstraint(Constraint):
    # Objects all a minimum distance apart on tabletop
    def __init__(self):
        lb = torch.tensor([0.05]).square()
        ub = torch.tensor([np.inf])
        super().__init__(
            lower_bound=lb,
            upper_bound=ub
        )
    def eval(self, scene_tree):
        tables = scene_tree.find_nodes_by_type(Table)
        all_dists = []
        for table in tables:
            # Collect table children
            objs = [node for node in scene_tree.get_children_recursive(table) if isinstance(node, ObjectModel)]
            if len(objs) <= 1:
                continue
            xys = torch.stack([obj.translation[:2] for obj in objs], axis=0)
            N = xys.shape[0]
            xys_rowwise = xys.unsqueeze(1).expand(-1, N, -1)
            xys_colwise = xys.unsqueeze(0).expand(N, -1, -1)
            dists = (xys_rowwise - xys_colwise).square().sum(axis=-1)
            # Get only lower triangular non-diagonal elems
            rows, cols = torch.tril_indices(N, N, -1)
            dists = dists[rows, cols].reshape(-1, 1)
            all_dists.append(dists)
        if len(all_dists) > 0:
            return torch.cat(all_dists, axis=0)
        else:
            return torch.empty(size=(0, 1))
    def add_to_ik_prog(self, scene_tree, ik, mbp, mbp_context, node_to_free_body_ids_map):
        raise NotImplementedError()



constraints = [
    TablesInRoomConstraint(),
    TableSpacingConstraint(),
    ObjectsOnTableConstraint(),
    ObjectSpacingConstraint()
]
hmc_tree = deepcopy(tree)
samples = do_fixed_structure_hmc_with_constraint_penalties(
    grammar, hmc_tree, num_samples=20, subsample_step=5,
    with_nonpenetration=False, zmq_url=None,
    constraints=constraints,
    kernel_type="NUTS", max_tree_depth=6, target_accept_prob=0.65
    #kernel_type="HMC", num_steps=1, step_size=1E-2 # Langevin-ish
)

Warmup:   0%|          | 0/30 [00:00, ?it/s]

Initial trace log prob:  tensor(60.7916)


Sample: 100%|██████████| 30/30 [00:24,  1.23it/s, step size=6.69e-03, acc. prob=0.803]


In [None]:
for k in range(4):
    draw_scene_tree_contents_meshcat(samples[-k], zmq_url=vis.window.zmq_url,  prefix="contents/hmc_sample_%d" % k)
    draw_scene_tree_structure_meshcat(samples[-k], zmq_url=vis.window.zmq_url,  prefix="structure/hmc_sample_%d" % k)