In [1]:
from __future__ import print_function
from collections import namedtuple
from copy import deepcopy
import datetime
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import traceback
import time
import yaml

import pydrake  # MUST BE BEFORE TORCH OR PYRO
from pydrake.common.eigen_geometry import Quaternion, AngleAxis, Isometry3
from pydrake.math import RollPitchYaw, RotationMatrix
import pyro
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import torch.multiprocessing as mp
mp.set_sharing_strategy('file_system')
from multiprocessing.managers import SyncManager
from tensorboardX import SummaryWriter
import networkx

import scene_generation.data.dataset_utils as dataset_utils
from scene_generation.models.probabilistic_scene_grammar_nodes import *
from scene_generation.models.probabilistic_scene_grammar_nodes_mug_shelf import *
from scene_generation.models.probabilistic_scene_grammar_model import *
from scene_generation.models.probabilistic_scene_grammar_fitting import *

torch.set_default_tensor_type(torch.DoubleTensor)

In [2]:
seed = int(time.time()) % (2**32-1)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.set_default_tensor_type(torch.DoubleTensor)
pyro.enable_validation(True)
pyro.clear_param_store()

# CONFIGURATION STUFF
root_node_type = MugShelf
root_node = root_node_type()
hyper_parse_tree = generate_hyperexpanded_parse_tree(root_node, max_iters=8)
guide_gvs = hyper_parse_tree.get_global_variable_store()

train_dataset = dataset_utils.ScenesDataset("../../data/mug_shelf/mug_rack_environments_human_train/")
test_dataset = dataset_utils.ScenesDataset("../../data/mug_shelf/mug_rack_environments_human_test/")
print("%d training examples" % len(train_dataset))
print("%d test examples" % len(test_dataset))

pyro.get_param_store().load("../../data/mug_shelf/icra_runs/1/param_store_best_on_test.pyro")

for var_name in guide_gvs.keys():
    guide_gvs[var_name][0] = pyro.param(var_name + "_est",
                                        guide_gvs[var_name][0],
                                        constraint=guide_gvs[var_name][1].support)

60 training examples
40 test examples


In [4]:
def draw_environment_for_video(yaml_env, parse_tree=None):
    alpha = 1.0
    if parse_tree is not None:
        alpha = 0.95
    dataset_utils.DrawYamlEnvironment(yaml_env, base_environment_type="mug_shelf", alpha=alpha)
    node_class_to_color_dict = {"MugShelf":[0., 0., 1.], "MugShelfLevel":[0., 1., 0.],
                               "MugIntermediate": [1., 0.5, 0.5], "Mug": [1., 0., 0.]}
    if parse_tree is not None:
        draw_parse_tree_meshcat(parse_tree, node_class_to_color_dict=node_class_to_color_dict, alpha=1.0)

In [9]:
draw_environment_for_video(test_dataset[1])

Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6000...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/
Connected to meshcat-server.


In [7]:
# Draw all ims from the test set and save them
for i, env in enumerate(test_dataset):
    print("Showing %d" % i)
    draw_environment_for_video(env)
    time.sleep(2.)

Showing 0
Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6000...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/
Connected to meshcat-server.
Showing 1
Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6000...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/
Connected to meshcat-server.
Showing 2
Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6000...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/
Connected to meshcat-server.
Showing 3
Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6000...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/
Connected to meshcat-server.
Showing 4
Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6000...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/
Connected to meshcat-server.
Showing 5
Connecting to meshcat-server a

In [13]:
# Generate 100 example environments with pre and post projection. Format into list and save
N = 100
all_generated_run_info = []
for k in range(N):
    print(k)
    while True:
        try:
            parse_tree = generate_unconditioned_parse_tree(root_node=root_node, initial_gvs=guide_gvs)
            pre_projection_yaml_env = convert_tree_to_yaml_env(parse_tree)
            post_projection_yaml_env = ProjectEnvironmentToFeasibility(
                pre_projection_yaml_env, base_environment_type="mug_shelf",
                make_nonpenetrating=True, make_static=True)[-1]
            run_info = {"parse_tree": parse_tree,
                        "pre_projection_yaml_env": pre_projection_yaml_env,
                        "post_projection_yaml_env": post_projection_yaml_env}
            all_generated_run_info.append(run_info)
            break
        except:
            pass
torch.save(all_generated_run_info, "all_generated_run_infos_trained.run_info")

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [7]:
# Generate a single example and save it. Is it possible?
all_generated_run_info = torch.load("all_generated_run_infos_trained.run_info")
plt.figure(figsize=(5, 5), dpi=300)
for i, run_info in enumerate(all_generated_run_info):
    print("Drawing %d" % i)
    draw_environment_for_video(run_info["pre_projection_yaml_env"], parse_tree=run_info["parse_tree"])
    #draw_environment_for_video(run_info["post_projection_yaml_env"] )
    time.sleep(1.)
    break


Drawing 0
Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6000...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/
Connected to meshcat-server.
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/


    Please see `help(pydrake.common.deprecation)` for more information.
  mbp.AddForceElement(UniformGravityFieldElement())
    Please see `help(pydrake.common.deprecation)` for more information.
  visualizer._DoPublish(mbp_context, [])


<Figure size 1500x1500 with 0 Axes>

In [None]:
# Generate a grid of sampled feasible environments from the loaded model
parse_tree = generate_unconditioned_parse_tree(root_node=root_node, initial_gvs=guide_gvs)
yaml_env = convert_tree_to_yaml_env(parse_tree)
yaml_env = ProjectEnvironmentToFeasibility(yaml_env, base_environment_type="mug_shelf",
                                           make_nonpenetrating=True, make_static=True)[-1]
