In [1]:
import numpy as np

from manipulation import running_as_notebook

from pydrake.all import (
    AddMultibodyPlantSceneGraph,
    Box,
    Parser,
    DiagramBuilder,
    MeshcatVisualizer,
    ModelVisualizer,
    MeshcatVisualizerParams,
    MultibodyPlant,
    Role,
    BodyIndex,
    RigidTransform,
    InverseKinematics,
    RotationMatrix,
    SceneGraph,
    Simulator,
    StartMeshcat,
    RollPitchYaw,
    InverseDynamicsController,
    Solve,
)

from manipulation.meshcat_utils import AddMeshcatTriad
from centers import centers

from IPython.display import clear_output

# Start the visualizer.
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7001


In [2]:
robot_directives = """
directives:
- add_model:
    name: iiwa_left
    file: package://drake/manipulation/models/iiwa_description/iiwa7/iiwa7_no_collision.sdf
    default_joint_positions:
        iiwa_joint_1: [-1]
        iiwa_joint_2: [0.1]
        iiwa_joint_3: [0]
        iiwa_joint_4: [-1.2]
        iiwa_joint_5: [0]
        iiwa_joint_6: [ 1.6]
        iiwa_joint_7: [10]
- add_weld:
    parent: world
    child: iiwa_left::iiwa_link_0
    X_PC:
    
        translation: [0.6, 0, 0]
        rotation: !Rpy { deg: [0, 0, -50]}
- add_model:
    name: iiwa_right
    file: package://drake/manipulation/models/iiwa_description/iiwa7/iiwa7_no_collision.sdf
    default_joint_positions:
        iiwa_joint_1: [-1]
        iiwa_joint_2: [0.1]
        iiwa_joint_3: [0]
        iiwa_joint_4: [-1.2]
        iiwa_joint_5: [0]
        iiwa_joint_6: [ 1.6]
        iiwa_joint_7: [10]
- add_weld:
    parent: world
    child: iiwa_right::iiwa_link_0
    X_PC:
        translation: [0.4, 0.4, 0]
        rotation: !Rpy { deg: [0, 0, -0]}
- add_model:
    name: drum_stick_left
    file: package://robotics_final_project/drum_sticks.sdf
- add_weld:
    parent: iiwa_left::iiwa_link_7
    child: drum_stick_left::stick
    X_PC:
        translation: [0, 0.1, .1]
        rotation: !Rpy { deg: [-90, 0, 0]}
- add_model:
    name: drum_stick_right
    file: package://robotics_final_project/drum_sticks.sdf
- add_weld:
    parent: iiwa_right::iiwa_link_7
    child: drum_stick_right::stick
    X_PC:
        translation: [0, 0, 0]
        rotation: !Rpy { deg: [90, 0, 0]}
- add_model:
    name: drum_kit
    file: package://robotics_final_project/drum_kit.sdf

"""

In [3]:
stiffness=20000
dissipation=2

meshcat.Delete()

builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(
        builder, time_step=0.00001
    )

parser = Parser(plant)
parser.package_map().Add("robotics_final_project", "./")
parser.AddModelsFromString(robot_directives, ".dmd.yaml")


plant.Finalize()
visualizer = MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

kp = [100] * plant.num_positions()
ki = [1] * plant.num_positions()
kd = [20] * plant.num_positions()
iiwa_controller = builder.AddSystem(
    InverseDynamicsController(plant, kp, ki, kd, False)
)
iiwa_controller.set_name("iiwa_controller")

builder.Connect(
    plant.get_state_output_port(),
    iiwa_controller.get_input_port_estimated_state(),
)
builder.Connect(
    iiwa_controller.get_output_port_control(), plant.get_actuation_input_port()
)

diagram = builder.Build()
context = diagram.CreateDefaultContext()
plant_context = plant.GetMyMutableContextFromRoot(context)

simulator = Simulator(diagram, context)
simulator.set_target_realtime_rate(1.0)

