In [1]:
from enum import Enum
import numpy as np
import os
import pathlib
from pydrake.all import (StartMeshcat, MeshcatVisualizer, DiagramBuilder, Parser, ConstantVectorSource,
                        Simulator, LeafSystem, RigidTransform, RotationMatrix, UniformlyRandomRotationMatrix, 
                        RandomGenerator, AddMultibodyPlantSceneGraph, Role)
from manipulation import ConfigureParser
from manipulation.scenarios import MakeManipulationStation

In [2]:
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7001


In [3]:
rs = np.random.RandomState()  # this is for python
generator = RandomGenerator(rs.randint(1000)) # this is for c++

path = os.getcwd()
TRASH_YAML = pathlib.Path(path + "/models/trash_model.dmd.yaml").as_uri()
INTERNAL_YAML = pathlib.Path(path + "/models/internal_model.dmd.yaml").as_uri()
MODEL_YAML = pathlib.Path(path + "/models/recycling.dmd.yaml").as_uri()
MODEL_PATH = 'recycling_maskrcnn_model.pt'

q0 = [-1.57, -0.1, 0, -1.4, 0, 1.6, 0]
X_WHome = RigidTransform(
            RotationMatrix([
                [1, 0, 0],
                [0, 0, 1],
                [0, -1, 0]
            ]),
            [0, -0.5, 0.65])

ITEM_NAMES = ["bottle", "Banana", "coffee"]

class GarbageType(Enum):
    TRASH = 0
    RECYCLE = 1
    ORGANIC = 2
    
get_garbage_type = {"bottle": GarbageType.RECYCLE, 
                   "Banana": GarbageType.ORGANIC, 
                   "coffee": GarbageType.TRASH}

# Contains iiwa, bins, table, floor
def make_internal_model():
    station = MakeManipulationStation(
        filename=INTERNAL_YAML,
        package_xmls=["./package.xml"])
    return station

# Contains table & trash
def make_trash_model():
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)
    parser = Parser(plant)
    parser.package_map().AddPackageXml('./package.xml')
    ConfigureParser(parser)
    parser.AddModelsFromUrl(TRASH_YAML)
    plant.Finalize()
    return builder.Build()

class IIWA(LeafSystem):
    def __init__(self):
        LeafSystem.__init__(self)
            
        # Setup diagram builder components
        builder = DiagramBuilder()
        self.station = MakeManipulationStation(
            filename=MODEL_YAML,
            package_xmls=["./package.xml"])
        builder.AddSystem(self.station)
        self.plant = self.station.GetSubsystemByName("plant")
        self.visualizer = MeshcatVisualizer.AddToBuilder(
            builder, self.station.GetOutputPort("query_object"), meshcat)
        
        # Stationary iiwa
        iiwa_source = builder.AddSystem(ConstantVectorSource(q0))
        builder.Connect(iiwa_source.get_output_port(),
                        self.station.GetInputPort("iiwa_position"))
        
        # Stationary gripper
        grip_source = builder.AddSystem(ConstantVectorSource([0.107]))
        builder.Connect(grip_source.get_output_port(),
                    self.station.GetInputPort("wsg_position"))
        
        # Finalize
        self.diagram = builder.Build()
        self.context = self.diagram.CreateDefaultContext()
            
        # Randomize poses of trash
        self._trash_model = make_trash_model()
        self.RandomizeTrash()
        
        # Publish context
        simulator = Simulator(self.diagram)
        context = simulator.get_context()
        simulator.AdvanceTo(0.1)
        meshcat.Flush()
        
    def RandomizeTrash(self):
        
        trash_context = self._trash_model.CreateDefaultContext()
        trash_plant = self._trash_model.GetSubsystemByName("plant")
        trash_plant_context = trash_plant.GetMyMutableContextFromRoot(trash_context)
        trash_scene_graph = self._trash_model.GetSubsystemByName("scene_graph")
        trash_scene_graph_context = trash_scene_graph.GetMyMutableContextFromRoot(trash_context)
        query_object = trash_scene_graph.get_query_output_port().Eval(trash_scene_graph_context)
        
        iterate = True
        counter = 0
        body_tfs = {}
        while iterate:
            for body_index in trash_plant.GetFloatingBaseBodies():
                body = trash_plant.get_body(body_index)
                if body.name() in ITEM_NAMES:
                    tf = RigidTransform(
                            UniformlyRandomRotationMatrix(generator),
                            [0.75*np.random.rand() - 0.375, 0.16*np.random.rand() - 0.08 -.6, .44])
                    trash_plant.SetFreeBodyPose(trash_plant_context, body, tf)
                    body_tfs[body.name()] = tf
                    
            iterate = query_object.HasCollisions()
            counter += 1
            if counter > 40:
                print("Large amount of consecutive failures, stopping...")
                break
        
        plant_context = self.plant.GetMyMutableContextFromRoot(self.context)
        if not query_object.HasCollisions():
            print(f"Objects randomized successfully after {counter} tries")
            for body_index in self.plant.GetFloatingBaseBodies():
                body = self.plant.get_body(body_index)
                if body.name() in ITEM_NAMES:
                    self.plant.SetFreeBodyPose(plant_context, body, body_tfs[body.name()])
        
    def GetIms(self):
        
        station_context = self.diagram.GetMutableSubsystemContext(self.station, self.context)
        im0 = self.station.GetOutputPort("camera0_rgb_image").Eval(station_context).data
        im1 = self.station.GetOutputPort("camera1_rgb_image").Eval(station_context).data
        
        return [im0, im1]
    
    def GetLabelIms(self):
        # Still need to do .squeeze() on each image output to view in matplotlib
        station_context = self.diagram.GetMutableSubsystemContext(self.station, self.context)
        im0 = self.station.GetOutputPort("camera0_label_image").Eval(station_context).data
        im1 = self.station.GetOutputPort("camera1_label_image").Eval(station_context).data
        
        return [im0, im1]
    
 
