In [3]:
%load_ext autoreload
%autoreload 2

import time
import networkx as nx
import numpy as np
import logging
from copy import deepcopy
from collections import namedtuple
import torch
import pyro

from spatial_scene_grammars.nodes import *
from spatial_scene_grammars.rules import *
from spatial_scene_grammars.scene_grammar import *
from spatial_scene_grammars.parsing import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
from spatial_scene_grammars_examples.singles_pairs.grammar_constituency import *
#from spatial_scene_grammars_examples.singles_pairs.grammar_dependency import *
pyro.set_rng_seed(42)

grammar = SpatialSceneGrammar(
    root_node_type = Root,
    root_node_tf = torch.eye(4)
)
ground_truth_tree = grammar.sample_tree(detach=True)
observed_nodes = ground_truth_tree.get_observed_nodes()
print("Observed %d objects" % len(observed_nodes))
print("Objects: ", observed_nodes)

Observed 8 objects
Objects:  [<spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Root object at 0x7f29be3456a0>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7f29be852b00>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7f29be7eaf28>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7f29be7ea6d8>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7f29be7ead30>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7f29be7ea7f0>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7f29be7ea828>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7f29be7eadd8>]


In [32]:
# Build up a big set of candidate intermediate nodes with both
# top-down and bottom-up generation.

max_recursion_depth = 10

top_down_candidate_intermediate_nodes = generate_top_down_intermediate_nodes_by_supertree(
    grammar, observed_nodes, max_recursion_depth=max_recursion_depth
)
bottom_up_candidate_intermediate_nodes = generate_bottom_up_intermediate_nodes_by_inverting_rules(
    grammar, observed_nodes
)
print("%d top-down, %d bottom-up candidates." %
      (len(top_down_candidate_intermediate_nodes),
       len(bottom_up_candidate_intermediate_nodes)))
candidate_intermediate_nodes = top_down_candidate_intermediate_nodes + bottom_up_candidate_intermediate_nodes
assert all([not node.observed for node in candidate_intermediate_nodes])

print("Before pruning: %d candidates" % len(candidate_intermediate_nodes))
# Prune duplicates
pruned_intermediate_nodes = []
for node in candidate_intermediate_nodes:
    present = False
    for other_node in pruned_intermediate_nodes:
        if type(node) is type(other_node) and torch.allclose(node.tf, other_node.tf):
            present = True
            break
    if not present:
        pruned_intermediate_nodes.append(node)

candidate_intermediate_nodes = pruned_intermediate_nodes            
print("After pruning: %d candidates" % len(candidate_intermediate_nodes))
for node in candidate_intermediate_nodes:
    print("%s: %s" % (node.name, node.tf.detach().numpy()))

Candidate <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Singles object at 0x7f29be864eb8> with pose tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
Candidate <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Singles object at 0x7f29be8080b8> with pose tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
Candidate <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Singles object at 0x7f29be8087b8> with pose tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
Candidate <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Singles object at 0x7f29be808898> with pose tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
Candidate <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Singles object

In [33]:
# Set up a MIP to try to glue a valid tree together.
# We can build a graph where each node is an observed
# or candidate intermediate node, and there is a directed
# edge for every legal rule, with a weight corresponding to its
# probability. We add one binary variable per edge indicating
# its activation.
#  - Every observed node except the root needs exactly one active incoming
#    edge. The root needs exactly zero. (It shouldn't have any incoming
#    edges anyway by assumptions about construction of the grammar.)
#    Unobserved nodes have outgoing edges iff they have an active
#    incoming edge.
#  - The score (and constraints on) a node's set of outgoing edges
#    depends on the node type. This includes symmetry breaking
#    where appropriate.
#  - Maximize the total score.

from pydrake.all import (
    MathematicalProgram,
    GurobiSolver,
    SolverOptions
)

prog = MathematicalProgram()

# Extract root node; it may have been observed,
# otherwise produce a new one.
root = None
for node in observed_nodes:
    if isinstance(node, grammar.root_node_type):
        root = node
if root is None:
    root = grammar.root_node_type(tf=grammar.root_node_tf)
    observed_nodes += [root, ]

# For each node, iterate over its rules and add appropriate edges.
all_nodes = observed_nodes + candidate_intermediate_nodes
# Important to use a MultiDiGraph, as there may be multiple edges
# between two nodes (corresponding to different rules being used
# to generate the same node).
parse_graph = nx.MultiDiGraph()
parse_graph.add_nodes_from(all_nodes)
def add_edges_for_rule(parent, rule, rule_k, rule_activation_expr):
    # Given a parent node and one of its rules, add directed
    # edges from this parent node to all children that rule could
    # create.
    # rule_activation_expr should be a linear expression of
    # decision variables that evaluates to 1 when this rule
    # is active and 0 when not.
    if isinstance(parent, GeometricSetNode):
        assert rule_k >= 0 and rule_k <= parent.max_children
        assert rule is parent.rules[0]
    elif isinstance(parent, (AndNode, OrNode, IndependentSetNode)):
        assert rule is parent.rules[rule_k]
    else:
        raise ValueError(type(rule), "Bad type.")
    all_outgoing_activations = []
    for node in all_nodes:
        if isinstance(node, rule.child_type) and node is not parent:
            score = rule.score_child(parent, node).detach().item()
            var_name = "%s:%s_%d:%s" % (parent.name, type(rule).__name__, rule_k, node.name)
            if np.isfinite(score):
                active = prog.NewBinaryVariables(1, var_name)[0]
                parse_graph.add_edge(
                    parent, node, active=active, score=score, rule_k=rule_k
                )
                all_outgoing_activations.append(active)
                # If this edge is active, it adds this score to the total cost.
                prog.AddLinearCost(-score * active)
            else:
                logging.warning("Skipping rule ", var_name, " as its infeasible")
    if len(all_outgoing_activations) > 0:
        prog.AddLinearConstraint(sum(all_outgoing_activations) == rule_activation_expr)
    else:
        logging.warning("No outgoing connections for %s:%s_%d" % (parent.name, type(rule), rule_k))
    