cam_isom = cam_quat_base = RollPitchYaw(
    68.*np.pi/180.,
    0.*np.pi/180,
    80.*np.pi/180.).ToQuaternion()
cam_trans_base = np.array([0.67, -0.11, 0.5])
cam_tf_base = Isometry3(quaternion=cam_quat_base,
                        translation=cam_trans_base)
im = dataset_utils.DrawYamlEnvironmentWithBlender(yaml_env, base_environment_type="mug_shelf", cam_isom=cam_tf_base)

node_class_to_color_dict = {"MugShelf":[0., 0., 1.], "MugShelfLevel":[0., 1., 0.],
                            "MugIntermediate": [1., 0.5, 0.5], "Mug": [1., 0., 0.]}        
dataset_utils.DrawYamlEnvironment(yaml_env, base_environment_type="mug_shelf", alpha=0.95)
draw_parse_tree_meshcat(parse_tree, node_class_to_color_dict=node_class_to_color_dict, alpha=1.0)


In [None]:
# Parse the entire test set
all_test_parse_trees = guess_parse_trees_batch_async(test_dataset, root_node_type=root_node_type, guide_gvs=guide_gvs.detach())

In [None]:
# Parse the outlier dataset
outlier_dataset = dataset_utils.ScenesDataset("../../data/mug_shelf/mug_rack_environments_human_adversarial.yaml")
outlier_dataset_parse_trees = guess_parse_trees_batch_async(outlier_dataset, root_node_type=root_node_type, guide_gvs=guide_gvs.detach())
outlier_dataset_parse_trees_scores = []
for parse_tree in outlier_dataset_parse_trees:
    joint_score = parse_tree.get_total_log_prob()[0]
    #latents_score = parse_tree.get_total_log_prob(include_observed=False)[0]
    outlier_dataset_parse_trees_scores.append((joint_score).item())
print(outlier_dataset_parse_trees_scores)

In [5]:
k = 1
scores_by_node = outlier_dataset_parse_trees[k].get_total_log_prob()[1]
min_key = list(scores_by_node.keys())[0]
for key in scores_by_node.keys():
    if scores_by_node[key] < scores_by_node[min_key]:
        min_key = key
print(min_key, scores_by_node[min_key])
print(scores_by_node)
dataset_utils.DrawYamlEnvironment(outlier_dataset[k], base_environment_type="mug_shelf", alpha=0.75)
draw_parse_tree_meshcat(outlier_dataset_parse_trees[k],color_by_score=True)

NameError: name 'outlier_dataset_parse_trees' is not defined

In [None]:
all_test_parse_trees_scores = []
for parse_tree in all_test_parse_trees:
    joint_score = parse_tree.get_total_log_prob()[0]
    latents_score = parse_tree.get_total_log_prob(include_observed=False)[0]
    all_test_parse_trees_scores.append((joint_score).item())
order = np.argsort(all_test_parse_trees_scores)
test_yaml_envs_in_sorted_order = [test_dataset[x] for x in order]
all_test_parse_trees_in_sorted_order = [all_test_parse_trees[x] for x in order]
all_test_parse_trees_scores_in_sorted_order = [all_test_parse_trees_scores[x] for x in order]
plt.hist(all_test_parse_trees_scores, density=True, bins=np.linspace(0., 150., 10))
plt.title("Histogram of ELBO over Test Set")
plt.xlabel("ELBO")
plt.ylabel("Relative occurance")

In [6]:
dataset_utils.DrawYamlEnvironment(test_yaml_envs_in_sorted_order[0], base_environment_type="mug_shelf", alpha=1.0)
draw_parse_tree_meshcat(all_test_parse_trees_in_sorted_order[0], color_by_score=True)

NameError: name 'test_yaml_envs_in_sorted_order' is not defined

In [None]:
import scipy
import os
from scipy.ndimage import filters
import matplotlib.pyplot as plt

def load_data_from_dir(path):
    runs = os.listdir(path)
    lesioned_runs = []
    full_runs = []
    for run in runs:
        data = np.loadtxt(path + "/" + run, skiprows=1, delimiter=",")
        full_runs.append(data[:, 2])
    inds = data[:, 1]
    if len(runs) > 1:
        return inds, np.vstack(full_runs)
    else:
        return inds, np.array(full_runs).T

def apply_filtering(data, sigma):
    return filters.gaussian_filter1d(data, sigma, axis=0, mode='nearest')

def make_train_curve_plot(ax):
    inds_train, full_runs_train = load_data_from_dir("train_joint_score")
    print(inds_train.shape, full_runs_train.shape)
    full_runs_train_smoothed = apply_filtering(full_runs_train, sigma=20.)
    
    ax.plot(full_runs_train, alpha=0.2, color='darkblue')
    ax.plot(full_runs_train_smoothed, color='darkblue', alpha=0.8, label="Train")
    
    inds_test, full_runs_test = load_data_from_dir("test_joint_score")
    full_runs_test_smoothed = apply_filtering(full_runs_test, sigma=3.)
    ax.plot(inds_test, full_runs_test, alpha=0.2, color='darkblue', linestyle="--")
    ax.plot(inds_test, full_runs_test_smoothed, color='darkblue', linestyle="--", alpha=1.0, label="Test")

    ax.grid(True)
    plt.legend()
    ax.set_ylabel("Mean log p(t)")
    ax.set_xlabel("Epoch")
plt.figure(figsize=(6, 1), dpi=300)
make_train_curve_plot(plt.gca())