In [1]:
%load_ext autoreload
%autoreload 2

import time
import networkx as nx
import numpy as np
import logging
from copy import deepcopy
from collections import namedtuple
import torch
import pyro

from spatial_scene_grammars.nodes import *
from spatial_scene_grammars.rules import *
from spatial_scene_grammars.scene_grammar import *
from spatial_scene_grammars.parsing import *

In [37]:
from spatial_scene_grammars_examples.singles_pairs.grammar_constituency import *
#from spatial_scene_grammars_examples.singles_pairs.grammar_dependency import *
pyro.set_rng_seed(42)

grammar = SpatialSceneGrammar(
    root_node_type = Root,
    root_node_tf = torch.eye(4)
)
ground_truth_tree = grammar.sample_tree(detach=True)
observed_nodes = ground_truth_tree.get_observed_nodes()
print("Observed %d objects" % len(observed_nodes))
print("Objects: ", observed_nodes)

Observed 8 objects
Objects:  [<spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Root object at 0x7ff239f052b0>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7ff239f05be0>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7ff23a13eb00>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7ff23b6b5470>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7ff23b6b53c8>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7ff23b6b52b0>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7ff23b6b5320>, <spatial_scene_grammars_examples.singles_pairs.grammar_constituency.Object object at 0x7ff23b6b5f98>]


In [40]:
# Build up a big set of candidate intermediate nodes with both
# top-down and bottom-up generation.
candidate_intermediate_nodes = generate_candidate_intermediate_nodes(
    grammar, observed_nodes, max_recursion_depth=10, verbose=True
)
for node in candidate_intermediate_nodes:
    print(node.name, node.tf)

Expanding  Root_281
Expanding  Object_470
New candidate:  Singles_319  at  tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
New candidate:  Pair_355  at  tensor([[-0.1789, -0.5387, -0.8233,  1.1561],
        [ 0.4725, -0.7810,  0.4083,  0.3965],
        [-0.8630, -0.3160,  0.3943, -2.4661],
        [ 0.0000,  0.0000,  0.0000,  1.0000]])
New candidate:  Pair_355  at  tensor([[-0.1789, -0.5387, -0.8233,  1.1561],
        [ 0.4725, -0.7810,  0.4083,  0.3965],
        [-0.8630, -0.3160,  0.3943, -2.4661],
        [ 0.0000,  0.0000,  0.0000,  1.0000]])
Expanding  Object_471
New candidate:  Singles_320  at  tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
New candidate:  Pair_356  at  tensor([[-0.5703, -0.4390,  0.6943,  0.1985],
        [-0.0829,  0.8716,  0.4831,  0.4377],
        [-0.8172,  0.2180, -0.5335,  0.2162],
        [ 0.0000,  0.0000,  0.0000,  1.0000]])
New candidate: 

In [41]:
parse_trees = infer_mle_tree_with_mip_from_proposals(
    grammar, observed_nodes, candidate_intermediate_nodes, verbose=True, N_solutions=1
)

Candidate edge Root_281 --(0.000000, 0)-> Singles_317
Candidate edge Root_281 --(0.000000, 0)-> Singles_320
Candidate edge Root_281 --(0.000000, 0)-> Singles_321
Candidate edge Root_281 --(0.000000, 0)-> Singles_322
Candidate edge Root_281 --(0.000000, 0)-> Singles_319
Candidate edge Root_281 --(0.000000, 0)-> Singles_324
Candidate edge Root_281 --(0.000000, 0)-> Singles_325
Candidate edge Root_281 --(0.000000, 0)-> Singles_323
Candidate edge Root_281 --(0.000000, 1)-> Pairs_317
Candidate edge Root_281 --(0.000000, 1)-> Pairs_334
Candidate edge Root_281 --(0.000000, 1)-> Pairs_336
Candidate edge Root_281 --(0.000000, 1)-> Pairs_328
Candidate edge Root_281 --(0.000000, 1)-> Pairs_345
Candidate edge Root_281 --(0.000000, 1)-> Pairs_330
Candidate edge Root_281 --(0.000000, 1)-> Pairs_337
Candidate edge Root_281 --(0.000000, 1)-> Pairs_339
Candidate edge Root_281 --(0.000000, 1)-> Pairs_340
Candidate edge Root_281 --(0.000000, 1)-> Pairs_331
Candidate edge Root_281 --(0.000000, 1)-> Pairs_

Candidate edge Pairs_336 --(-6.843780, 2)-> Pair_359
Candidate edge Pairs_336 --(-6.066871, 2)-> Pair_357
Candidate edge Pairs_328 --(-6.643922, 0)-> Pair_351
Candidate edge Pairs_328 --(-8.182841, 0)-> Pair_352
Candidate edge Pairs_328 --(-8.483766, 0)-> Pair_353
Candidate edge Pairs_328 --(-9.552021, 0)-> Pair_355
Candidate edge Pairs_328 --(-8.372935, 0)-> Pair_361
Candidate edge Pairs_328 --(-5.893601, 0)-> Pair_356
Candidate edge Pairs_328 --(-6.742184, 0)-> Pair_358
Candidate edge Pairs_328 --(-8.167213, 0)-> Pair_360
Candidate edge Pairs_328 --(-6.843780, 0)-> Pair_359
Candidate edge Pairs_328 --(-6.066871, 0)-> Pair_357
Candidate edge Pairs_328 --(-6.643922, 1)-> Pair_351
Candidate edge Pairs_328 --(-8.182841, 1)-> Pair_352
Candidate edge Pairs_328 --(-8.483766, 1)-> Pair_353
Candidate edge Pairs_328 --(-9.552021, 1)-> Pair_355
Candidate edge Pairs_328 --(-8.372935, 1)-> Pair_361
Candidate edge Pairs_328 --(-5.893601, 1)-> Pair_356
Candidate edge Pairs_328 --(-6.742184, 1)-> Pa

