Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix in event decoder (?) #69

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

sapieneptus
Copy link

@sapieneptus sapieneptus commented Aug 1, 2022

Background

While running this decoder on a sequence of tokens (midifile -> NoteSequence -> event_tokens), I realised that after every note is processed and we return to processing shifts, the cur_time is reset (since start_time is 0). This then causes an exception to be thrown from note_sequences#decode_note_event:

if time < state.current_time:
    raise ValueError('event time < current time, %f < %f' % (
        time, state.current_time))

I believe the cur_time should be an offset of state.current_time, not of start_time - this allows the decoding sequence to pick up where it left off along the timeline.


Code example:

subsequences = note_seq.split_note_sequence(ns, 1)
    event_batches = []
    for i, subseq in enumerate(subsequences):
        subseq = note_seq.apply_sustain_control_changes(subseq)
        midi_times, midi_events = midi.note_sequence_to_events(subseq)
        del subseq.control_changes[:]

        events, _, _, _, _ = midi.encode_midi_events(audio_times, midi_times, midi_events)
        event_batches.append(events)

    reconstructed = midi.event_batches_to_note_sequence(event_batches, codec=utils.CODEC)

    midi.note_sequence_to_midi_file(reconstructed, 'moo.mid')

# midi.py

def midi_file_to_note_sequence(midi_path) -> note_seq.NoteSequence:
    """
    Convert a midi file to a list of onset and offset times and pitches
    """
    print(f"Converting midi file to note sequence: {midi_path}")
    ns = note_seq.midi_file_to_note_sequence(midi_path)
    return ns

def note_sequence_to_events(ns: note_seq.NoteSequence) -> Tuple[Sequence[float], Sequence[note_sequences.NoteEventData]]:
    return note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns)

def event_batches_to_note_sequence(event_batches, codec: event_codec.Codec=utils.CODEC) -> note_seq.NoteSequence:
    print("converting event batches to note sequence")
    decoding_state = note_sequences.NoteDecodingState()
    total_invalid_ids = 0
    total_dropped_events = 0

    for events in event_batches:
        invalid_ids, dropped_events = run_length_encoding.decode_events(
            state=decoding_state,
            tokens=events,
            start_time=decoding_state.current_time,
            max_time=None,
            codec=codec,
            decode_event_fn=note_sequences.decode_note_event
        )
        total_invalid_ids += invalid_ids
        total_dropped_events += dropped_events
        
    ns = note_sequences.flush_note_decoding_state(decoding_state)
    
    print(f'Dropped {total_dropped_events} events')
    print(f'Invalid ids: {total_invalid_ids}')
    return ns

def note_sequence_to_midi_file(ns: note_seq.NoteSequence, midi_path: str):
    """
    Convert a list of onset and offset times and pitches to a midi file
    """
    print(f"Converting events to midi file: {midi_path}")

    return note_seq.midi_io.note_sequence_to_midi_file(ns, midi_path)

def encode_midi_events(
    audio_frame_times: Sequence[float],
    midi_event_times: Sequence[float],
    midi_event_values: Sequence[note_sequences.NoteEventData]
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:

    events, event_start_indices, event_end_indices, state_events, state_event_indices = run_length_encoding.encode_and_index_events(
        state=note_sequences.NoteEncodingState(),
        event_times=midi_event_times,
        event_values=midi_event_values,
        encode_event_fn=note_sequences.note_event_data_to_events,
        codec=utils.CODEC,
        frame_times=audio_frame_times,
        encoding_state_to_events_fn=note_sequences.note_encoding_state_to_events
    )
    return events, event_start_indices, event_end_indices, state_events, state_event_indices

Use state.current_time after finishing decoding an event
@google-cla
Copy link

google-cla bot commented Aug 1, 2022

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@sapieneptus
Copy link
Author

I've signed the CLA, can that build be rerun?

@cghawthorne
Copy link

Can you provide a little more information on how you're using the decoder? I'm guessing it's for some kind of custom setup.

The way we're using it (as illustrated in metrics_utils.event_predictions_to_ns), the ground truth for the current time offset comes from the start_time passed into decode_events. That's then used to set state.current_time. We do this because we're decoding independently-inferred chunks of the full sequence.

@sapieneptus
Copy link
Author

sapieneptus commented Aug 2, 2022

We do this because we're decoding independently-inferred chunks of the full sequence.

Right, I had suspected as such. I was just trying to understand how the MT3 + note-seq libraries work and wrote some code to split a midi file into subsequences just to see if I could then reconstruct the original midi file. So my input would be a sequence of events corresponding to the entire midi file (several minutes worth of events).

I can see how this function would work as-is for a small slice containing only a single note event + some shift events, but it would fail if it encounters multiple note events (unless there are increasingly more shifts between subsequent note events).

So perhaps it's by design - but I believe this change is still an improvement, as it should not change functionality for the 'small slice' use-case and will prevent errors in a longer slice use-case.

I have updated the description with my relevant code.

@cghawthorne
Copy link

I think it still doesn't work for our case because state.current_time at the end of one chunk isn't necessarily the right start time of the subsequence chunk. For example, what if there are several chunks in a row with no note events? The start time of the chunk needs to come from some external source.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants