In [6]:
from enum import Enum
import numpy as np
import os
from pydrake.all import (StartMeshcat, MeshcatVisualizer, DiagramBuilder, Parser, ConstantVectorSource,
                        Simulator, LeafSystem, RigidTransform, RotationMatrix, ProcessModelDirectives,
                        UniformlyRandomRotationMatrix, RandomGenerator, Rgba, LoadModelDirectivesFromString,
                        PiecewisePose, TrajectorySource, Integrator, JacobianWrtVariable, AbstractValue, 
                        PointCloud, ImageRgba8U, ImageDepth32F, Concatenate, AddMultibodyPlantSceneGraph, 
                        MeshcatVisualizerParams, Role, RollPitchYaw, BaseField, Fields)
# from manipulation.clutter import GenerateAntipodalGraspCandidate
from manipulation.meshcat_utils import AddMeshcatTriad
from manipulation import FindResource
from manipulation.scenarios import (AddIiwaDifferentialIK, MakeManipulationStation, AddPackagePaths)

import torch
import torch.utils.data
import torchvision
import torchvision.transforms.functional as Tf
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights

# For diagram viz
import matplotlib.pyplot as plt
from IPython.display import SVG
import pydot

In [7]:
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7001


In [14]:
rs = np.random.RandomState()  # this is for python
generator = RandomGenerator(rs.randint(1000)) # this is for c++

path = os.getcwd()
TRASH_YAML = path + '/models/trash_model.dmd.yaml'
INTERNAL_YAML = path + "/models/internal_model.dmd.yaml"
MODEL_YAML = path + "/models/recycling.dmd.yaml"
MODEL_PATH = 'recycling_maskrcnn_model.pt'

q0 = [-1.57, -0.1, 0, -1.4, 0, 1.6, 0]
X_WHome = RigidTransform(
            RotationMatrix([
                [1, 0, 0],
                [0, 0, 1],
                [0, -1, 0]
            ]),
            [0, -0.5, 0.65])

ITEM_NAMES = ["cola_can", "bottle", "Banana", "Orange", "coffee"]

get_garbage_type = {"cola": GarbageType.RECYCLE, "bottle": GarbageType.RECYCLE, 
                   "Banana": GarbageType.ORGANIC, "Orange": GarbageType.ORGANIC,
                   "coffee": GarbageType.TRASH}

class GarbageType(Enum):
    TRASH = 0
    RECYCLE = 1
    ORGANIC = 2

    
# Contains iiwa, bins, table, floor
def make_internal_model():
    station = MakeManipulationStation(
        filename=FindResource(INTERNAL_YAML),
        package_xmls=["./package.xml"])
    return station


# Contains table & trash
def make_trash_model():
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)
    parser = Parser(plant)
    parser.package_map().AddPackageXml('./package.xml')
    AddPackagePaths(parser)
    parser.AddAllModelsFromFile(TRASH_YAML)
    plant.Finalize()
    return builder.Build()
    

class PseudoInverseController(LeafSystem):
    def __init__(self, plant):
        LeafSystem.__init__(self)
        self._plant = plant
        self._plant_context = plant.CreateDefaultContext()
        self._iiwa = plant.GetModelInstanceByName("iiwa")
        self._G = plant.GetBodyByName("body").body_frame()
        self._W = plant.world_frame()

        self.V_G_port = self.DeclareVectorInputPort("V_WG", 6)
        self.q_port = self.DeclareVectorInputPort("iiwa_position", 7)
        self.DeclareVectorOutputPort("iiwa_velocity", 7, self.CalcOutput)
        self.iiwa_start = plant.GetJointByName("iiwa_joint_1").velocity_start()
        self.iiwa_end = plant.GetJointByName("iiwa_joint_7").velocity_start()

    def CalcOutput(self, context, output):
        V_G = self.V_G_port.Eval(context)
        q = self.q_port.Eval(context)
        self._plant.SetPositions(self._plant_context, self._iiwa, q)
        J_G = self._plant.CalcJacobianSpatialVelocity(
            self._plant_context, JacobianWrtVariable.kV,
            self._G, [0,0,0], self._W, self._W)
        J_G = J_G[:,self.iiwa_start:self.iiwa_end+1] # Only iiwa terms.
        v = np.linalg.pinv(J_G).dot(V_G)
        output.SetFromVector(v)
        
        
class Vision(LeafSystem):
    def __init__(self, station, camera_body_indices):
        LeafSystem.__init__(self)
        
        rgb_image = AbstractValue.Make(ImageRgba8U(640,480))
        depth_image = AbstractValue.Make(ImageDepth32F(640,480))
        point_cloud = AbstractValue.Make(PointCloud(0))
        self.DeclareAbstractInputPort("depth0", depth_image)
        self.DeclareAbstractInputPort("depth1", depth_image)
        self.DeclareAbstractInputPort("rgb0", rgb_image)
        self.DeclareAbstractInputPort("rgb1", rgb_image)
        self.DeclareAbstractInputPort(
            "body_poses", AbstractValue.Make([RigidTransform()]))
        
        self.DeclareAbstractOutputPort("point_cloud_W", 
            lambda: AbstractValue.Make((0, ITEM_NAMES[0], point_cloud)), 
            self.SendSegmentedCloud)
            
        # Crop box for area of interest
        self._crop_lower = np.array([-.5, -.75, .39])
        self._crop_upper = np.array([.5, -.45, .475])
        
        self._camera_body_indices = camera_body_indices
        self.cam_info_0 = station.GetSubsystemByName('camera0').depth_camera_info()
        self.cam_info_1 = station.GetSubsystemByName('camera0').depth_camera_info()
            
        ## Load model
        num_classes = len(ITEM_NAMES) + 1
        self.model = self.load_model(num_classes)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device))
        self.model.eval()
        self.model.to(self.device)
        
    def load_model(self, num_classes):
        
        # load an instance segmentation model pre-trained on COCO
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(
            weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT, progress=False)

        # get the number of input features for the classifier
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        # replace the pre-trained head with a new one
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

        # now get the number of input features for the mask classifier
        in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
        hidden_layer = 256
        # and replace the mask predictor with a new one
        model.roi_heads.mask_predictor = MaskRCNNPredictor(
            in_features_mask, hidden_layer, num_classes)
        
        return model
        
    def SendSegmentedCloud(self, context, output):
        
        body_poses = self.GetInputPort("body_poses").Eval(context)
        rgb0 = self.GetInputPort("rgb0").Eval(context).data
        rgb1 = self.GetInputPort("rgb1").Eval(context).data
        depth0 = self.GetInputPort("depth0").Eval(context).data
        depth1 = self.GetInputPort("depth1").Eval(context).data
                
        # Put through deep model
        with torch.no_grad():
            predictions = []
            predictions.append(
                self.model([Tf.to_tensor(rgb0[:, :, :3]).to(self.device)]))
            predictions.append(
                self.model([Tf.to_tensor(rgb1[:, :, :3]).to(self.device)]))
        for i in range(2):
            for k in predictions[i][0].keys():
                if k == "masks":
                    predictions[i][0][k] = predictions[i][0][k].mul(
                        255).byte().cpu().numpy()
                else:
                    predictions[i][0][k] = predictions[i][0][k].cpu().numpy()
        
        X_WCs = []
        for idx in self._camera_body_indices:
            X_WCs.append(body_poses[idx])
        
        score, obj_idx, cloud = self.get_merged_masked_pcd(predictions, [rgb0, rgb1], 
           [depth0, depth1], [self.project_depth_to_pC, self.project_depth_to_pC], 
           X_WCs, [self.cam_info_0, self.cam_info_1])
        
        cloud.Crop(self._crop_lower, self._crop_upper)
        
        output.set_value((score, ITEM_NAMES[obj_idx], cloud))
        
    def get_merged_masked_pcd(self, predictions, rgb_ims, depth_ims, 
          project_depth_to_pC_funcs, X_WCs, cam_infos, mask_threshold=150):
        """
        predictions: The output of the trained network (one for each camera)
        rgb_ims: RGBA images from each camera
        depth_ims: Depth images from each camera
        project_depth_to_pC_funcs: Functions that perform the pinhole camera operations to convert pixels
            into points. See the analogous function in problem 5.2 to see how to use it.
        X_WCs: Poses of the cameras in the world frame
        """
        # Let's focus on the maximal confidence object
        # Limitation: Assumes object uniqueness
        scores = {}
        for obj_idx in range(len(ITEM_NAMES)):
            combined_score = 0
            for p in predictions:
                if obj_idx not in p[0]['labels']:
                    continue
                combined_score += np.max(
                    p[0]['scores'][p[0]['labels'] == obj_idx])
            if combined_score > 0:
                scores[obj_idx] = combined_score
            
        print(scores)
        if not bool(scores):
            return 0, -1, PointCloud(0)
        
        obj_idx = max(scores, key=scores.get)
        print("Decided to pick up ", ITEM_NAMES[obj_idx])
        
        pcd = []
        for prediction, rgb_im, depth_im, project_depth_to_pC_func, X_WC, cam_info in \
                zip(predictions, rgb_ims, depth_ims, project_depth_to_pC_funcs, X_WCs, cam_infos):

            obj_masks = prediction[0]['masks'][prediction[0]['labels'] == obj_idx]
            mask = obj_masks[0,0]
            idx = np.where(mask >= mask_threshold)
            depth_pts = np.column_stack((idx[0], idx[1], depth_im[idx[0], idx[1]]))
            p_C_obj = project_depth_to_pC_func(depth_pts, cam_info)
            spatial_points = X_WC @ p_C_obj.T
            rgb_points = rgb_im[idx[0], idx[1], 0:3].T

            # You get an unhelpful RunTime error if your arrays are the wrong
            # shape, so we'll check beforehand that they're the correct shapes.
            assert len(spatial_points.shape
                      ) == 2, "Spatial points is the wrong size -- should be 3 x N"
            assert spatial_points.shape[
                0] == 3, "Spatial points is the wrong size -- should be 3 x N"
            assert len(rgb_points.shape
                      ) == 2, "RGB points is the wrong size -- should be 3 x N"
            assert rgb_points.shape[
                0] == 3, "RGB points is the wrong size -- should be 3 x N"
            assert rgb_points.shape[1] == spatial_points.shape[1]

            N = spatial_points.shape[1]
            pcd.append(PointCloud(N, Fields(BaseField.kXYZs | BaseField.kRGBs)))
            pcd[-1].mutable_xyzs()[:] = spatial_points
            pcd[-1].mutable_rgbs()[:] = rgb_points
            # Estimate normals
            pcd[-1].EstimateNormals(radius=0.1, num_closest=30)
            # Flip normals toward camera
            pcd[-1].FlipNormalsTowardPoint(X_WC.translation())

        # Merge point clouds.
        merged_pcd = Concatenate(pcd)
        
        # Get the prediciton score
        avg_score = scores[obj_idx] / len(rgb_ims)

        # Voxelize down-sample.  (Note that the normals still look reasonable)
        return avg_score, obj_idx, merged_pcd.VoxelizedDownSample(voxel_size=0.005)
    
    def project_depth_to_pC(self, depth_pixel, cam_info):
        """
        project depth pixels to points in camera frame
        using pinhole camera model
        Input:
            depth_pixels: numpy array of (nx3) or (3,)
        Output:
            pC: 3D point in camera frame, numpy array of (nx3)
        """
        # switch u,v due to python convention
        v = depth_pixel[:,0]
        u = depth_pixel[:,1]
        Z = depth_pixel[:,2]
        cx = cam_info.center_x()
        cy = cam_info.center_y()
        fx = cam_info.focal_x()
        fy = cam_info.focal_y()
        X = (u-cx) * Z/fx
        Y = (v-cy) * Z/fy
        pC = np.c_[X,Y,Z]
        return pC
        
            