Candidate edge Pairs_339 --(-9.552021, 0)-> Pair_355
Candidate edge Pairs_339 --(-8.372935, 0)-> Pair_361
Candidate edge Pairs_339 --(-5.893601, 0)-> Pair_356
Candidate edge Pairs_339 --(-6.742184, 0)-> Pair_358
Candidate edge Pairs_339 --(-8.167213, 0)-> Pair_360
Candidate edge Pairs_339 --(-6.843780, 0)-> Pair_359
Candidate edge Pairs_339 --(-6.066871, 0)-> Pair_357
Candidate edge Pairs_339 --(-6.643922, 1)-> Pair_351
Candidate edge Pairs_339 --(-8.182841, 1)-> Pair_352
Candidate edge Pairs_339 --(-8.483766, 1)-> Pair_353
Candidate edge Pairs_339 --(-9.552021, 1)-> Pair_355
Candidate edge Pairs_339 --(-8.372935, 1)-> Pair_361
Candidate edge Pairs_339 --(-5.893601, 1)-> Pair_356
Candidate edge Pairs_339 --(-6.742184, 1)-> Pair_358
Candidate edge Pairs_339 --(-8.167213, 1)-> Pair_360
Candidate edge Pairs_339 --(-6.843780, 1)-> Pair_359
Candidate edge Pairs_339 --(-6.066871, 1)-> Pair_357
Candidate edge Pairs_339 --(-6.643922, 2)-> Pair_351
Candidate edge Pairs_339 --(-8.182841, 2)-> Pa

Candidate edge Pair_358 --(-131.839598, 1)-> Object_471
Candidate edge Pair_358 --(-112.205863, 1)-> Object_472
Candidate edge Pair_358 --(8.654081, 1)-> Object_473
Candidate edge Pair_358 --(7.777871, 1)-> Object_474
Candidate edge Pair_358 --(-485.211569, 1)-> Object_475
Candidate edge Pair_358 --(-495.561490, 1)-> Object_476
Candidate edge Pair_360 --(-1101.228843, 0)-> Object_470
Candidate edge Pair_360 --(-400.485651, 0)-> Object_471
Candidate edge Pair_360 --(-455.057559, 0)-> Object_472
Candidate edge Pair_360 --(-485.211569, 0)-> Object_473
Candidate edge Pair_360 --(-493.355847, 0)-> Object_474
Candidate edge Pair_360 --(8.654081, 0)-> Object_475
Candidate edge Pair_360 --(2.522238, 0)-> Object_476
Candidate edge Pair_360 --(-1101.228843, 1)-> Object_470
Candidate edge Pair_360 --(-400.485651, 1)-> Object_471
Candidate edge Pair_360 --(-455.057559, 1)-> Object_472
Candidate edge Pair_360 --(-485.211569, 1)-> Object_473
Candidate edge Pair_360 --(-493.355847, 1)-> Object_474
Ca

Candidate edge Pairs_342 --(-8.483766, 0)-> Pair_353
Candidate edge Pairs_342 --(-9.552021, 0)-> Pair_355
Candidate edge Pairs_342 --(-8.372935, 0)-> Pair_361
Candidate edge Pairs_342 --(-5.893601, 0)-> Pair_356
Candidate edge Pairs_342 --(-6.742184, 0)-> Pair_358
Candidate edge Pairs_342 --(-8.167213, 0)-> Pair_360
Candidate edge Pairs_342 --(-6.843780, 0)-> Pair_359
Candidate edge Pairs_342 --(-6.066871, 0)-> Pair_357
Candidate edge Pairs_342 --(-6.643922, 1)-> Pair_351
Candidate edge Pairs_342 --(-8.182841, 1)-> Pair_352
Candidate edge Pairs_342 --(-8.483766, 1)-> Pair_353
Candidate edge Pairs_342 --(-9.552021, 1)-> Pair_355
Candidate edge Pairs_342 --(-8.372935, 1)-> Pair_361
Candidate edge Pairs_342 --(-5.893601, 1)-> Pair_356
Candidate edge Pairs_342 --(-6.742184, 1)-> Pair_358
Candidate edge Pairs_342 --(-8.167213, 1)-> Pair_360
Candidate edge Pairs_342 --(-6.843780, 1)-> Pair_359
Candidate edge Pairs_342 --(-6.066871, 1)-> Pair_357
Candidate edge Pairs_342 --(-6.643922, 2)-> Pa