meshcat.Delete()
iiwa = IIWA()



Objects randomized successfully after 1 tries


In [4]:
import json
import matplotlib.pyplot as plt
from PIL import Image
import os
import shutil
import warnings

from manipulation.utils import colorize_labels

debug = False
path = '/tmp/clutter_maskrcnn_data'
num_images = 500

if debug:
    plt.rcParams["figure.figsize"] = (15,30)
    fig_d, (ax1_d, ax2_d) = plt.subplots(1,2)
    fig_l, (ax1_l, ax2_l) = plt.subplots(1,2)

if not debug:
    if os.path.exists(path):
        shutil.rmtree(path)
    os.makedirs(path)
    print(f'Creating dataset in {path} with {num_images} images')

def generate_image(image_num):
    filename_base_0 = os.path.join(path, f"{2*image_num:04d}")
    filename_base_1 = os.path.join(path, f"{2*image_num+1:04d}")

    inspector = iiwa.station.GetSubsystemByName("scene_graph").model_inspector()

    instance_id_to_class_name = {}

    for body_index in iiwa.plant.GetFloatingBaseBodies():
        body = iiwa.plant.get_body(body_index)
        if body.name() in ITEM_NAMES:
            frame_id = iiwa.plant.GetBodyFrameIdOrThrow(body_index)
            geometry_ids = inspector.GetGeometries(frame_id, Role.kPerception)
            for geom_id in geometry_ids:
                instance_id_to_class_name[int(
                    inspector.GetPerceptionProperties(geom_id).GetProperty(
                        "label", "id"))] = body.name()

    if not debug:
        with open(filename_base_0 + ".json", "w") as f:
            json.dump(instance_id_to_class_name, f)
        with open(filename_base_1 + ".json", "w") as f:
            json.dump(instance_id_to_class_name, f)

    # Randomize trash pose. Do we need to republish?
    iiwa.RandomizeTrash()
#     simulator = Simulator(iiwa.diagram, iiwa.context)
#     simulator.AdvanceTo(2)
        
    [rgb0, rgb1] = iiwa.GetIms()
    [label0, label1] = iiwa.GetLabelIms()

    if debug:
        ax1_d.imshow(rgb0)
        ax2_d.imshow(rgb1)
        ax1_l.imshow(colorize_labels(label0))
        ax2_l.imshow(colorize_labels(label1))
    else:
        Image.fromarray(rgb0).save(f"{filename_base_0}.png")
        Image.fromarray(rgb1).save(f"{filename_base_1}.png")
        np.save(f"{filename_base_0}_mask", label0)
        np.save(f"{filename_base_1}_mask", label1)

        
for image_num in range(int(num_images/2)):
    generate_image(image_num)


Creating dataset in /tmp/clutter_maskrcnn_data with 500 images
Objects randomized successfully after 4 tries
Objects randomized successfully after 6 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 5 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 4 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 2 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 3 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 1 tries
Objects randomized successfully after 2 tries
Objects randomized successfully after 3 tries
Objects randomized successfully after 3 tries
Objects randomize

: 