# Using one big MIP to do parameter fitting

Given a dataset of observed object sets drawn from the grammar, with the constraint that the grammar uses a subset of the rules that admit convex formulations of the log joint probability in both the continuous variables and the parameters, create a big MIP that optimizes the grammar parameters and parses for each scene simultaneously.

In [1]:
%load_ext autoreload
%autoreload 2

import logging
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import pickle
import time

import torch
torch.set_default_tensor_type(torch.DoubleTensor)

from spatial_scene_grammars.nodes import *
from spatial_scene_grammars.rules import *
from spatial_scene_grammars.scene_grammar import *
from spatial_scene_grammars.visualization import *
from spatial_scene_grammars_examples.planar_clusters_no_rotation.grammar import *
from spatial_scene_grammars.parsing import *
from spatial_scene_grammars.sampling import *

import meshcat
import meshcat.geometry as meshcat_geom

In [2]:
if 'vis' not in globals():
    vis = meshcat.Visualizer()

base_url = "http://127.0.0.1"
meshcat_url = base_url + ":" + vis.url().split(":")[-1]
print("Meshcat url: ", meshcat_url)
from IPython.display import HTML
HTML("""
    <div style="height: 400px; width: 100%; overflow-x: auto; overflow-y: hidden; resize: both">
    <iframe src="{url}" style="width: 100%; height: 100%; border: none"></iframe>
</div>
""".format(url=meshcat_url))

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/
Meshcat url:  http://127.0.0.1:7000/static/


In [3]:
# Sample a dataset of scenes from the default grammar params.
# Draw a random sample from the grammar and visualize it.
# (Cache output.)
torch.random.manual_seed(2)
N_samples = 10
RESAMPLE = True
scenes_file = "sampled_scenes_%d.dat" % N_samples

ground_truth_grammar = SpatialSceneGrammar(
    root_node_type = Desk,
    root_node_tf = torch.eye(4)
)

if not os.path.exists(scenes_file) or RESAMPLE:
    samples = []
    for k in range(N_samples):
        tree = ground_truth_grammar.sample_tree(detach=True)
        observed_nodes = tree.get_observed_nodes()
        samples.append((tree, observed_nodes))

    with open(scenes_file, "wb") as f:
        pickle.dump(samples, f)

with open(scenes_file, "rb") as f:
    samples = pickle.load(f)
print("Loaded %d scenes." % len(samples))
observed_node_sets = [x[1] for x in samples]

Loaded 100 scenes.


In [None]:
import pydrake
from pydrake.all import (
    CommonSolverOption,
    MathematicalProgram,
    MakeSolver,
    MixedIntegerBranchAndBound,
    RigidTransform,
    RollPitchYaw,
    RotationMatrix,
    GurobiSolver,
    SnoptSolver,
    OsqpSolver,
    Solve,
    SolverOptions,
    SolutionResult,
    VPolytope,
    MixedIntegerRotationConstraintGenerator,
    IntervalBinning,
    Variable
)


# Initialize a grammar with random parameter guesses
# (which will be thrown out).
grammar = SpatialSceneGrammar(
    root_node_type = Desk,
    root_node_tf = torch.eye(4),
    sample_params_from_prior=True
)


GrammarInferenceResults = namedtuple(
    "GrammarInferenceResults",
    ["solver", "optim_result", "grammar", "all_supertrees", "observed_node_sets"]
)
def fit_grammar_params_to_observed_nodes(grammar, observed_node_sets, verbose=False):
    start_time = time.time()

    # Setup fit of grammar parameters + parse of each tree.
    prog = MathematicalProgram()    
    grammar = prepare_grammar_for_mip_parsing(prog, grammar, optimize_parameters=True)
    all_trees = []
    all_obs = []
    for observed_nodes in observed_node_sets:
        tree, obs = add_mle_tree_parsing_to_prog(
            prog, grammar, observed_nodes, verbose=False
        )
        all_trees.append(tree)
        all_obs.append(obs)

    setup_time = time.time()
    if verbose:
        print("Setup time: ", setup_time - start_time)
        print("Num vars: ", prog.num_vars())
        print("Num constraints: ", sum([c.evaluator().num_constraints() for c in prog.GetAllConstraints()]))
        sys.stdout.flush()

    solver = GurobiSolver()
    options = SolverOptions()
    logfile = "/tmp/gurobi.log"
    os.system("rm %s" % logfile)
    options.SetOption(solver.id(), "LogFile", logfile)
    gap = 0.05
    options.SetOption(solver.id(), "MIPGap", gap)
    logging.info("MIP gap set to %d\%", gap*100.)
    result = solver.Solve(prog, None, options)
    if verbose:
        print("Optimization success?: ", result.is_success())
        print("Logfile: ")
        with open(logfile) as f:
            print(f.read())

    solve_time = time.time() 
    if verbose:
            print("Solve time: ", solve_time-setup_time)
            print("Total time: ", solve_time - start_time)

    # If successful, go fill the grammar ConstrainedParameter values back in.
    if result.is_success():
        logging.info("TODO: Node parameters")
        for node_type in grammar.all_types:
            for ((xyz_params, rot_params), (fit_xyz_params, fit_rot_params)) in zip(
                    grammar.rule_params_by_node_type[node_type.__name__],
                    grammar.rule_params_by_node_type_optim[node_type.__name__]):
                for key in xyz_params.keys():
                    xyz_params[key].set(torch.tensor(result.GetSolution(fit_xyz_params[key])))
                for key in rot_params.keys():
                    rot_params[key].set(torch.tensor(result.GetSolution(fit_rot_params[key])))
                
    else:
        logging.warn("Parameter fitting optimization failed; grammar params left alone.")
        
        
    return GrammarInferenceResults(solver, result, grammar, all_trees, all_obs)
    
results = fit_grammar_params_to_observed_nodes(grammar, observed_node_sets, verbose=1)

Setup time:  104.38579440116882
Num vars:  101390
Num constraints:  352739


In [None]:
# Compare solved-out grammar params to original grammar params
fit_params = {key: value for key, value in results.grammar.named_parameters()}
def print_compare(name, orig, fit):
    print(name + ":")
    print("\t Orig: %s" % orig().detach().numpy())
    print("\t Fit:  %s" % fit().detach().numpy())
    
for node_type in ground_truth_grammar.all_types:
    for (k, ((xyz_params, rot_params), (fit_xyz_params, fit_rot_params))) in enumerate(zip(
            ground_truth_grammar.rule_params_by_node_type[node_type.__name__],
            grammar.rule_params_by_node_type[node_type.__name__])):
        prefix = "%s:%d:" % (node_type.__name__, k)
        for key in xyz_params.keys():
            print_compare(prefix + key, xyz_params[key], fit_xyz_params[key])
        for key in rot_params.keys():
            print_compare(prefix + key, rot_params[key], rot_xyz_params[key])