class GraspSelector(LeafSystem):
    def __init__(self):
        LeafSystem.__init__(self)
        
        point_cloud = AbstractValue.Make(PointCloud(0))
        self.DeclareAbstractInputPort("point_cloud_W", 
            AbstractValue.Make((0, ITEM_NAMES[0], point_cloud)))
        
        port = self.DeclareAbstractOutputPort(
            "grasp_selection", 
            lambda: AbstractValue.Make((np.inf, RigidTransform(), ITEM_NAMES[0])), 
            self.SelectGrasp)
        port.disable_caching_by_default()
        
        self._internal_model = make_internal_model()
        self._internal_model_context = self._internal_model.CreateDefaultContext()
        self._rng = np.random.default_rng()
        self.X_WHome = X_WHome
        
    # Taken from manipulation/clutter.py - slightly modified
    def GraspCandidateCost(self, diagram, context, cloud, wsg_body_index=None, plant_system_name="plant",
                           scene_graph_system_name="scene_graph", adjust_X_G=False):
        plant = diagram.GetSubsystemByName(plant_system_name)
        plant_context = plant.GetMyMutableContextFromRoot(context)
        scene_graph = diagram.GetSubsystemByName(scene_graph_system_name)
        scene_graph_context = scene_graph.GetMyMutableContextFromRoot(context)
        if wsg_body_index:
            wsg = plant.get_body(wsg_body_index)
        else:
            wsg = plant.GetBodyByName("body")
            wsg_body_index = wsg.index()

        X_G = plant.GetFreeBodyPose(plant_context, wsg)

        # Transform cloud into gripper frame
        X_GW = X_G.inverse()
        p_GC = X_GW @ cloud.xyzs()

        # Crop to a region inside of the finger box.
        crop_min = [-.05, 0.1, -0.00625]
        crop_max = [.05, 0.1125, 0.00625]
        indices = np.all((crop_min[0] <= p_GC[0, :], p_GC[0, :] <= crop_max[0],
                          crop_min[1] <= p_GC[1, :], p_GC[1, :] <= crop_max[1],
                          crop_min[2] <= p_GC[2, :], p_GC[2, :] <= crop_max[2]),
                         axis=0)

        if adjust_X_G and np.sum(indices) > 0:
            p_GC_x = p_GC[0, indices]
            p_Gcenter_x = (p_GC_x.min() + p_GC_x.max()) / 2.0
            X_G.set_translation(X_G @ np.array([p_Gcenter_x, 0, 0]))
            plant.SetFreeBodyPose(plant_context, wsg, X_G)
            X_GW = X_G.inverse()

        query_object = scene_graph.get_query_output_port().Eval(scene_graph_context)

        # Check collisions between the gripper and the sink
        if query_object.HasCollisions():
            cost = np.inf
            return cost

        # Check collisions between the gripper and the point cloud. `margin`` must
        # be smaller than the margin used in the point cloud preprocessing.
        margin = 0.0
        for i in range(cloud.size()):
            distances = query_object.ComputeSignedDistanceToPoint(cloud.xyz(i),
                                                                  threshold=margin)
            if distances:
                cost = np.inf
                return cost

        n_GC = X_GW.rotation().multiply(cloud.normals()[:, indices])

        # Reward sum |dot product of normals with gripper x|^2
        cost = -np.sum(n_GC[0, :]**2)
        return cost
        
    def GenerateAntipodalGraspCandidate(self, diagram, context, cloud, rng, wsg_body_index=None,
                                    plant_system_name="plant", scene_graph_system_name="scene_graph"):
        """
        Picks a random point in the cloud, and aligns the robot finger with the
        x/y projection of the normal of that pixel. Perturbs z a little to find better grasps.
        """
        plant = diagram.GetSubsystemByName(plant_system_name)
        plant_context = plant.GetMyMutableContextFromRoot(context)
        scene_graph = diagram.GetSubsystemByName(scene_graph_system_name)
        scene_graph_context = scene_graph.GetMyMutableContextFromRoot(context)
        if wsg_body_index:
            wsg = plant.get_body(wsg_body_index)
        else:
            wsg = plant.GetBodyByName("body")
            wsg_body_index = wsg.index()

        if cloud.size() < 1:
            return np.inf, None

        index = rng.integers(0, cloud.size() - 1)

        # Use S for sample point/frame.
        p_WS = cloud.xyz(index)
        n_WS = cloud.normal(index)

        assert np.isclose(np.linalg.norm(n_WS),
                          1.0), f"Normal has magnitude: {np.linalg.norm(n_WS)}"

        # Modification: always keep gripper y aligned with world -z
        Gx = np.array([n_WS[0], n_WS[1], 0]) # Project onto XY plane
        Gx = Gx / np.linalg.norm(Gx)
        Gy = np.array([0.0, 0.0, -1.0]) # World downward
        Gz = np.cross(Gx, Gy)
        R_WG = RotationMatrix(np.vstack((Gx, Gy, Gz)).T)
        p_GS_G = [0.054 - 0.01, 0.1, 0] # Position of sample end wrt gripper
        p_WG = p_WS - R_WG @ p_GS_G

        # Try vertical perturbations
        min_z = -0.02
        max_z = 0.02
        for z in np.linspace(min_z, max_z, num=5):
            p_WG_2 = p_WG + np.array([0, 0, z])
            X_G = RigidTransform(R_WG, p_WG_2)
            plant.SetFreeBodyPose(plant_context, wsg, X_G)
            cost = self.GraspCandidateCost(diagram, context, cloud, adjust_X_G=True)
            X_G = plant.GetFreeBodyPose(plant_context, wsg)
            if np.isfinite(cost):
                return cost, X_G

        return np.inf, None
        
    def SelectGrasp(self, context, output):
        
        score, item_selection, down_sampled_pcd = self.GetInputPort("point_cloud_W").Eval(context)
        if score < 0.75:
            output.set_value((np.inf, self.X_WHome, item_selection))
            return

        costs = []
        X_Gs = []
        for i in range(25):
            cost, X_G = self.GenerateAntipodalGraspCandidate(
                self._internal_model, self._internal_model_context,
                down_sampled_pcd, self._rng)
            if np.isfinite(cost):
                costs.append(cost)
                X_Gs.append(X_G)

        if len(costs) == 0:
            # Didn't find a viable grasp candidate
            output.set_value((np.inf, self.X_WHome))
        else:
            best = np.argmin(costs)
            output.set_value((costs[best], X_Gs[best], item_selection))
            

