In [16]:
from enum import Enum
import numpy as np
from pydrake.all import (StartMeshcat, MeshcatVisualizer, DiagramBuilder, Parser, ConstantVectorSource,
                        Simulator, LeafSystem, RigidTransform, RotationMatrix, ProcessModelDirectives,
                        UniformlyRandomRotationMatrix, RandomGenerator, Rgba, LoadModelDirectivesFromString,
                        PiecewisePose, TrajectorySource, Integrator, JacobianWrtVariable, AbstractValue, 
                        PointCloud, Concatenate, AddMultibodyPlantSceneGraph, MeshcatVisualizerParams, Role,
                        RollPitchYaw)
# from manipulation.clutter import GenerateAntipodalGraspCandidate
from manipulation.meshcat_utils import AddMeshcatTriad
from manipulation import FindResource
from manipulation.scenarios import (AddIiwaDifferentialIK, MakeManipulationStation, AddPackagePaths)

# For diagram viz
import matplotlib.pyplot as plt
from IPython.display import SVG
import pydot

In [3]:
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7000


In [67]:
rs = np.random.RandomState()  # this is for python
generator = RandomGenerator(rs.randint(1000)) # this is for c++

q0 = [-1.57, 0.1, 0, -1.2, 0, 1.6, 0]

trash = f"""
        directives:
        - add_model:
            name: can
            file: package://RoboticRecycling/models/can/cola_can.sdf
            
        - add_model:
            name: bottle
            file: package://RoboticRecycling/models/bottle/bottle.sdf
            
        - add_model:
            name: banana
            file: package://RoboticRecycling/models/banana/Banana.sdf
            
        - add_model:
            name: orange
            file: package://RoboticRecycling/models/orange/Orange.sdf
            
        - add_model:
            name: coffee
            file: package://RoboticRecycling/models/coffee/coffee.sdf
        """

item_names = ["cola_can_body_link", "bottle_body_link", 
              "Banana_body_link", "Orange_body_link", "coffee_body_link"]


class PlannerState(Enum):
    DEBUG = -1
    WAIT = 0
    CHOOSE = 1
    SELECT_GRASP = 2
    PICK = 3
    PLACE = 4
    HOME = 5
    CLOSE = 6
    OPEN = 7
    
    
class GarbageType(Enum):
    TRASH = 0
    RECYCLE = 1
    ORGANIC = 2

    
# Contains iiwa, bins, table, floor
def make_internal_model():
    station = MakeManipulationStation(
        filename=FindResource("/home/jared/Desktop/RoboticRecycling/models/internal_model.dmd.yaml"),
        package_xmls=["./package.xml"])
    return station


# Contains table & trash
def make_trash_model():
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)
    parser = Parser(plant)
    parser.package_map().AddPackageXml('./package.xml')
    AddPackagePaths(parser)
    model_directives = trash + """
    - add_model:
        name: table
        file: package://RoboticRecycling/models/picnic_table/garden_table.sdf
    """
    directives = LoadModelDirectivesFromString(model_directives)
    ProcessModelDirectives(directives, parser)
    plant.Finalize()
    return builder.Build()
    

class PseudoInverseController(LeafSystem):
    def __init__(self, plant):
        LeafSystem.__init__(self)
        self._plant = plant
        self._plant_context = plant.CreateDefaultContext()
        self._iiwa = plant.GetModelInstanceByName("iiwa")
        self._G = plant.GetBodyByName("body").body_frame()
        self._W = plant.world_frame()

        self.V_G_port = self.DeclareVectorInputPort("V_WG", 6)
        self.q_port = self.DeclareVectorInputPort("iiwa_position", 7)
        self.DeclareVectorOutputPort("iiwa_velocity", 7, self.CalcOutput)
        self.iiwa_start = plant.GetJointByName("iiwa_joint_1").velocity_start()
        self.iiwa_end = plant.GetJointByName("iiwa_joint_7").velocity_start()

    def CalcOutput(self, context, output):
        V_G = self.V_G_port.Eval(context)
        q = self.q_port.Eval(context)
        self._plant.SetPositions(self._plant_context, self._iiwa, q)
        J_G = self._plant.CalcJacobianSpatialVelocity(
            self._plant_context, JacobianWrtVariable.kV,
            self._G, [0,0,0], self._W, self._W)
        J_G = J_G[:,self.iiwa_start:self.iiwa_end+1] # Only iiwa terms.
        v = np.linalg.pinv(J_G).dot(V_G)
        output.SetFromVector(v)


