In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import pickle
import time
from tqdm.notebook import tqdm

import torch
torch.set_default_tensor_type(torch.DoubleTensor)

from spatial_scene_grammars.nodes import *
from spatial_scene_grammars.rules import *
from spatial_scene_grammars.scene_grammar import *
from spatial_scene_grammars.visualization import *
from spatial_scene_grammars_examples.planar_clusters_gaussians.grammar import *
from spatial_scene_grammars.parsing import *
from spatial_scene_grammars.sampling import *
from spatial_scene_grammars.parameter_estimation import *

import meshcat
import meshcat.geometry as meshcat_geom

In [2]:
if 'vis' not in globals():
    vis = meshcat.Visualizer()
vis.delete()
base_url = "http://127.0.0.1"
meshcat_url = base_url + ":" + vis.url().split(":")[-1]
print("Meshcat url: ", meshcat_url)
from IPython.display import HTML
HTML("""
    <div style="height: 400px; width: 100%; overflow-x: auto; overflow-y: hidden; resize: both">
    <iframe src="{url}" style="width: 100%; height: 100%; border: none"></iframe>
</div>
""".format(url=meshcat_url))

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7015/static/
Meshcat url:  http://127.0.0.1:7015/static/


In [3]:
# Sample a dataset of scenes from the default grammar params.
# Draw a random sample from the grammar and visualize it.
# (Cache output.)
torch.random.manual_seed(5)
N_samples = 50
RESAMPLE = True
scenes_file = "sampled_scenes_%d.dat" % N_samples

ground_truth_grammar = SpatialSceneGrammar(
    root_node_type = Desk,
    root_node_tf = torch.eye(4)
)

if not os.path.exists(scenes_file) or RESAMPLE:
    samples = []
    for k in tqdm(range(N_samples)):
        tree = ground_truth_grammar.sample_tree(detach=True)
        observed_nodes = tree.get_observed_nodes()
        samples.append((tree, observed_nodes))

    print("Saving...")
    with open(scenes_file, "wb") as f:
        pickle.dump(samples, f)

print("Loading...")
with open(scenes_file, "rb") as f:
    samples = pickle.load(f)
print("Loaded %d scenes." % len(samples))
observed_node_sets = [x[1] for x in samples]

draw_scene_tree_contents_meshcat(samples[0][0], zmq_url=vis.window.zmq_url, prefix="sample/contents")
draw_scene_tree_structure_meshcat(samples[0][0], zmq_url=vis.window.zmq_url, prefix="sample/structure")



  0%|          | 0/50 [00:00<?, ?it/s]

Saving...
Loading...




Loaded 50 scenes.


In [5]:
# Demonstrate parsing for one scene in this grammar.
observed_nodes = samples[0][1]
inference_results = infer_mle_tree_with_mip(
    ground_truth_grammar, observed_nodes, verbose=True,
)
mip_optimized_tree = get_optimized_tree_from_mip_results(inference_results)
draw_scene_tree_contents_meshcat(mip_optimized_tree, zmq_url=vis.window.zmq_url, prefix="parsing/mip/contents")
draw_scene_tree_structure_meshcat(mip_optimized_tree, zmq_url=vis.window.zmq_url, prefix="parsing/mip/structure")

    
# Do NLP refinement of tree
refinement_results = optimize_scene_tree_with_nlp(ground_truth_grammar, mip_optimized_tree, verbose=True)
refined_tree = refinement_results.refined_tree
draw_scene_tree_contents_meshcat(refined_tree, zmq_url=vis.window.zmq_url, prefix="parsing/nlp/contents")
draw_scene_tree_structure_meshcat(refined_tree, zmq_url=vis.window.zmq_url, prefix="parsing/nlp/structure")

for node in mip_optimized_tree:
    err = torch.matmul(node.rotation.transpose(0, 1), node.rotation) - torch.eye(3)
    print("Avg elementwise deviation from R^T R = I: ", err.abs().mean())
for node in refined_tree:
    err = torch.matmul(node.rotation.transpose(0, 1), node.rotation) - torch.eye(3)
    print("Post-refinement avg elementwise deviation from R^T R = I: ", err.abs().mean())

Starting setup.
Activation vars allocated.
Continuous variables and SO(3) constraints allocated for all equivalence sets.
Setup time:  1.1676464080810547
Num vars:  7592
Num constraints:  22349




Optimization success?:  True
Logfile: 

Gurobi 9.0.2 (linux64) logging started Tue Oct 12 22:03:16 2021