class PlannerState(Enum):
    DEBUG = -1
    WAIT = 0
    SELECT_GRASP = 1
    PICK = 2
    PLACE = 3
    HOME = 4
    CLOSE = 5
    OPEN = 6
            
            
class Planner(LeafSystem):
    def __init__(self, plant):
        LeafSystem.__init__(self)
        
        self._gripper_body_index = plant.GetBodyByName("body").index()
        self.DeclareAbstractInputPort("body_poses", AbstractValue.Make([RigidTransform()]))
        self._grasp_index = self.DeclareAbstractInputPort("grasp_selection", 
            AbstractValue.Make((np.inf, RigidTransform(), ITEM_NAMES[0]))).get_index()
        self._wsg_state_index = self.DeclareVectorInputPort("wsg_state", 2).get_index()
        self.DeclareVectorInputPort("iiwa_position", 7)
        
        self.X_WHome = X_WHome
        self.X_WRecycle = RigidTransform(
            RotationMatrix([
                [0, 0, -1],
                [1, 0, 0],
                [0, -1, 0]
            ]),
            [0.4, -0.25, 0.65])
        self.X_WTrash = RigidTransform(
            RotationMatrix([
                [0, 0, -1],
                [1, 0, 0],
                [0, -1, 0]
            ]),
            [0.4, -.1, 0.65])
        self.X_WOrganic = RigidTransform(
            RotationMatrix([
                [0, 0, -1],
                [1, 0, 0],
                [0, -1, 0]
            ]),
            [0.4, 0.05, 0.65])
        
        self.eps = 1e-4
        self.init = True
        self.prev_time = None
        self.mode = PlannerState.WAIT
        self.wsg_des = np.array([0.107])
        self.traj_WG = PiecewisePose.MakeLinear([0, np.inf], [RigidTransform(), RigidTransform()])
        self.garbage_type = GarbageType.RECYCLE
        
        self.DeclareVectorOutputPort("wsg_position", 1,
            lambda context, output: output.SetFromVector([self.wsg_des]))
        self.DeclareVectorOutputPort("V_WG", 6, self.SendTrajV)
        
        self.DeclarePeriodicUnrestrictedUpdateEvent(0.1, 0.0, self.Update)
        
    def Update(self, context, state):
        if self.mode == PlannerState.DEBUG:
            return
        
        if self.init:
            self.prev_time = context.get_time()
            self.init = False
            return
        
        if self.mode == PlannerState.WAIT:
            if context.get_time() - self.prev_time > 2.:
                self.mode = PlannerState.SELECT_GRASP
            
        if self.mode == PlannerState.SELECT_GRASP:
            print("Selecting grasp")
            [cost, X_WD, item_selection] = self.get_input_port(self._grasp_index).Eval(context)
            self.garbage_type = get_garbage_type[item_selection]
            print("Cost: ", cost, "\n", "TF: ", X_WD)
            if cost == np.inf:
                self.mode = PlannerState.WAIT
                self.prev_time = context.get_time()
            else:
                self.traj_WG = self.MakePickingTraj(context, X_WD)
                self.mode = PlannerState.PICK
                
        if self.mode == PlannerState.PICK:
            if self.TrajDone(context):
                self.mode = PlannerState.CLOSE
                self.wsg_des = np.array([0.0])
            # Handle error cases
            
        # Need a better way to determine when it's closed.
        if self.mode == PlannerState.CLOSE:
            if context.get_time() - self.prev_time > 2.:
                self.traj_WG = self.MakePlacingTraj(context, self.garbage_type)
                self.mode = PlannerState.PLACE
#             wsg_state = self.get_input_port(self._wsg_state_index).Eval(context)
#             if wsg_state[0] < 0.01:
            
        if self.mode == PlannerState.PLACE:
            if self.TrajDone(context):
                self.mode = PlannerState.OPEN
                self.wsg_des = np.array([0.107])
            # Handle error cases
            
        if self.mode == PlannerState.OPEN:
            wsg_state = self.get_input_port(self._wsg_state_index).Eval(context)
            if wsg_state[0] > 0.09:
                self.traj_WG = self.MakeHomeTraj(context)
                self.mode = PlannerState.HOME
                
        if self.mode == PlannerState.HOME:
            if self.TrajDone(context):
                self.mode = PlannerState.SELECT_GRASP
            

    def TrajDone(self, context):
        """
        Checks if we have completed our desired trajectory
            - Time elapsed
            - We're at the desired end pose
        If time is up and we're no longer moving, but we haven't 
            gone to our desired location, then make a micro-adjustment.
        """
        time = context.get_time()
        X_WG = self.get_input_port(0).Eval(context)[int(self._gripper_body_index)]
        X_WEnd = self.traj_WG.GetPose(np.inf)
        if X_WG.IsNearlyEqualTo(X_WEnd, self.eps):
            return True
        elif self.traj_WG.end_time() < time:
            self.traj_WG = PiecewisePose.MakeLinear([time, time+0.25], [X_WG, X_WEnd])
        
        return False

    def MakePickingTraj(self, context, X_WD):
        """
        Creates a basic picking trajectory
        """
        time = context.get_time()
        X_WG = self.get_input_port(0).Eval(context)[int(self._gripper_body_index)]
        X_WPrepick = RigidTransform(RotationMatrix(), [0, 0, 0.1]) @ X_WD
        traj_WG = PiecewisePose.MakeLinear([time, time+2, time+4], [X_WG, X_WPrepick, X_WD])
        
        return traj_WG
    
    def MakePlacingTraj(self, context, garbage_type):
        """
        Creates a placing trajectory to either trash, recycle, or organic bin
        """
        time = context.get_time()
        X_WG = self.get_input_port(0).Eval(context)[int(self._gripper_body_index)]
        if garbage_type == GarbageType.TRASH:
            traj_WG = PiecewisePose.MakeLinear([time, time+2, time+6], [X_WG, self.X_WHome, self.X_WTrash])
        elif garbage_type == GarbageType.RECYCLE:
            traj_WG = PiecewisePose.MakeLinear([time, time+2, time+6], [X_WG, self.X_WHome, self.X_WRecycle])
        elif garbage_type == GarbageType.ORGANIC:
            traj_WG = PiecewisePose.MakeLinear([time, time+2, time+6], [X_WG, self.X_WHome, self.X_WOrganic])
            
        return traj_WG
    
    def MakeHomeTraj(self, context):
        """
        Creates a basic trajectory for going to home position
        """
        time = context.get_time()
        X_WG = self.get_input_port(0).Eval(context)[int(self._gripper_body_index)]
        traj_WG = PiecewisePose.MakeLinear([time, time+4], [X_WG, self.X_WHome])
        
        return traj_WG

    def SendTrajV(self, context, output):
        """
        Callback for evaluating V_WG. Makes derivative of trajectory and publishes
        the V_WG corresponding to current time.
        """
        time = context.get_time()
        V_WG = self.traj_WG.GetVelocity(time)
        output.SetFromVector(V_WG)

        