class Planner(LeafSystem):
    def __init__(self, plant):
        LeafSystem.__init__(self)
        
        self._gripper_body_index = plant.GetBodyByName("body").index()
        self.DeclareAbstractInputPort("body_poses", AbstractValue.Make([RigidTransform()]))
        self._grasp_index = self.DeclareAbstractInputPort("grasp_selection", 
            AbstractValue.Make((np.inf, RigidTransform()))).get_index()
        self._wsg_state_index = self.DeclareVectorInputPort("wsg_state", 2).get_index()
        self.DeclareVectorInputPort("iiwa_position", 7)
        
        self.X_WHome = RigidTransform(
            RotationMatrix([
                [1, 0, 0],
                [0, 0, 1],
                [0, -1, 0]
            ]),
            [0, -0.45, 0.65])
        self.X_WRecycle = RigidTransform(
            RotationMatrix([
                [0, 0, -1],
                [1, 0, 0],
                [0, -1, 0]
            ]),
            [0.475, -0.2, 0.725])
        self.X_WTrash = RigidTransform(
            RotationMatrix([
                [0, 0, -1],
                [1, 0, 0],
                [0, -1, 0]
            ]),
            [0.475, 0, 0.725])
        self.X_WOrganic = RigidTransform(
            RotationMatrix([
                [0, 0, -1],
                [1, 0, 0],
                [0, -1, 0]
            ]),
            [0.475, 0.2, 0.725])
        
        self.eps = 1e-4
        self.init = True
        self.prev_time = None
        self.mode = PlannerState.DEBUG
        self.wsg_des = np.array([0.107])
        self.traj_WG = PiecewisePose.MakeLinear([0, np.inf], [RigidTransform(), RigidTransform()])
        self.garbage_type = GarbageType.RECYCLE
        
        self.DeclareVectorOutputPort("wsg_position", 1,
            lambda context, output: output.SetFromVector([self.wsg_des]))
        self.DeclareVectorOutputPort("V_WG", 6, self.SendTrajV)
        
        self.DeclarePeriodicUnrestrictedUpdateEvent(0.1, 0.0, self.Update)
        
    def Update(self, context, state):
        if self.mode == PlannerState.DEBUG:
            return
        
        if self.init:
            self.prev_time = context.get_time()
            self.init = False
            return
        
        if self.mode == PlannerState.WAIT:
            if context.get_time() - self.prev_time > 2.:
                self.mode = PlannerState.SELECT_GRASP
                # self.mode = PlannerState.CHOOSE
                
#         if self.mode == PlannerState.CHOOSE:
#             # Vision stuff
            
        if self.mode == PlannerState.SELECT_GRASP:
            print("Selecting grasp")
            [cost, X_WD] = self.get_input_port(self._grasp_index).Eval(context)
            print("Cost: ", cost, "\n", "TF: ", X_WD)
            if cost == np.inf:
                self.mode = PlannerState.WAIT
                self.prev_time = context.get_time()
            else:
                self.traj_WG = self.MakePickingTraj(context, X_WD)
                self.mode = PlannerState.PICK
                
        if self.mode == PlannerState.PICK:
            if self.TrajDone(context):
                self.mode = PlannerState.CLOSE
                self.wsg_des = np.array([0.0])
            # Handle error cases
            
        # Need a better way to determine when it's closed.
        if self.mode == PlannerState.CLOSE:
            if context.get_time() - self.prev_time > 2.:
                self.traj_WG = self.MakePlacingTraj(context, self.garbage_type)
                self.mode = PlannerState.PLACE
