Skip to content

Commit

Permalink
Add trim_silence to MidiFile and deprecate the one in NoteTrajectory.
Browse files Browse the repository at this point in the history
Also add more unit tests.
  • Loading branch information
kevinzakka committed Aug 30, 2023
1 parent fa4a0cd commit df08443
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 13 deletions.
4 changes: 3 additions & 1 deletion examples/play_midi_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@


def main(_) -> None:
music.load(_FILE.value, stretch=_STRETCH.value, shift=_SHIFT.value).play()
music.load(
_FILE.value, stretch=_STRETCH.value, shift=_SHIFT.value
).trim_silence().play()


if __name__ == "__main__":
Expand Down
13 changes: 13 additions & 0 deletions robopianist/music/midi_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ def transpose(self, amount: int, transpose_chords: bool = True) -> "MidiFile":
)
return MidiFile(seq=seq)

def trim_silence(self) -> "MidiFile":
seq = sequences_lib.extract_subsequence(
sequence=self.seq,
start_time=self.seq.notes[0].start_time,
end_time=self.seq.notes[-1].end_time,
)
return MidiFile(seq=seq)

def synthesize(self, sampling_rate: int = consts.SAMPLING_RATE) -> np.ndarray:
"""Synthesize the MIDI file into a waveform using FluidSynth."""
return midi_synth.fluidsynth(
Expand Down Expand Up @@ -361,6 +369,11 @@ def trim_silence(self) -> "NoteTrajectory":
This method modifies the note trajectory in place.
"""
print(
"WARNING: NoteTrajectory.trim_silence is deprecated. "
"Trim the silence at the MIDI level instead."
)

# Continue removing from the front until we find a non-empty timestep.
while len(self.notes) > 0 and len(self.notes[0]) == 0:
self.notes.pop(0)
Expand Down
5 changes: 5 additions & 0 deletions robopianist/music/midi_file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def test_transpose_no_op(self) -> None:
transposed_midi = midi.transpose(0)
self.assertProtoEquals(transposed_midi.seq, midi.seq)

def test_trim_silence(self) -> None:
midi = music.load("TwinkleTwinkleRousseau")
midi_trimmed = midi.trim_silence()
self.assertEqual(midi_trimmed.seq.notes[0].start_time, 0.0)


class PianoNoteTest(absltest.TestCase):
def test_constructor(self) -> None:
Expand Down
12 changes: 12 additions & 0 deletions robopianist/music/music_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ def test_midis_in_library(self, midi_name: str) -> None:
"""Test that all midis in the library can be loaded."""
self.assertIsInstance(music.load(midi_name), midi_file.MidiFile)

@parameterized.parameters(*music.ALL)
def test_fingering_available_for_all_timesteps(self, midi_name: str) -> None:
"""Test that all midis in the library have fingering annotations for all
timesteps."""
midi = music.load(midi_name).trim_silence()
traj = midi_file.NoteTrajectory.from_midi(midi, dt=0.05)
for timestep in traj.notes:
for note in timestep:
# -1 indicates no fingering annotation. Valid fingering lies in [0, 9].
self.assertGreater(note.fingering, -1)
self.assertLess(note.fingering, 10)


if __name__ == "__main__":
absltest.main()
8 changes: 4 additions & 4 deletions robopianist/suite/tasks/piano_with_one_shadow_hand.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def __init__(
goal state.
n_seconds_lookahead: Number of seconds to look ahead when computing the
goal state. If specified, this will override `n_steps_lookahead`.
trim_silence: If True, remove initial and final timesteps without any notes.
trim_silence: If True, shifts the MIDI file so that the first note starts
at time 0.
wrong_press_termination: If True, terminates the episode if the hands press
the wrong keys at any timestep.
initial_buffer_time: Specifies the duration of silence in seconds to add to
Expand All @@ -78,13 +79,14 @@ def __init__(
"""
super().__init__(arena=stage.Stage(), **kwargs)

if trim_silence:
midi = midi.trim_silence()
self._midi = midi
self._n_steps_lookahead = n_steps_lookahead
if n_seconds_lookahead is not None:
self._n_steps_lookahead = int(
np.ceil(n_seconds_lookahead / self.control_timestep)
)
self._trim_silence = trim_silence
self._initial_buffer_time = initial_buffer_time
self._disable_fingering_reward = disable_fingering_reward
self._wrong_press_termination = wrong_press_termination
Expand Down Expand Up @@ -129,8 +131,6 @@ def _maybe_change_midi(self, random_state) -> None:

def _reset_trajectory(self, midi: midi_file.MidiFile) -> None:
note_traj = midi_file.NoteTrajectory.from_midi(midi, self.control_timestep)
if self._trim_silence:
note_traj.trim_silence()
note_traj.add_initial_buffer_time(self._initial_buffer_time)
self._notes = note_traj.notes
self._sustains = note_traj.sustains
Expand Down
8 changes: 4 additions & 4 deletions robopianist/suite/tasks/piano_with_shadow_hands.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def __init__(
goal state.
n_seconds_lookahead: Number of seconds to look ahead when computing the
goal state. If specified, this will override `n_steps_lookahead`.
trim_silence: If True, remove initial and final timesteps without any notes.
trim_silence: If True, shifts the MIDI file so that the first note starts
at time 0.
wrong_press_termination: If True, terminates the episode if the hands press
the wrong keys at any timestep.
initial_buffer_time: Specifies the duration of silence in seconds to add to
Expand All @@ -95,14 +96,15 @@ def __init__(
"""
super().__init__(arena=stage.Stage(), **kwargs)

if trim_silence:
midi = midi.trim_silence()
self._midi = midi
self._initial_midi = midi
self._n_steps_lookahead = n_steps_lookahead
if n_seconds_lookahead is not None:
self._n_steps_lookahead = int(
np.ceil(n_seconds_lookahead / self.control_timestep)
)
self._trim_silence = trim_silence
self._initial_buffer_time = initial_buffer_time
self._disable_fingering_reward = (
disable_fingering_reward or not self._midi.has_fingering()
Expand Down Expand Up @@ -152,8 +154,6 @@ def _reset_trajectory(self) -> None:
note_traj = midi_file.NoteTrajectory.from_midi(
self._midi, self.control_timestep
)
if self._trim_silence:
note_traj.trim_silence()
note_traj.add_initial_buffer_time(self._initial_buffer_time)
self._notes = note_traj.notes
self._sustains = note_traj.sustains
Expand Down
8 changes: 4 additions & 4 deletions robopianist/suite/tasks/self_actuated_piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,20 @@ def __init__(
midi: A `MidiFile` object.
n_steps_lookahead: Number of timesteps to look ahead when computing the
goal state.
trim_silence: If True, remove initial and final timesteps without any notes.
trim_silence: If True, shifts the MIDI file so that the first note starts
at time 0.
reward_type: Reward function to use for the key press reward.
augmentations: A list of `Variation` objects that will be applied to the
MIDI file at the beginning of each episode. If None, no augmentations
will be applied.
"""
super().__init__(arena=stage.Stage(), add_piano_actuators=True, **kwargs)

if trim_silence:
midi = midi.trim_silence()
self._midi = midi
self._initial_midi = midi
self._n_steps_lookahead = n_steps_lookahead
self._trim_silence = trim_silence
self._key_press_reward = reward_type.get()
self._reward_fn = composite_reward.CompositeReward(
key_press_reward=self._compute_key_press_reward,
Expand All @@ -126,8 +128,6 @@ def _reset_trajectory(self) -> None:
note_traj = midi_file.NoteTrajectory.from_midi(
self._midi, self.control_timestep
)
if self._trim_silence:
note_traj.trim_silence()
self._notes = note_traj.notes
self._sustains = note_traj.sustains

Expand Down

0 comments on commit df08443

Please sign in to comment.