In [1]:
import numpy as np

from manipulation import running_as_notebook

from pydrake.all import (
    AddMultibodyPlantSceneGraph,
    Box,
    Parser,
    DiagramBuilder,
    MeshcatVisualizer,
    ModelVisualizer,
    MeshcatVisualizerParams,
    MultibodyPlant,
    Role,
    BodyIndex,
    RigidTransform,
    InverseKinematics,
    RotationMatrix,
    SceneGraph,
    Simulator,
    StartMeshcat,
    RollPitchYaw,
    InverseDynamicsController,
    Solve,
)

from manipulation.meshcat_utils import AddMeshcatTriad
from centers import centers
from MIDI.midi_to_hit import midi_to_hit

from IPython.display import clear_output

# Start the visualizer.
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7002


In [2]:
# robot_directives = """
# directives:

# - add_model:
#     name: iiwa_right
#     file: package://drake/manipulation/models/iiwa_description/iiwa7/iiwa7_no_collision.sdf
#     default_joint_positions:
#         iiwa_joint_1: [-1]
#         iiwa_joint_2: [0.1]
#         iiwa_joint_3: [0]
#         iiwa_joint_4: [-1.2]
#         iiwa_joint_5: [0]
#         iiwa_joint_6: [ 1.6]
#         iiwa_joint_7: [10]
# - add_weld:
#     parent: world
#     child: iiwa_right::iiwa_link_0
#     X_PC:
#         translation: [0.4, 0.6, 0]
#         rotation: !Rpy { deg: [0, 0, -0]}
# - add_model:
#     name: drum_stick_right
#     file: package://robotics_final_project/drum_sticks.sdf
# - add_weld:
#     parent: iiwa_right::iiwa_link_7
#     child: drum_stick_right::stick
#     X_PC:
#         translation: [0, 0, 0.18]
#         rotation: !Rpy { deg: [180, 0, 0]}
# - add_model:
#     name: drum_kit
#     file: package://robotics_final_project/drum_kit.sdf

# """
robot_directives = """
directives:

- add_model:
    name: iiwa_right
    file: package://drake/manipulation/models/iiwa_description/iiwa7/iiwa7_no_collision.sdf
    default_joint_positions:
        iiwa_joint_1: [-1]
        iiwa_joint_2: [0.1]
        iiwa_joint_3: [0]
        iiwa_joint_4: [-1.2]
        iiwa_joint_5: [0]
        iiwa_joint_6: [ 1.6]
        iiwa_joint_7: [10]
- add_weld:
    parent: world
    child: iiwa_right::iiwa_link_0
    X_PC:
        translation: [0.5, 1.3, 0]
        rotation: !Rpy { deg: [0, 0, -0]}
- add_model:
    name: drum_stick_right
    file: package://robotics_final_project/drum_sticks.sdf
- add_weld:
    parent: iiwa_right::iiwa_link_7
    child: drum_stick_right::stick
    X_PC:
        translation: [0, 0, 0.18]
        rotation: !Rpy { deg: [180, 0, 0]}

- add_model:
    name: iiwa_left
    file: package://drake/manipulation/models/iiwa_description/iiwa7/iiwa7_no_collision.sdf
    default_joint_positions:
        iiwa_joint_1: [-1]
        iiwa_joint_2: [0.1]
        iiwa_joint_3: [0]
        iiwa_joint_4: [-1.2]
        iiwa_joint_5: [0]
        iiwa_joint_6: [ 1.6]
        iiwa_joint_7: [10]
- add_weld:
    parent: world
    child: iiwa_left::iiwa_link_0
    X_PC:
        translation: [0.45, 0.3, 0]
        rotation: !Rpy { deg: [0, 0, -0]}
- add_model:
    name: drum_stick_left
    file: package://robotics_final_project/drum_sticks.sdf
- add_weld:
    parent: iiwa_left::iiwa_link_7
    child: drum_stick_left::stick
    X_PC:
        translation: [0, 0, 0.18]
        rotation: !Rpy { deg: [180, 0, 0]}

- add_model:
    name: drum_kit
    file: package://robotics_final_project/drum_kit.sdf

"""

In [3]:
stiffness=20000
dissipation=2

meshcat.Delete()

builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(
        builder, time_step=0.00001
    )

parser = Parser(plant)
parser.package_map().Add("robotics_final_project", "./")
parser.AddModelsFromString(robot_directives, ".dmd.yaml")

# plant.mutable_gravity_field().set_gravity_vector([0, 0, 0])

plant.Finalize()
visualizer = MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)
collision_visualizer = MeshcatVisualizer.AddToBuilder(
    builder,
    scene_graph,
    meshcat,
    MeshcatVisualizerParams(
        prefix="collision", role=Role.kProximity, visible_by_default=False
    ),
)