class IIWA(LeafSystem):
    def __init__(self):
        LeafSystem.__init__(self)
            
        # Setup diagram builder components
        builder = DiagramBuilder()
        self.station = MakeManipulationStation(
            filename=FindResource(MODEL_YAML),
            package_xmls=["./package.xml"])
        builder.AddSystem(self.station)
        self.plant = self.station.GetSubsystemByName("plant")
        self.visualizer = MeshcatVisualizer.AddToBuilder(
            builder, self.station.GetOutputPort("query_object"), meshcat)  
#         self.collision = MeshcatVisualizer.AddToBuilder(
#             builder, self.station.GetOutputPort("query_object"), meshcat,
#             MeshcatVisualizerParams(role=Role.kProximity, prefix="collision"))

        # Add vision system
        vision = builder.AddNamedSystem("vision", 
            Vision(self.station,
                   camera_body_indices=[
                      self.plant.GetBodyIndices(
                          self.plant.GetModelInstanceByName("camera0"))[0],
                      self.plant.GetBodyIndices(
                          self.plant.GetModelInstanceByName("camera1"))[0]
                  ]))
        builder.Connect(self.station.GetOutputPort("camera0_depth_image"),
                        vision.get_input_port(0))
        builder.Connect(self.station.GetOutputPort("camera1_depth_image"),
                        vision.get_input_port(1))
        builder.Connect(self.station.GetOutputPort("camera0_rgb_image"),
                        vision.get_input_port(2))
        builder.Connect(self.station.GetOutputPort("camera1_rgb_image"),
                        vision.get_input_port(3))
        builder.Connect(self.station.GetOutputPort("body_poses"),
                        vision.GetInputPort("body_poses"))
        
        # Add grasp selector
        grasp_selector = builder.AddNamedSystem("grasp_selector", GraspSelector())
        builder.Connect(vision.GetOutputPort("point_cloud_W"),
            grasp_selector.GetInputPort("point_cloud_W"))
        
        # Planner
        planner = builder.AddNamedSystem("planner", Planner(self.plant))
        builder.Connect(self.station.GetOutputPort("body_poses"),
                    planner.GetInputPort("body_poses"))
        builder.Connect(grasp_selector.GetOutputPort("grasp_selection"),
                        planner.GetInputPort("grasp_selection"))
        builder.Connect(self.station.GetOutputPort("wsg_state_measured"),
                        planner.GetInputPort("wsg_state"))
        builder.Connect(self.station.GetOutputPort("iiwa_position_measured"),
                        planner.GetInputPort("iiwa_position"))
        
        # Add pseudo-inverse controller
        self.controller = builder.AddSystem(PseudoInverseController(self.plant))
        self.controller.set_name("PseudoInverseController")
        builder.Connect(planner.GetOutputPort("V_WG"), self.controller.GetInputPort("V_WG"))
        # Integrate controller velocity commands to get joint angles
        self.integrator = builder.AddSystem(Integrator(7))
        self.integrator.set_name("integrator")
        builder.Connect(self.controller.get_output_port(),
                        self.integrator.get_input_port())
        builder.Connect(self.integrator.get_output_port(),
                        self.station.GetInputPort("iiwa_position"))
        builder.Connect(self.station.GetOutputPort("iiwa_position_measured"),
                        self.controller.GetInputPort("iiwa_position"))
        
        # Gripper control
        builder.Connect(planner.GetOutputPort("wsg_position"),
                    self.station.GetInputPort("wsg_position"))
        
        # Finalize
        self.diagram = builder.Build()
        self.context = self.diagram.CreateDefaultContext()
        
        # Set current position
        self.integrator.set_integral_value(
            self.integrator.GetMyMutableContextFromRoot(self.context), q0)
            
        # Randomize poses of trash
        self._trash_model = make_trash_model()
        self.RandomizeTrash()
        
        # Publish context
        self.diagram.Publish(self.context)
        
    def RandomizeTrash(self):
        
        trash_context = self._trash_model.CreateDefaultContext()
        trash_plant = self._trash_model.GetSubsystemByName("plant")
        trash_plant_context = trash_plant.GetMyMutableContextFromRoot(trash_context)
        trash_scene_graph = self._trash_model.GetSubsystemByName("scene_graph")
        trash_scene_graph_context = trash_scene_graph.GetMyMutableContextFromRoot(trash_context)
        query_object = trash_scene_graph.get_query_output_port().Eval(trash_scene_graph_context)
        
        iterate = True
        counter = 0
        body_tfs = {}
        while iterate:
            for body_index in trash_plant.GetFloatingBaseBodies():
                body = trash_plant.get_body(body_index)
                if body.name() in ITEM_NAMES:
                    tf = RigidTransform(
                            UniformlyRandomRotationMatrix(generator),
                            [0.75*np.random.rand() - 0.375, 0.16*np.random.rand() - 0.08 -.6, .44])
                    trash_plant.SetFreeBodyPose(trash_plant_context, body, tf)
                    body_tfs[body.name()] = tf
                    
            iterate = query_object.HasCollisions()
            counter += 1
            if counter > 40:
                print("Large amount of consecutive failures, stopping...")
                break
        
        plant_context = self.plant.GetMyMutableContextFromRoot(self.context)
        if not query_object.HasCollisions():
            print(f"Objects randomized successfully after {counter} tries")
            for body_index in self.plant.GetFloatingBaseBodies():
                body = self.plant.get_body(body_index)
                if body.name() in ITEM_NAMES:
                    self.plant.SetFreeBodyPose(plant_context, body, body_tfs[body.name()])
        
    def Simulate(self, t):
        
        # Simulator
        simulator = Simulator(self.diagram, self.context)
        simulator.set_target_realtime_rate(1.0)
        self.visualizer.StartRecording()
        simulator.AdvanceTo(t)
        self.visualizer.PublishRecording()
        
    def GetIms(self):
        
        station_context = self.diagram.GetMutableSubsystemContext(self.station, self.context)
        im0 = self.station.GetOutputPort("camera0_rgb_image").Eval(station_context).data
        im1 = self.station.GetOutputPort("camera1_rgb_image").Eval(station_context).data
        
        return [im0, im1]
    
    def GetLabelIms(self):
        # Still need to do .squeeze() on each image output to view in matplotlib
        station_context = self.diagram.GetMutableSubsystemContext(self.station, self.context)
        im0 = self.station.GetOutputPort("camera0_label_image").Eval(station_context).data
        im1 = self.station.GetOutputPort("camera1_label_image").Eval(station_context).data
        
        return [im0, im1]
    
 
meshcat.Delete()
iiwa = IIWA()

Objects randomized successfully after 1 tries


In [15]:
iiwa.Simulate(25)

Selecting grasp
{1: 1.960726261138916, 2: 1.8075592517852783, 3: 1.9633259773254395, 4: 1.9030123949050903}
Decided to pick up  Orange
Cost:  -20.021505017616207 
 TF:  RigidTransform(
  R=RotationMatrix([
    [-0.6884053149133325, 0.0, 0.725326217917894],
    [-0.725326217917894, -2.220446049250313e-16, -0.6884053149133322],
    [0.0, -1.0000000000000002, -2.220446049250313e-16],
  ]),
  p=[-0.048446556476337416, -0.6231942325865776, 0.5024986624717712],
)
Selecting grasp
{1: 1.959364891052246, 2: 1.8260812759399414, 3: 0.12109408527612686, 4: 1.9210864305496216}
Decided to pick up  bottle
Cost:  -17.872753851489573 
 TF:  RigidTransform(
  R=RotationMatrix([
    [-0.7230504175540644, 0.0, 0.6907952617634932],
    [-0.6907952617634932, 0.0, -0.7230504175540644],
    [0.0, -1.0, 0.0],
  ]),
  p=[-0.3007609178526606, -0.6018996455035304, 0.515854285955429],
)