for node in all_nodes:
    # Add activation variable for this node.
    node.active = prog.NewBinaryVariables(1, node.name + "_active")[0]

    if isinstance(node, TerminalNode):
        # No rules / children to worry about.
        continue

    elif isinstance(node, AndNode):
        # Rules are gated on parent activation.
        for rule_k, rule in enumerate(node.rules):
            add_edges_for_rule(node, rule, rule_k, node.active)
        
    elif isinstance(node, OrNode):
        activation_vars = prog.NewBinaryVariables(len(node.rules), node.name + "_outgoing")
        # Rules are gated on parent activation, and exactly one
        # is active.
        prog.AddLinearConstraint(sum(activation_vars) == node.active)
        for rule_k, (rule, activation_var) in enumerate(zip(node.rules, activation_vars)):
            add_edges_for_rule(node, rule, rule_k, activation_var)
            # Each rule activation has a corresponding score based
            # on its log-prob.
            prog.AddLinearCost(-activation_var * np.log(node.rule_probs[rule_k].detach().item()))
        
    elif isinstance(node, IndependentSetNode):
        activation_vars = prog.NewBinaryVariables(len(node.rules), node.name + "_outgoing")
        for rule_k, (rule, activation_var) in enumerate(zip(node.rules, activation_vars)):
            add_edges_for_rule(node, rule, rule_k, activation_var)
            # The rules are only active if the parent is active.
            prog.AddLinearConstraint(activation_var <= node.active)
            # Each rule activation incurs an independent score based
            # on its log-prob.
            prog.AddLinearCost(-activation_var * np.log(node.rule_probs[rule_k].detach().item()))
            
    elif isinstance(node, GeometricSetNode):
        activation_vars = prog.NewBinaryVariables(node.max_children, node.name + "_outgoing")
        # Ensure that these variables activate in order by constraining
        # that for a rule to be active, the preceeding rule must also be
        # active.
        for k in range(len(activation_vars) - 1):
            prog.AddLinearConstraint(activation_vars[k + 1] <= activation_vars[k])
        # Ensure at least one is active if the node is active.
        # If the node is inactive, then all will be deactivated.
        prog.AddLinearConstraint(activation_vars[0] == node.active)
        rules = [node.rules[0] for k in range(node.max_children)]
        for rule_k, (rule, activation_var) in enumerate(zip(rules, activation_vars)):
            add_edges_for_rule(node, rule, rule_k, activation_var)
            # Each rule activation incurs an independent score based
            # on its log-prob; a rule being active disables the score
            # from the previous activation and enables the current score.
            if rule_k > 0:
                last_score = np.log(node.rule_probs[rule_k - 1].detach().item())
            else:
                last_score = 0.
            this_score = np.log(node.rule_probs[rule_k].detach().item())
            prog.AddLinearCost(-activation_var * (-last_score + this_score))

    else:
        raise NotImplementedError(type(node))
    
# Now that the DiGraph is fully formed, go in an constrain node
# activation vars to depend on explanatory incoming edges.
for node in observed_nodes:
    prog.AddLinearConstraint(node.active == True)
for node in all_nodes:
    if node is root:
        continue
    incoming_edges = parse_graph.in_edges(nbunch=node, data="active")
    activations = [edge[-1] for edge in incoming_edges]
    prog.AddLinearConstraint(sum(activations) == node.active)

solver = GurobiSolver()
options = SolverOptions()
logfile = "/tmp/gurobi_%s.log" % datetime.now().strftime("%Y%m%dT%H%M%S")
os.system("rm -f %s" % logfile)
options.SetOption(solver.id(), "LogFile", logfile)
options.SetOption(solver.id(), "MIPGap", 1E-3)
N_solutions = 1
if N_solutions > 1:
    options.SetOption(solver.id(), "PoolSolutions", N_solutions)
    options.SetOption(solver.id(), "PoolSearchMode", 2)

result = solver.Solve(prog, None, options)
# Hacky method getter because `num_suboptimal_solution()` was bound with () in its
# method name. Should fix this upstream!
actual_N_solutions = getattr(result, "num_suboptimal_solution()")()
if actual_N_solutions != N_solutions:
    logging.warning("MIP got %d solutions, but requested %d. ", actual_N_solutions, N_solutions)
print("Optimization success?: ", result.is_success())
print("Logfile: ")
with open(logfile) as f:
    print(f.read())



Optimization success?:  True
Logfile: 

Gurobi 9.0.2 (linux64) logging started Tue Nov 30 03:59:47 2021

Gurobi Optimizer version 9.0.2 build v9.0.2rc0 (linux64)
Optimize a model with 42 rows, 95 columns and 198 nonzeros
Model fingerprint: 0x23c2ae0d
Variable types: 0 continuous, 95 integer (95 binary)
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [2e-01, 8e+02]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 1e+00]
Found heuristic solution: objective 1205.6117474
Presolve removed 19 rows and 17 columns
Presolve time: 0.00s
Presolved: 23 rows, 78 columns, 163 nonzeros
Variable types: 0 continuous, 78 integer (78 binary)

Root relaxation: objective 6.811424e+02, 30 iterations, 0.00 seconds

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time

     0     0  681.14242    0    5 1205.61175  681.14242  43.5%     -    0s
H    0     0                   