# MIDI File interpretation
* A MIDI file contains multiple tracks, which correspond to different instruments. For the sake of this project, we will focus only on the first track [0].
* MetaMessage: We will ignore most MetaMessages except for:
    * 'set_tempo': the 'tempo' key on this message contains the tempo in microseconds per quarter note. We can convert between BPM and this tempo using the 'mido' function tempo2bpm().
    Note: Ideally, our robot will try to keep up with the midi file tempo, but we might need to slow it down to match the robot's movements.
    * 'end_of_track': Message indicating the end of the track.
* Message: One of 'note_on' 'note_off'. Since the drums don't usually have a duration for each hit, we can ignore all 'note_off' messages.
    * 'note': Indicates what hit (snare, rim, tom, hihat, etc...) is required. Ranges from 0-127. A piano C1 is equivalent to note 36 in MIDI. We will use the following mapping: 
    ![midi_notes_map](./MIDI/midi_notes_map.jpg) 
    * 'velocity': Indicates the strength of the hit. Ranges from 0-127
    * 'time': Indicates the "ticks" since the last message was received. Knowing the 'ticks_per_beat' parameter, it's possible to obtain the absolute value in milliseconds, when the message is triggered. We can use function tick2second() to convert to seconds, a more useful measurement for our simulation.


In this book we will read each MIDI message and convert it to a format that's readable for our trajectory planning and control.

In [4]:
import numpy as np
from MIDI.midi_to_name import midi_to_name, hit_mapping
from mido import MidiFile, MetaMessage, tick2second

mid = MidiFile('MIDI/midi4.mid')
ticks_per_beat = mid.ticks_per_beat
track = mid.tracks[0]
tempo = 0
hits = []
last_hit_time = 0
total_time = 0
for msg in track:
    if type(msg) == MetaMessage:
        if msg.type == 'set_tempo':
            tempo = msg.tempo

    total_time += tick2second(msg.time, ticks_per_beat, tempo)
    last_hit_time += tick2second(msg.time, ticks_per_beat, tempo)
    if msg.type == 'note_on' and msg.note in midi_to_name.keys():
        hit = midi_to_name[msg.note]
        if hit in hit_mapping.keys():
            simplified_hit = hit_mapping[hit]
            hits.append({"hit": simplified_hit, "last_hit_time": last_hit_time, "total_time": total_time, "strength": msg.velocity})
            last_hit_time = 0

print(hits)


