In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from IPython.display import clear_output
from pydrake.all import (
    AbstractValue,
    AddMultibodyPlantSceneGraph,
    Concatenate,
    DiagramBuilder,
    JointSliders,
    LeafSystem,
    MeshcatPoseSliders,
    MeshcatVisualizer,
    MeshcatVisualizerParams,
    Parser,
    PointCloud,
    RandomGenerator,
    Rgba,
    RigidTransform,
    RotationMatrix,
    Simulator,
    StartMeshcat,
    UniformlyRandomRotationMatrix,
)

from manipulation import running_as_notebook
from manipulation.scenarios import AddFloatingRpyJoint, AddRgbdSensors, ycb
from manipulation.utils import ConfigureParser
import sponana.utils

In [16]:
from sponana.grasping.grasp_generator import get_unified_point_cloud, BananaSystem, ScoreSystem, GenerateAntipodalGraspCandidate

In [4]:
# Start the visualizer.
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7006


In [11]:
def make_internal_model():
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)
    parser = Parser(plant)
    sponana.utils.configure_parser(parser)
    parser.AddModelsFromUrl("package://sponana/grasping/banana_and_gripper.dmd.yaml")
    plant.Finalize()
    return builder.Build()


In [15]:
def grasp_score_inspector():
    meshcat.Delete()

    # Finally, we'll build a diagram for running our visualization
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)
    parser = Parser(plant)
    sponana.utils.configure_parser(parser)
    
    parser.AddModelsFromUrl("package://sponana/grasping/banana_and_gripper.dmd.yaml")

    AddFloatingRpyJoint(
        plant,
        plant.GetFrameByName("body"),
        plant.GetModelInstanceByName("gripper"),
    )
    plant.Finalize()

    meshcat.DeleteAddedControls()
    params = MeshcatVisualizerParams()
    params.prefix = "planning"
    visualizer = MeshcatVisualizer.AddToBuilder(
        builder, scene_graph, meshcat, params
    )

    environment = BananaSystem()
    environment_context = environment.CreateDefaultContext()

    cloud = get_unified_point_cloud(
        environment,
        environment_context,
        meshcat=meshcat
    )
    meshcat.SetObject("planning/cloud", cloud, point_size=0.003)

    internal_model = make_internal_model()
    score = builder.AddSystem(
        ScoreSystem(internal_model, cloud, plant.GetBodyByName("body").index(), meshcat=meshcat)
    )
    builder.Connect(plant.get_body_poses_output_port(), score.get_input_port())

    lower_limit = [-1, -1, 0, -np.pi, -np.pi / 4.0, -np.pi / 4.0]
    upper_limit = [1, 1, 1, 0, np.pi / 4.0, np.pi / 4.0]
    q0 = [-0.05, -0.5, 0.25, -np.pi / 2.0, 0, 0]
    default_interactive_timeout = None if running_as_notebook else 1.0
    sliders = builder.AddSystem(
        JointSliders(
            meshcat,
            plant,
            initial_value=q0,
            lower_limit=lower_limit,
            upper_limit=upper_limit,
            decrement_keycodes=[
                "KeyQ",
                "KeyS",
                "KeyA",
                "KeyJ",
                "KeyK",
                "KeyU",
            ],
            increment_keycodes=[
                "KeyE",
                "KeyW",
                "KeyD",
                "KeyL",
                "KeyI",
                "KeyO",
            ],
        )
    )
    diagram = builder.Build()
    sliders.Run(diagram, default_interactive_timeout)
    meshcat.DeleteAddedControls()


grasp_score_inspector()

cost: -14.804734118947348
normal terms: [0.02230194 0.05648695 0.0973478  0.06090133 0.10572309 0.17236773
 0.1877837  0.02573566]


# Generate grasps

In [17]:
# For visualization
def draw_grasp_candidate(X_G, meshcat, prefix="gripper", draw_frames=True):
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)
    parser = Parser(plant)
    ConfigureParser(parser)
    parser.AddModelsFromUrl(
        "package://manipulation/schunk_wsg_50_welded_fingers.sdf"
    )
    plant.WeldFrames(plant.world_frame(), plant.GetFrameByName("body"), X_G)
    plant.Finalize()

    # frames_to_draw = {"gripper": {"body"}} if draw_frames else {}
    params = MeshcatVisualizerParams()
    params.prefix = prefix
    params.delete_prefix_on_initialization_event = False
    visualizer = MeshcatVisualizer.AddToBuilder(
        builder, scene_graph, meshcat, params
    )
    diagram = builder.Build()
    context = diagram.CreateDefaultContext()
    diagram.ForcedPublish(context)

In [23]:
def sample_grasps_example():
    meshcat.Delete()
    rng = np.random.default_rng()

    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)
    parser = Parser(plant)
    sponana.utils.configure_parser(parser)
    
    parser.AddModelsFromUrl("package://sponana/grasping/banana_and_gripper.dmd.yaml")

    AddFloatingRpyJoint(
        plant,
        plant.GetFrameByName("body"),
        plant.GetModelInstanceByName("gripper"),
    )
    plant.Finalize()

    params = MeshcatVisualizerParams()
    params.prefix = "planning"
    visualizer = MeshcatVisualizer.AddToBuilder(
        builder, scene_graph, meshcat, params
    )
    diagram = builder.Build()
    context = diagram.CreateDefaultContext()
    diagram.ForcedPublish(context)

    # Hide the planning gripper
    meshcat.SetProperty("planning/gripper", "visible", False)

    environment = BananaSystem()
    environment_context = environment.CreateDefaultContext()
    cloud = get_unified_point_cloud(
        environment,
        environment_context,
        meshcat=meshcat
    )
    meshcat.SetObject("planning/cloud", cloud, point_size=0.003)

    plant.GetMyContextFromRoot(context)
    scene_graph.GetMyContextFromRoot(context)

    internal_model = make_internal_model()
    internal_model_context = internal_model.CreateDefaultContext()
    costs = []
    X_Gs = []
    for i in range(100 if running_as_notebook else 2):
        cost, X_G = GenerateAntipodalGraspCandidate(
            internal_model, internal_model_context, cloud, rng
        )
        if np.isfinite(cost):
            costs.append(cost)
            X_Gs.append(X_G)

    indices = np.asarray(costs).argsort()[:5]
    for rank, index in enumerate(indices):
        draw_grasp_candidate(
            X_Gs[index], meshcat, prefix=f"{rank}th best", draw_frames=False
        )

sample_grasps_example()