kp = [100] * plant.num_positions()
ki = [1] * plant.num_positions()
kd = [20] * plant.num_positions()
iiwa_controller = builder.AddSystem(
    InverseDynamicsController(plant, kp, ki, kd, False)
)
iiwa_controller.set_name("iiwa_controller")

builder.Connect(
    plant.get_state_output_port(),
    iiwa_controller.get_input_port_estimated_state(),
)
builder.Connect(
    iiwa_controller.get_output_port_control(), plant.get_actuation_input_port()
)

diagram = builder.Build()
context = diagram.CreateDefaultContext()
plant_context = plant.GetMyMutableContextFromRoot(context)

simulator = Simulator(diagram, context)
simulator.set_target_realtime_rate(1.0)

# MIDI File interpretation
* A MIDI file contains multiple tracks, which correspond to different instruments. For the sake of this project, we will focus only on the first track [0].
* MetaMessage: We will ignore most MetaMessages except for:
    * 'set_tempo': the 'tempo' key on this message contains the tempo in microseconds per quarter note. We can convert between BPM and this tempo using the 'mido' function tempo2bpm().
    Note: Ideally, our robot will try to keep up with the midi file tempo, but we might need to slow it down to match the robot's movements.
    * 'end_of_track': Message indicating the end of the track.
* Message: One of 'note_on' 'note_off'. Since the drums don't usually have a duration for each hit, we can ignore all 'note_off' messages.
    * 'note': Indicates what hit (snare, rim, tom, hihat, etc...) is required. Ranges from 0-127. A piano C1 is equivalent to note 36 in MIDI. We will use the following mapping: 
    ![midi_notes_map](./MIDI/midi_notes_map.jpg) 
    * 'velocity': Indicates the strength of the hit. Ranges from 0-127
    * 'time': Indicates the "ticks" since the last message was received. Knowing the 'ticks_per_beat' parameter, it's possible to obtain the absolute value in milliseconds, when the message is triggered. We can use function tick2second() to convert to seconds, a more useful measurement for our simulation.


In this book we will read each MIDI message and convert it to a format that's readable for our trajectory planning and control.

In [4]:
import numpy as np
from MIDI.midi_to_name import midi_to_name, hit_mapping
from mido import MidiFile, MetaMessage, tick2second


hits = midi_to_hit(filename = 'MIDI/midi_drums.mid')
print(hits)