#             wsg_state = self.get_input_port(self._wsg_state_index).Eval(context)
#             if wsg_state[0] < 0.01:
            
        if self.mode == PlannerState.PLACE:
            if self.TrajDone(context):
                self.mode = PlannerState.OPEN
                self.wsg_des = np.array([0.107])
            # Handle error cases
            
        if self.mode == PlannerState.OPEN:
            wsg_state = self.get_input_port(self._wsg_state_index).Eval(context)
            if wsg_state[0] > 0.09:
                self.traj_WG = self.MakeHomeTraj(context)
                self.mode = PlannerState.HOME
                
        if self.mode == PlannerState.HOME:
            if self.TrajDone(context):
                # self.mode = PlannerState.CHOOSE
                print("Done")
                return
            

    def TrajDone(self, context):
        """
        Checks if we have completed our desired trajectory
            - Time elapsed
            - We're at the desired end pose
        If time is up and we're no longer moving, but we haven't 
            gone to our desired location, then make a micro-adjustment.
        """
        time = context.get_time()
        X_WG = self.get_input_port(0).Eval(context)[int(self._gripper_body_index)]
        X_WEnd = self.traj_WG.GetPose(np.inf)
        if X_WG.IsNearlyEqualTo(X_WEnd, self.eps):
            return True
        elif self.traj_WG.end_time() < time:
            self.traj_WG = PiecewisePose.MakeLinear([time, time+0.25], [X_WG, X_WEnd])
        
        return False

    def MakePickingTraj(self, context, X_WD):
        """
        Creates a basic picking trajectory
        """
        time = context.get_time()
        X_WG = self.get_input_port(0).Eval(context)[int(self._gripper_body_index)]
        X_WPrepick = RigidTransform(RotationMatrix(), [0, 0, 0.1]) @ X_WD
        traj_WG = PiecewisePose.MakeLinear([time, time+2, time+4], [X_WG, X_WPrepick, X_WD])
        
        return traj_WG
    
    def MakePlacingTraj(self, context, garbage_type):
        """
        Creates a placing trajectory to either trash, recycle, or organic bin
        """
        time = context.get_time()
        X_WG = self.get_input_port(0).Eval(context)[int(self._gripper_body_index)]
        if garbage_type == GarbageType.TRASH:
            traj_WG = PiecewisePose.MakeLinear([time, time+2, time+6], [X_WG, self.X_WHome, self.X_WTrash])
        elif garbage_type == GarbageType.RECYCLE:
            traj_WG = PiecewisePose.MakeLinear([time, time+2, time+6], [X_WG, self.X_WHome, self.X_WRecycle])
        elif garbage_type == GarbageType.ORGANIC:
            traj_WG = PiecewisePose.MakeLinear([time, time+2, time+6], [X_WG, self.X_WHome, self.X_WOrganic])
            
        return traj_WG
    
    def MakeHomeTraj(self, context):
        """
        Creates a basic trajectory for going to home position
        """
        time = context.get_time()
        X_WG = self.get_input_port(0).Eval(context)[int(self._gripper_body_index)]
        traj_WG = PiecewisePose.MakeLinear([time, time+4], [X_WG, self.X_WHome])
        
        return traj_WG

    def SendTrajV(self, context, output):
        """
        Callback for evaluating V_WG. Makes derivative of trajectory and publishes
        the V_WG corresponding to current time.
        """
        time = context.get_time()
        V_WG = self.traj_WG.GetVelocity(time)
        output.SetFromVector(V_WG)
    

# Taken from manipulation/clutter.py - slightly modified
def GraspCandidateCost(diagram, context, cloud, wsg_body_index=None, plant_system_name="plant",
                       scene_graph_system_name="scene_graph", adjust_X_G=False):
    plant = diagram.GetSubsystemByName(plant_system_name)
    plant_context = plant.GetMyMutableContextFromRoot(context)
    scene_graph = diagram.GetSubsystemByName(scene_graph_system_name)
    scene_graph_context = scene_graph.GetMyMutableContextFromRoot(context)
    if wsg_body_index:
        wsg = plant.get_body(wsg_body_index)
    else:
        wsg = plant.GetBodyByName("body")
        wsg_body_index = wsg.index()

    X_G = plant.GetFreeBodyPose(plant_context, wsg)

    # Transform cloud into gripper frame
    X_GW = X_G.inverse()
    p_GC = X_GW @ cloud.xyzs()

    # Crop to a region inside of the finger box.
    crop_min = [-.05, 0.1, -0.00625]
    crop_max = [.05, 0.1125, 0.00625]
    indices = np.all((crop_min[0] <= p_GC[0, :], p_GC[0, :] <= crop_max[0],
                      crop_min[1] <= p_GC[1, :], p_GC[1, :] <= crop_max[1],
                      crop_min[2] <= p_GC[2, :], p_GC[2, :] <= crop_max[2]),
                     axis=0)

    if adjust_X_G and np.sum(indices) > 0:
        p_GC_x = p_GC[0, indices]
        p_Gcenter_x = (p_GC_x.min() + p_GC_x.max()) / 2.0
        X_G.set_translation(X_G @ np.array([p_Gcenter_x, 0, 0]))
        plant.SetFreeBodyPose(plant_context, wsg, X_G)
        X_GW = X_G.inverse()

    query_object = scene_graph.get_query_output_port().Eval(scene_graph_context)

    # Check collisions between the gripper and the sink
    if query_object.HasCollisions():
        cost = np.inf
        return cost

    # Check collisions between the gripper and the point cloud. `margin`` must
    # be smaller than the margin used in the point cloud preprocessing.
    margin = 0.0
    for i in range(cloud.size()):
        distances = query_object.ComputeSignedDistanceToPoint(cloud.xyz(i),
                                                              threshold=margin)
        if distances:
            cost = np.inf
            return cost

    n_GC = X_GW.rotation().multiply(cloud.normals()[:, indices])

    # Reward sum |dot product of normals with gripper x|^2
    cost = -np.sum(n_GC[0, :]**2)
    return cost


def GenerateAntipodalGraspCandidate(diagram, context, cloud, rng, wsg_body_index=None,
                                    plant_system_name="plant", scene_graph_system_name="scene_graph"):
    """
    Picks a random point in the cloud, and aligns the robot finger with the
    x/y projection of the normal of that pixel. Perturbs z a little to find better grasps.
    """
    plant = diagram.GetSubsystemByName(plant_system_name)
    plant_context = plant.GetMyMutableContextFromRoot(context)
    scene_graph = diagram.GetSubsystemByName(scene_graph_system_name)
    scene_graph_context = scene_graph.GetMyMutableContextFromRoot(context)
    if wsg_body_index:
        wsg = plant.get_body(wsg_body_index)
    else:
        wsg = plant.GetBodyByName("body")
        wsg_body_index = wsg.index()

    if cloud.size() < 1:
        return np.inf, None

    index = rng.integers(0, cloud.size() - 1)

    # Use S for sample point/frame.
    p_WS = cloud.xyz(index)
    n_WS = cloud.normal(index)

    assert np.isclose(np.linalg.norm(n_WS),
                      1.0), f"Normal has magnitude: {np.linalg.norm(n_WS)}"

    # Modification: always keep gripper y aligned with world -z
    Gx = np.array([n_WS[0], n_WS[1], 0]) # Project onto XY plane
    Gx = Gx / np.linalg.norm(Gx)
    Gy = np.array([0.0, 0.0, -1.0]) # World downward
    Gz = np.cross(Gx, Gy)
    R_WG = RotationMatrix(np.vstack((Gx, Gy, Gz)).T)
    p_GS_G = [0.054 - 0.01, 0.1, 0] # Position of sample end wrt gripper
    p_WG = p_WS - R_WG @ p_GS_G

    # Try vertical perturbations
    min_z = -0.02
    max_z = 0.02
    for z in np.linspace(min_z, max_z, num=5):
        p_WG_2 = p_WG + np.array([0, 0, z])
        X_G = RigidTransform(R_WG, p_WG_2)
        plant.SetFreeBodyPose(plant_context, wsg, X_G)
        cost = GraspCandidateCost(diagram, context, cloud, adjust_X_G=True)
        X_G = plant.GetFreeBodyPose(plant_context, wsg)
        if np.isfinite(cost):
            return cost, X_G

    return np.inf, None
    