Gurobi Optimizer version 9.0.2 build v9.0.2rc0 (linux64)
Optimize a model with 16264 rows, 7592 columns and 116448 nonzeros
Model fingerprint: 0x43509e1d
Model has 108 quadratic objective terms
Variable types: 7248 continuous, 344 integer (344 binary)
Coefficient statistics:
  Matrix range     [1e-01, 2e+01]
  Objective range  [2e-14, 1e+03]
  QObjective range [5e+00, 2e+02]
  Bounds range     [1e+00, 1e+01]
  RHS range        [1e+00, 2e+01]
Presolve removed 9131 rows and 2235 columns
Presolve time: 1.88s
Presolved: 7133 rows, 5357 columns, 66340 nonzeros
Presolved model has 84 quadratic objective terms
Variable types: 5143 continuous, 214 integer (214 binary)

Root relaxation: objective -4.707072e+04, 8058 iterations, 0.22 seconds

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time

 

In [None]:
# Initialize a grammar with wide parameter guesses.
grammar = SpatialSceneGrammar(
    root_node_type = Desk,
    root_node_tf = torch.eye(4),
    sample_params_from_prior=True
)

def do_vis(tree):
    draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="sampled_in_progress")

if 0:
    posterior_sample_sets = collect_posterior_sample_sets(grammar, observed_node_sets)
    for k, tree in enumerate(posterior_sample_sets[-1]):
        draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="guesses/%d" % k)

In [None]:
torch.save(grammar, "/tmp/test_saved_grammar.torch")
orig_grammar = torch.load("/tmp/test_saved_grammar.torch")

In [None]:
if 0:
    grammar = fit_grammar_params_to_sample_sets_with_uninformative_prior(grammar, posterior_sample_sets)
    print("**********************************\n"
          "**********  BEFORE ***************\n"
          "************************************")
    orig_grammar.print_params(node_names=["Desk"])
    print("**********************************\n"
          "**********  AFTER ***************\n"
          "************************************")
    grammar.print_params(node_names=["Desk"])
    print("**********************************\n"
          "**********  TRUTH ***************\n"
          "************************************")
    ground_truth_grammar.print_params(node_names=["Desk"])

In [None]:
state_dict_history = []
for iter_k in tqdm(range(20), desc="Major iteration"):
    state_dict_history.append(deepcopy(grammar.state_dict()))
    posterior_sample_sets = collect_posterior_sample_sets(grammar, observed_node_sets, num_workers=8, tqdm=tqdm)
    # Check out if it does good fitting on the sample set itself -- which it should!
    #posterior_sample_sets = [[x[0] for x in samples]]
    grammar = fit_grammar_params_to_sample_sets_with_uninformative_prior(grammar, posterior_sample_sets)
state_dict_history.append(deepcopy(grammar.state_dict()))

In [None]:
# Plot a couple of interesting parameters
param_getters_of_interest = {
    "Desk child xyz mean: ": lambda x: x.rule_params_by_node_type["Desk"][0][0]["mean"]().detach().numpy(),
    "Desk child xyz var: ": lambda x: x.rule_params_by_node_type["Desk"][0][0]["variance"]().detach().numpy(),
    "Desk child rot loc: ": lambda x: x.rule_params_by_node_type["Desk"][0][1]["loc"]().detach().numpy(),
    "Desk child rot var: ": lambda x: x.rule_params_by_node_type["Desk"][0][1]["concentration"]().detach().numpy(),
    "Pencil child rot var: ": lambda x: x.rule_params_by_node_type["PencilCluster"][0][1]["concentration"]().detach().numpy(),
    "Desk child rate: ": lambda x: x.params_by_node_type["Desk"]().detach().numpy(),
    "Object cluster child rate: ": lambda x: x.params_by_node_type["ObjectCluster"]().detach().numpy(),
    "FoodWasteCluster child rate: ": lambda x: x.params_by_node_type["FoodWasteCluster"]().detach().numpy(),
    
}

for key, getter in param_getters_of_interest.items():
    plt.figure()
    history = []
    for state_dict in state_dict_history:
        grammar.load_state_dict(state_dict)
        history.append(getter(grammar).copy().flatten())
    data = np.stack(history)

    gt_x = getter(ground_truth_grammar).flatten()
    cm = plt.get_cmap("viridis")
    N = len(gt_x)
    for k in range(N):
        color = cm(k / max(1, N))
        plt.plot(data[:, k], color=color)
        plt.axhline(gt_x[k], color=color, linestyle="--")
    plt.xlabel("Iter")
    plt.ylabel(key)