[{'hit': 'ride', 'last_hit_time': 0.03750001875, 'total_time': 0.03750001875, 'strength': 82}, {'hit': 'low_tom', 'last_hit_time': 0.018055564583333333, 'total_time': 0.05555558333333334, 'strength': 57}, {'hit': 'snare', 'last_hit_time': 0.5138891458333333, 'total_time': 0.5694447291666667, 'strength': 116}, {'hit': 'ride', 'last_hit_time': 0.150000075, 'total_time': 0.7194448041666667, 'strength': 96}, {'hit': 'ride', 'last_hit_time': 0.60416696875, 'total_time': 1.3236117729166668, 'strength': 47}, {'hit': 'low_tom', 'last_hit_time': 0.018055564583333333, 'total_time': 1.3416673375000001, 'strength': 56}, {'hit': 'snare', 'last_hit_time': 0.52083359375, 'total_time': 1.8625009312500003, 'strength': 93}, {'hit': 'ride', 'last_hit_time': 0.18472231458333335, 'total_time': 2.0472232458333335, 'strength': 64}, {'hit': 'ride', 'last_hit_time': 0.616666975, 'total_time': 2.6638902208333337, 'strength': 42}, {'hit': 'mid_tom', 'last_hit_time': 0.0055555583333333335, 'total_time': 2.6694457

In [5]:
import copy
newhits=[]
def timecalc(hit1, hit2):
    halftime = 0.5 * (hit2 +hit1)
    return halftime

for i in range(len(hits)):
    if i ==0:
        halftime = timecalc(0, hits[i]["total_time"])
    else:
        halftime = timecalc(hits[i-1]["total_time"], hits[i]["total_time"])
    l = hits[i]
    ll= copy.copy(l)
    ll["hit"] = ll["hit"] + "2"
    ll["total_time"] = halftime
    newhits.append(ll)
    # print()
    newhits.append(l)
    # print(newhits[-1]["total_time"])
newhits



[{'hit': 'ride2',
  'last_hit_time': 0.03750001875,
  'total_time': 0.018750009375,
  'strength': 82},
 {'hit': 'ride',
  'last_hit_time': 0.03750001875,
  'total_time': 0.03750001875,
  'strength': 82},
 {'hit': 'low_tom2',
  'last_hit_time': 0.018055564583333333,
  'total_time': 0.04652780104166667,
  'strength': 57},
 {'hit': 'low_tom',
  'last_hit_time': 0.018055564583333333,
  'total_time': 0.05555558333333334,
  'strength': 57},
 {'hit': 'snare2',
  'last_hit_time': 0.5138891458333333,
  'total_time': 0.31250015625,
  'strength': 116},
 {'hit': 'snare',
  'last_hit_time': 0.5138891458333333,
  'total_time': 0.5694447291666667,
  'strength': 116},
 {'hit': 'ride2',
  'last_hit_time': 0.150000075,
  'total_time': 0.6444447666666666,
  'strength': 96},
 {'hit': 'ride',
  'last_hit_time': 0.150000075,
  'total_time': 0.7194448041666667,
  'strength': 96},
 {'hit': 'ride2',
  'last_hit_time': 0.60416696875,
  'total_time': 1.0215282885416668,
  'strength': 47},
 {'hit': 'ride',
  'las

# MIDI notes to simulation poses

We can now run a simulation of the different poses depending on the midi file.

In [6]:
# left = ["hi_hat", "crash", "hi_tom", "hi_hat2", "crash2", "hi_tom2","low_tom2","low_tom"]
# right = ["snare","mid_tom", "ride","snare2","mid_tom2", "ride2"]
left = ["hi_hat", "crash", "hi_tom", "hi_hat2", "crash2", "hi_tom2","snare","snare2"]
right = ["mid_tom", "ride","mid_tom2", "ride2","low_tom2","low_tom"]
# left = ["hi_hat", "hi_hat2"]
# right = ["snare","snare2"]
def left_right_detector(hit):
    if hit["hit"] in left:
        side = "left"
    else:
        side = "right"
    return side

def angle_constraint_detector(hit,centerx, centery):
    x1 = centers(hit["hit"])[0]
    y1 = centers(hit["hit"])[1]

    vec1 = x1 - centerx
    vec2 = y1 - centery
    
    angle = np.tan(vec1,vec2)

    return angle
 

In [7]:
meshcat.StartRecording()
drum_stick_right = plant.GetModelInstanceByName("drum_stick_right")
drum_stick_left = plant.GetModelInstanceByName("drum_stick_left")
right_frame = plant.GetFrameByName("stick",drum_stick_right )
left_frame = plant.GetFrameByName("stick", drum_stick_left)
hihat = plant.GetModelInstanceByName("drum_kit")
hihat_frame = plant.GetFrameByName("hi_hat",hihat)


def visualize_frame(name, X_WF, length=0.15, radius=0.006):
        """
        visualize imaginary frame that are not attached to existing bodies

        Input:
            name: the name of the frame (str)
            X_WF: a RigidTransform to from frame F to world.

        Frames whose names already exist will be overwritten by the new frame
        """
        AddMeshcatTriad(
            meshcat, "painter/" + name, length=length, radius=radius, X_PT=X_WF
        )
count = 0
for hit in newhits[:30]:

    print(hit["hit"])
    # desired_qs =np.hstack([np.array(centers[hit["hit"]]), np.array(centers[hit["hit"]])])
    ik = InverseKinematics(plant, plant_context)
    qs = ik.q()
    prog = ik.prog()
    q_nominal = np.zeros(len(qs))
    prog.AddQuadraticErrorCost(
        np.eye(len(qs)), q_nominal, qs
    )

    # drum = plant.GetFrameByName(hit["hit"])
    # visualize_frame(hit["hit"], drum.CalcPoseInWorld (plant_context))
    if hit["hit"][-1] == "2":
        ik.AddMinimumDistanceLowerBoundConstraint(0.001)
        # translation = np.array(centers[hit["hit"]])[:3] + np.array([0,0,+.04])
        translation = np.array(centers[hit["hit"][:-1]])[:3] + np.array([0,0,0.08])
    else:
        translation = np.array(centers[hit["hit"]])[:3]

    rotation = RotationMatrix(RollPitchYaw(np.array(centers[hit["hit"]])[3:]))
    X_WD = RigidTransform(rotation, translation)
    print(X_WD.translation())
    # visualize_frame('hihat',X_WD )

    # print(X_WD.translation())
    # visualize_frame('x', X_WD)
    if left_right_detector(hit) == "right":
        ik.AddPositionConstraint(right_frame, np.zeros(3), plant.world_frame(), X_WD.translation()-np.array([0.01,0.01,0]), X_WD.translation()+np.array([0.01,0.01,0]))
        # ik.AddMinimumDistanceLowerBoundConstraint(0.01)

        # additional_rotation = RollPitchYaw(np.radians(50), np.radians(20), 0).ToRotationMatrix()
        # combined_rotation_matrix = X_WD.rotation().multiply(additional_rotation)
        # ik.AddOrientationConstraint(right_frame, RotationMatrix(np.eye(3)), plant.world_frame(), combined_rotation_matrix, 0)
        # print(plant.world_frame().rotation())
    elif left_right_detector(hit) == "left":
        ik.AddPositionConstraint(left_frame, np.zeros(3), plant.world_frame(), X_WD.translation()-np.array([0.01,0.01,0]), X_WD.translation() +np.array([0.01,0.01,0]))
    


    
    lower_limits = plant.GetPositionLowerLimits()
    upper_limits = plant.GetPositionUpperLimits()
    lower_limits = np.where(lower_limits == float('-inf'), -np.pi, lower_limits)
    upper_limits = np.where(upper_limits == float('inf'), np.pi, upper_limits)
    max_tries = 100
    
    for count in range(max_tries):
        # Compute a random initial guess here
        
        for idx in np.ndindex(qs.shape):
            random_number = np.random.uniform(lower_limits[idx], upper_limits[idx])
            prog.SetInitialGuess( qs[idx], random_number)
        
        result = Solve(prog)

        if result.is_success():
            print("GOTCHA")
            visualize_frame('right_frame', right_frame.CalcPoseInWorld (plant_context))
            visualize_frame('left_frame', left_frame.CalcPoseInWorld (plant_context))
            q0 = result.GetSolution(qs)
            if left_right_detector(hit) == "right":
                if hit["hit"][-1] == "2":
                    x0 = np.hstack((q0, 0* q0))
                else:
                    plant.SetVelocities(plant_context,drum_stick_right,np.zeros(0,),)
                    x0 = np.hstack((q0, .001*hit['strength'] * q0))
            else:
                if hit["hit"][-1] == "2":
                    x0 = np.hstack((q0, 0* q0))
                else:                
                    plant.SetVelocities(plant_context,drum_stick_left,np.zeros(0,),)
                    x0 = np.hstack((q0, .001*hit['strength'] * q0))

            iiwa_controller.GetInputPort("desired_state").FixValue(
                iiwa_controller.GetMyMutableContextFromRoot(context), x0
            )           

            simulator.AdvanceTo(hit["total_time"])

            break
    if result.is_success()==False:
        print(f"NO GOOD: Cant reach {hit['hit']}")
    count+=1

simulator.AdvanceTo(5)
meshcat.StopRecording() 
meshcat.PublishRecording()

# q0 = np.array([-1.57, 0.1, 0, -1.2, 0, 1.6, 0, -1.57, 0.1, 0, -1.2, 0, 1.6, 0])
# plant.SetPositions(plant_context, q0)

ride2
[-0.25  1.1   0.73]
GOTCHA
ride
[-0.25  1.1   0.65]
GOTCHA
low_tom2
[0.1  0.9  0.38]
GOTCHA
low_tom
[0.1 0.9 0.3]
GOTCHA
snare2
[0.1  0.1  0.57]
GOTCHA
snare
[0.1  0.1  0.49]
GOTCHA
ride2
[-0.25  1.1   0.73]
GOTCHA
ride
[-0.25  1.1   0.65]
GOTCHA
ride2
[-0.25  1.1   0.73]
GOTCHA
ride
[-0.25  1.1   0.65]
GOTCHA
low_tom2
[0.1  0.9  0.38]
GOTCHA
low_tom
[0.1 0.9 0.3]
GOTCHA
snare2
[0.1  0.1  0.57]
GOTCHA
snare
[0.1  0.1  0.49]
GOTCHA
ride2
[-0.25  1.1   0.73]
GOTCHA
ride
[-0.25  1.1   0.65]
GOTCHA
ride2
[-0.25  1.1   0.73]
GOTCHA
ride
[-0.25  1.1   0.65]
GOTCHA
mid_tom2
[-0.2   0.3   0.68]
NO GOOD: Cant reach mid_tom2
mid_tom
[-0.2  0.3  0.6]
NO GOOD: Cant reach mid_tom
snare2
[0.1  0.1  0.57]
GOTCHA
snare
[0.1  0.1  0.49]
GOTCHA
ride2
[-0.25  1.1   0.73]
GOTCHA
ride
[-0.25  1.1   0.65]
GOTCHA
ride2
[-0.25  1.1   0.73]
GOTCHA
ride
[-0.25  1.1   0.65]
GOTCHA
hi_tom2
[-0.2   0.3   0.68]
NO GOOD: Cant reach hi_tom2
hi_tom
[-0.2  0.3  0.6]
GOTCHA
snare2
[0.1  0.1  0.57]
GOTCHA
snare
[0.