class GraspSelector(LeafSystem):
    def __init__(self, camera_body_indices):
        LeafSystem.__init__(self)
        model_point_cloud = AbstractValue.Make(PointCloud(0))
        self.DeclareAbstractInputPort("cloud0_W", model_point_cloud)
        self.DeclareAbstractInputPort("cloud1_W", model_point_cloud)
        self.DeclareAbstractInputPort(
            "body_poses", AbstractValue.Make([RigidTransform()]))

        port = self.DeclareAbstractOutputPort(
            "grasp_selection", lambda: AbstractValue.Make(
                (np.inf, RigidTransform())), self.SelectGrasp)
        port.disable_caching_by_default()

        # Crop box
        self._crop_lower = np.array([-.5, -.75, .39])
        self._crop_upper = np.array([.5, -.45, .475])
#         AddMeshcatTriad(meshcat, "X_WLower", X_PT=RigidTransform(RotationMatrix(), self._crop_lower))
#         AddMeshcatTriad(meshcat, "X_WUpper", X_PT=RigidTransform(RotationMatrix(), self._crop_upper))

        self._internal_model = make_internal_model()
        self._internal_model_context = self._internal_model.CreateDefaultContext()
        self._rng = np.random.default_rng()
        self._camera_body_indices = camera_body_indices
        
    def SelectGrasp(self, context, output):
        body_poses = self.get_input_port(2).Eval(context)
        pcd = []
        for i in range(2):
            cloud = self.get_input_port(i).Eval(context)
            pcd.append(cloud.Crop(self._crop_lower, self._crop_upper))
            pcd[i].EstimateNormals(radius=0.1, num_closest=30)

            # Flip normals toward camera
            X_WC = body_poses[self._camera_body_indices[i]]
            pcd[i].FlipNormalsTowardPoint(X_WC.translation())
        merged_pcd = Concatenate(pcd)
        down_sampled_pcd = merged_pcd.VoxelizedDownSample(voxel_size=0.005)
        meshcat.SetObject("cam0_output", down_sampled_pcd, point_size=0.003, rgba=Rgba(1.0, 0, 0))
        meshcat.SetTransform("cam0_output", RigidTransform())

        costs = []
        X_Gs = []
        for i in range(25):
            cost, X_G = GenerateAntipodalGraspCandidate(
                self._internal_model, self._internal_model_context,
                down_sampled_pcd, self._rng)
            if np.isfinite(cost):
                costs.append(cost)
                X_Gs.append(X_G)

        if len(costs) == 0:
            # Didn't find a viable grasp candidate
            X_WG = RigidTransform(
                RotationMatrix([
                    [1, 0, 0],
                    [0, 0, 1],
                    [0, -1, 0]
                ]),
                [0, -0.45, 0.65])
            output.set_value((np.inf, X_WG))
        else:
            best = np.argmin(costs)
            output.set_value((costs[best], X_Gs[best]))

        
class IIWA(LeafSystem):
    def __init__(self):
        LeafSystem.__init__(self)
            
        # Setup diagram builder components
        builder = DiagramBuilder()
        self.station = MakeManipulationStation(
            model_directives=trash,
            filename=FindResource("/home/jared/Desktop/RoboticRecycling/models/recycling.dmd.yaml"),
            package_xmls=["./package.xml"])
        builder.AddSystem(self.station)
        self.plant = self.station.GetSubsystemByName("plant")
        self.visualizer = MeshcatVisualizer.AddToBuilder(
            builder, self.station.GetOutputPort("query_object"), meshcat)  
#         self.collision = MeshcatVisualizer.AddToBuilder(
#             builder, self.station.GetOutputPort("query_object"), meshcat,
#             MeshcatVisualizerParams(role=Role.kProximity, prefix="collision"))
        
        # Add grasp selector
        grasp_selector = builder.AddNamedSystem("grasp_selector",
            GraspSelector(camera_body_indices=[
                              self.plant.GetBodyIndices(
                                  self.plant.GetModelInstanceByName("camera0"))[0],
                              self.plant.GetBodyIndices(
                                  self.plant.GetModelInstanceByName("camera1"))[0]
                          ]))
        builder.Connect(self.station.GetOutputPort("camera0_point_cloud"),
                        grasp_selector.get_input_port(0))
        builder.Connect(self.station.GetOutputPort("camera1_point_cloud"),
                        grasp_selector.get_input_port(1))
        builder.Connect(self.station.GetOutputPort("body_poses"),
                        grasp_selector.GetInputPort("body_poses"))
        
        # Planner
        planner = builder.AddNamedSystem("planner", Planner(self.plant))
        builder.Connect(self.station.GetOutputPort("body_poses"),
                    planner.GetInputPort("body_poses"))
        builder.Connect(grasp_selector.GetOutputPort("grasp_selection"),
                        planner.GetInputPort("grasp_selection"))
        builder.Connect(self.station.GetOutputPort("wsg_state_measured"),
                        planner.GetInputPort("wsg_state"))
        builder.Connect(self.station.GetOutputPort("iiwa_position_measured"),
                        planner.GetInputPort("iiwa_position"))
        
        # Add pseudo-inverse controller
        self.controller = builder.AddSystem(PseudoInverseController(self.plant))
        self.controller.set_name("PseudoInverseController")
        builder.Connect(planner.GetOutputPort("V_WG"), self.controller.GetInputPort("V_WG"))
        # Integrate controller velocity commands to get joint angles
        self.integrator = builder.AddSystem(Integrator(7))
        self.integrator.set_name("integrator")
        builder.Connect(self.controller.get_output_port(),
                        self.integrator.get_input_port())
        builder.Connect(self.integrator.get_output_port(),
                        self.station.GetInputPort("iiwa_position"))
        builder.Connect(self.station.GetOutputPort("iiwa_position_measured"),
                        self.controller.GetInputPort("iiwa_position"))
        
        # Gripper control
        builder.Connect(planner.GetOutputPort("wsg_position"),
                    self.station.GetInputPort("wsg_position"))
        
        # Finalize
        self.diagram = builder.Build()
        self.context = self.diagram.CreateDefaultContext()
        
        # Set current position
        self.integrator.set_integral_value(
            self.integrator.GetMyMutableContextFromRoot(self.context), q0)
            
        # Randomize poses of trash
        self._trash_model = make_trash_model()
        self.RandomizeTrash()
        
        # Publish context
        self.diagram.Publish(self.context)
        
    def RandomizeTrash(self):
        
        trash_context = self._trash_model.CreateDefaultContext()
        trash_plant = self._trash_model.GetSubsystemByName("plant")
        trash_plant_context = trash_plant.GetMyMutableContextFromRoot(trash_context)
        trash_scene_graph = self._trash_model.GetSubsystemByName("scene_graph")
        trash_scene_graph_context = trash_scene_graph.GetMyMutableContextFromRoot(trash_context)
        query_object = trash_scene_graph.get_query_output_port().Eval(trash_scene_graph_context)
        
        iterate = True
        counter = 0
        body_tfs = {}
        while iterate:
            for body_index in trash_plant.GetFloatingBaseBodies():
                body = trash_plant.get_body(body_index)
                if body.name() in item_names:
                    tf = RigidTransform(
                            UniformlyRandomRotationMatrix(generator),
                            [0.75*np.random.rand() - 0.375, 0.16*np.random.rand() - 0.08 -.6, .425])
                    trash_plant.SetFreeBodyPose(trash_plant_context, body, tf)
                    body_tfs[body.name()] = tf
                    
            iterate = query_object.HasCollisions()
            counter += 1
            if counter > 30:
                print("Large amount of consecutive failures, stopping...")
                break
        
        plant_context = self.plant.GetMyMutableContextFromRoot(self.context)
        if not query_object.HasCollisions():
            print(f"Objects randomized successfully after {counter} tries")
            for body_index in self.plant.GetFloatingBaseBodies():
                body = self.plant.get_body(body_index)
                if body.name() in item_names:
                    self.plant.SetFreeBodyPose(plant_context, body, body_tfs[body.name()])
        
    def Simulate(self, t):
        
        # Simulator
        simulator = Simulator(self.diagram, self.context)
        simulator.set_target_realtime_rate(1.0)
        self.visualizer.StartRecording()
        simulator.AdvanceTo(t)
        self.visualizer.PublishRecording()
        
    def ShowDiagram(self):
        
        SVG(pydot.graph_from_dot_data(self.diagram.GetGraphvizString())[0].create_svg())
        
    def GetIms(self):
        
        station_context = self.diagram.GetMutableSubsystemContext(self.station, self.context)
        im0 = self.station.GetOutputPort("camera0_rgb_image").Eval(station_context).data
        im1 = self.station.GetOutputPort("camera1_rgb_image").Eval(station_context).data
        
        return [im0, im1]
    
    def GetLabelIms(self):
        # Still need to do .squeeze() on each image output to view in matplotlib
        station_context = self.diagram.GetMutableSubsystemContext(self.station, self.context)
        im0 = self.station.GetOutputPort("camera0_label_image").Eval(station_context).data
        im1 = self.station.GetOutputPort("camera1_label_image").Eval(station_context).data
        
        return [im0, im1]
    
 
