In [87]:
import tensorflow as tf
import tensorflow_datasets as tfds
import pretty_midi as pm
import numpy as np
import io
import matplotlib.pyplot as plt
import sys

In [110]:
# tfds works in both Eager and Graph modes
np.set_printoptions(threshold=sys.maxsize)
tf.executing_eagerly()

True

In [89]:
# Load the full GMD with MIDI only (no audio) as a tf.data.Dataset
dataset = tfds.load(
    name="groove/full-midionly",
    split=tfds.Split.TRAIN,
    try_gcs=True)



In [90]:
# Build your input pipeline
dataset = dataset.shuffle(1024).batch(64).prefetch(
    tf.data.experimental.AUTOTUNE)

In [91]:

for features in dataset.take(1):
  # Access the features you are interested in
  midi, genre = features["midi"], features["style"]["primary"]

In [92]:
print(midi[0])

tf.Tensor(b'MThd\x00\x00\x00\x06\x00\x00\x00\x01\x01\xe0MTrk\x00\x00\x01G\x00\xffQ\x03\x0b\xbc\xce\x00\xffX\x04\x04\x02\x18\x08\x00\xffY\x02\x00\x00\x00\xc9\x00\x00\xb9\x04Z\x00\xc9\x00\x15\xb9\x04Z\x01\x99*\x18:\xb9\x04X\x06\x99*\x004\xb9\x04V;\x04T\x03\x04S\x0f\x04G\x0e\x04:\x04\x046\x06\x99\x1ac\t\xb9\x04/\x0e\x04(\n\x04#\x1d\x04 \x01\x99\x1a\x00\x1d\xb9\x04\x1c\x07\x04\x1a\x1e\x04\x18\r\x04\x16\x0f\x04\x1d\x0e\x04%\x0f\x04-\x05\x040\x0f\x04B\x04\x99,\x16\n\xb9\x04U\x04\x04Z\x08\x99$?\x04*+%,\x00\x17$\x00\x03*\x00\x81\x16\xb9\x04Z\x12\x99\x16F)\xb9\x04X\x17\x99\x16\x00#\xb9\x04V;\x04S:\x04Q\x08\x99%S\x15(k\x1e\xb9\x04N\x0c\x99%\x00\x15(\x00\x19\xb9\x04L;\x04I:\x04G;\x04D:\x04B\x1f\x990"\x1c\xb9\x04@\x0f\x99+A\x130\x00\x18\xb9\x04=\x13\x99$C\x00+\x00(\xb9\x04;\x17\x99$\x00#\xb9\x048;\x046 \x044\x0f\x04.\x0e\x04\'\x03\x99(p\t\x1af\x01\xb9\x04!\x0f\x04\x1d\x0e\x04\x18\x05\x04\x16\x13\x99(\x00\t\x1a\x00\x01\xb9\x04\x14\x13\x04\x12\x1d\x04\x0f\x0c\x04\r\x0e\x04\x11\x0f\x04\x15\x0f\x04\x1

In [102]:
def convert_to_sequence(midi_tensor):
    midi_bytes = midi_tensor.numpy()
    midi_file = io.BytesIO(midi_bytes)
    midi_data = pm.PrettyMIDI(midi_file)
    
    time_resolution = midi_data.tick_to_time(1)
    duration = midi_data.get_end_time()
    total_ticks = int(duration / time_resolution) + 1
    
    note_sequence = np.zeros((total_ticks, 128))
    
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            tick_start = int(note.start / time_resolution)
            tick_end = int(note.end / time_resolution)
            
            note_sequence[tick_start:tick_end, note.pitch] = 1
    
    return note_sequence

In [111]:
roll = convert_to_sequence(midi[0])

print(roll.shape)

(1928, 128)