[{'hit': 'hi_hat', 'last_hit_time': 0.00125, 'total_time': 0.00125, 'strength': 114}, {'hit': 'hi_hat', 'last_hit_time': 0.3075, 'total_time': 0.30875, 'strength': 50}, {'hit': 'snare', 'last_hit_time': 0.2875, 'total_time': 0.5962500000000001, 'strength': 97}, {'hit': 'hi_hat', 'last_hit_time': 0.01125, 'total_time': 0.6075, 'strength': 117}, {'hit': 'hi_hat', 'last_hit_time': 0.30125, 'total_time': 0.9087500000000001, 'strength': 41}, {'hit': 'hi_hat', 'last_hit_time': 0.16, 'total_time': 1.06875, 'strength': 80}, {'hit': 'hi_hat', 'last_hit_time': 0.15, 'total_time': 1.2187500000000002, 'strength': 107}, {'hit': 'hi_hat', 'last_hit_time': 0.30375, 'total_time': 1.5225000000000002, 'strength': 62}, {'hit': 'snare', 'last_hit_time': 0.2675, 'total_time': 1.79, 'strength': 97}, {'hit': 'hi_hat', 'last_hit_time': 0.02875, 'total_time': 1.81875, 'strength': 117}, {'hit': 'hi_hat', 'last_hit_time': 0.29875, 'total_time': 2.1175, 'strength': 58}, {'hit': 'hi_hat', 'last_hit_time': 0.29875,

# MIDI notes to simulation poses

We can now run a simulation of the different poses depending on the midi file.

In [5]:
meshcat.StartRecording()
drum_stick_right = plant.GetModelInstanceByName("drum_stick_right")
drum_stick_left = plant.GetModelInstanceByName("drum_stick_left")
right_frame = plant.GetFrameByName("stick",drum_stick_right )
left_frame = plant.GetFrameByName("stick", drum_stick_left)

def visualize_frame(name, X_WF, length=0.15, radius=0.006):
        """
        visualize imaginary frame that are not attached to existing bodies

        Input:
            name: the name of the frame (str)
            X_WF: a RigidTransform to from frame F to world.

        Frames whose names already exist will be overwritten by the new frame
        """
        AddMeshcatTriad(
            meshcat, "painter/" + name, length=length, radius=radius, X_PT=X_WF
        )

for hit in [hits[0]]:
    # desired_qs =np.hstack([np.array(centers[hit["hit"]]), np.array(centers[hit["hit"]])])
    ik = InverseKinematics(plant, plant_context)
    qs = ik.q()
    prog = ik.prog()
    q_nominal = np.zeros(len(qs))
    prog.AddQuadraticErrorCost(
        np.eye(len(qs)), q_nominal, qs
    )

    # drum = plant.GetFrameByName(hit["hit"])
    # visualize_frame(hit["hit"], drum.CalcPoseInWorld (plant_context))

    rotation = RotationMatrix(RollPitchYaw(np.array(centers[hit["hit"]])[3:]))
    X_WD = RigidTransform(rotation, np.array(centers[hit["hit"]])[:3])
    # print(X_WD.translation())
    # visualize_frame('x', X_WD)
    ik.AddPositionConstraint(right_frame, np.zeros(3), plant.world_frame(), X_WD.translation() , X_WD.translation())
    # ik.AddPositionConstraint(left_frame, np.zeros(3), plant.world_frame(), np.array(centers[hit["hit"]])[:3] - 0.001, np.array(centers[hit["hit"]])[:3] + 0.001)
    # ik.AddOrientationConstraint(right_frame, RotationMatrix(np.eye(3)), plant.world_frame(), X_WD.rotation(), 0.01)
    # ik.AddOrientationConstraint(left_frame, RotationMatrix(np.eye(3)), plant.world_frame(), rt, 0.01)
    # ik.AddMinimumDistanceLowerBoundConstraint(0.01)
    
    lower_limits = plant.GetPositionLowerLimits()
    upper_limits = plant.GetPositionUpperLimits()
    lower_limits = np.where(lower_limits == float('-inf'), -np.pi, lower_limits)
    upper_limits = np.where(upper_limits == float('inf'), np.pi, upper_limits)
    max_tries = 100
    
    for count in range(max_tries):
        # Compute a random initial guess here
        
        for idx in np.ndindex(qs.shape):
            random_number = np.random.uniform(lower_limits[idx], upper_limits[idx])
            prog.SetInitialGuess( qs[idx], random_number)
        
        result = Solve(prog)

        if result.is_success():
            print("GOTCHA")
            visualize_frame('right_frame', right_frame.CalcPoseInWorld (plant_context))
            visualize_frame('left_frame', left_frame.CalcPoseInWorld (plant_context))
            q0 = result.GetSolution(qs)
            
            x0 = np.hstack((q0, 0 * q0))
            iiwa_controller.GetInputPort("desired_state").FixValue(
                iiwa_controller.GetMyMutableContextFromRoot(context), x0
            )
            break

            simulator.AdvanceTo(hit["total_time"])
simulator.AdvanceTo(5.0)
meshcat.StopRecording() 
meshcat.PublishRecording()

# q0 = np.array([-1.57, 0.1, 0, -1.2, 0, 1.6, 0, -1.57, 0.1, 0, -1.2, 0, 1.6, 0])
# plant.SetPositions(plant_context, q0)

GOTCHA