meshcat.Delete()
iiwa = IIWA()

Objects randomized successfully after 2 tries


In [68]:
import json
import matplotlib.pyplot as plt
from PIL import Image
import os
import shutil
import warnings

from manipulation.utils import colorize_labels

debug = False
path = '/tmp/clutter_maskrcnn_data'
num_images = 10

if debug:
    plt.rcParams["figure.figsize"] = (15,30)
    fig_d, (ax1_d, ax2_d) = plt.subplots(1,2)
    fig_l, (ax1_l, ax2_l) = plt.subplots(1,2)

if not debug:
    if os.path.exists(path):
        shutil.rmtree(path)
    os.makedirs(path)
    print(f'Creating dataset in {path} with {num_images} images')

def generate_image(image_num):
    filename_base_0 = os.path.join(path, f"{2*image_num:04d}")
    filename_base_1 = os.path.join(path, f"{2*image_num+1:04d}")

    inspector = iiwa.station.GetSubsystemByName("scene_graph").model_inspector()

    instance_id_to_class_name = {}

    for body_index in iiwa.plant.GetFloatingBaseBodies():
        body = iiwa.plant.get_body(body_index)
        if body.name() in item_names:
            frame_id = iiwa.plant.GetBodyFrameIdOrThrow(body_index)
            geometry_ids = inspector.GetGeometries(frame_id, Role.kPerception)
            for geom_id in geometry_ids:
                instance_id_to_class_name[int(
                    inspector.GetPerceptionProperties(geom_id).GetProperty(
                        "label", "id"))] = body.name().split('_')[0]

    if not debug:
        with open(filename_base_0 + ".json", "w") as f:
            json.dump(instance_id_to_class_name, f)
        with open(filename_base_1 + ".json", "w") as f:
            json.dump(instance_id_to_class_name, f)

    # Randomize trash pose. Do we need to republish?
    iiwa.RandomizeTrash()
#     simulator = Simulator(iiwa.diagram, iiwa.context)
#     simulator.AdvanceTo(2)
        
    [rgb0, rgb1] = iiwa.GetIms()
    [label0, label1] = iiwa.GetLabelIms()

    if debug:
        ax1_d.imshow(rgb0)
        ax2_d.imshow(rgb1)
        ax1_l.imshow(colorize_labels(label0))
        ax2_l.imshow(colorize_labels(label1))
    else:
        Image.fromarray(rgb0).save(f"{filename_base_0}.png")
        Image.fromarray(rgb1).save(f"{filename_base_1}.png")
        np.save(f"{filename_base_0}_mask", label0)
        np.save(f"{filename_base_1}_mask", label1)

        
for image_num in range(int(num_images/2)):
    generate_image(image_num)


Creating dataset in /tmp/clutter_maskrcnn_data with 10 images
Objects randomized successfully after 1 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 2 tries
Objects randomized successfully after 2 tries
Objects randomized successfully after 1 tries
