In [7]:
import os
import random

def get_random_midi_file(root_dir):
    midi_files = []
    for dirpath, _, filenames in os.walk(root_dir):
        for filename in filenames:

            if filename.endswith('.mid') or filename.endswith('.midi'):
                midi_files.append(os.path.join(dirpath, filename))
    
    if not midi_files:
        raise FileNotFoundError(f"No MIDI files found in {root_dir}")
    
    return random.choice(midi_files)

# Usage
root_directory = './lmd_full'
random_midi_file = get_random_midi_file(root_directory)
print(f"Random MIDI file: {random_midi_file}")

filename = random_midi_file

Random MIDI file: ./lmd_full/d/d6ecfdd296cb4c4273b9f716fff4b844.mid


In [9]:
import pretty_midi

midi_data = pretty_midi.PrettyMIDI(filename)

In [10]:
midi_data.instruments[0].notes

[Note(start=3.000000, end=4.200000, pitch=64, velocity=32),
 Note(start=27.189008, end=27.387871, pitch=64, velocity=86),
 Note(start=27.478781, end=27.822531, pitch=64, velocity=90),
 Note(start=27.958895, end=28.095258, pitch=64, velocity=78),
 Note(start=28.177645, end=28.282759, pitch=64, velocity=83),
 Note(start=28.376509, end=28.609463, pitch=64, velocity=91),
 Note(start=28.708895, end=28.927645, pitch=65, velocity=86),
 Note(start=29.021396, end=29.231623, pitch=64, velocity=82),
 Note(start=29.402078, end=29.862305, pitch=64, velocity=88),
 Note(start=29.893555, end=30.314010, pitch=62, velocity=77),
 Note(start=31.365148, end=31.512875, pitch=60, velocity=67),
 Note(start=31.600943, end=31.782761, pitch=60, velocity=75),
 Note(start=31.944693, end=32.061171, pitch=60, velocity=86),
 Note(start=32.120830, end=32.237307, pitch=60, velocity=82),
 Note(start=32.339580, end=32.617989, pitch=60, velocity=92),
 Note(start=32.757194, end=32.950376, pitch=64, velocity=87),
 Note(star

In [14]:
midi_data = pretty_midi.PrettyMIDI(filename)

print("TPQN:", midi_data.resolution)

# Print the time per tick
time_per_tick = midi_data.tick_to_time(1)

print("Time per tick:", time_per_tick)

# time per measure


TPQN: 384
Time per tick: 0.0015625


In [11]:
import pretty_midi

# Load the MIDI file
midi_data = pretty_midi.PrettyMIDI(filename)

# Print the TPQN (Ticks Per Quarter Note)
print("TPQN:", midi_data.resolution)

# Print the time per tick
time_per_tick = midi_data.tick_to_time(1)
print("Time per tick:", time_per_tick)

# Get the first track of the MIDI file
first_track = midi_data.instruments[2]

# Get all the "note on" messages in the first track
note_on_messages = [note for note in first_track.notes if note.velocity > 0]

# Assuming 4/4 time signature
ticks_per_measure = 4 * midi_data.resolution

# Create a list to store the annotated notes
annotated_notes = []

# Iterate over each note on message
for i, note in enumerate(note_on_messages):
    # Get the measure index of the note
    tick = midi_data.time_to_tick(note.start)
    measure_index = tick // ticks_per_measure
    
    # Get the ticks since the last measure for the note
    ticks_since_last_measure = tick % ticks_per_measure
    
    # Get the ticks since the last note
    if i > 0:
        ticks_since_last_note = tick - note_on_messages[i-1].start
    else:
        ticks_since_last_note = tick
    
    # Create a dictionary to store the annotated note information
    annotated_note = {
        'note': note,
        'measure_index': measure_index,
        'ticks_since_last_measure': ticks_since_last_measure,
        'ticks_since_last_note': ticks_since_last_note
    }
    
    # Append the annotated note to the list
    annotated_notes.append(annotated_note)

# Print the first 10 annotated notes
for i, annotated_note in enumerate(annotated_notes):
    # if i >= 20:
    #     break
    print("Note:", annotated_note['note'])
    print("Measure Index:", annotated_note['measure_index'])
    print("Ticks since last measure:", annotated_note['ticks_since_last_measure'])
    print("Ticks since last note:", annotated_note['ticks_since_last_note'])
    print("---")

TPQN: 384
Time per tick: 0.0015625
Note: Note(start=5.181250, end=5.562500, pitch=60, velocity=56)
Measure Index: 2
Ticks since last measure: 244
Ticks since last note: 0
---
Note: Note(start=4.981250, end=5.759375, pitch=55, velocity=41)
Measure Index: 2
Ticks since last measure: 116
Ticks since last note: 3182.81875
---
Note: Note(start=4.800000, end=5.975000, pitch=52, velocity=74)
Measure Index: 2
Ticks since last measure: 0
Ticks since last note: 3067.01875
---
Note: Note(start=5.790625, end=6.162500, pitch=55, velocity=49)
Measure Index: 2
Ticks since last measure: 634
Ticks since last note: 3701.2
---
Note: Note(start=5.575000, end=6.340625, pitch=60, velocity=49)
Measure Index: 2
Ticks since last measure: 496
Ticks since last note: 3562.209375
---
Note: Note(start=5.378125, end=6.540625, pitch=64, velocity=97)
Measure Index: 2
Ticks since last measure: 370
Ticks since last note: 3436.425
---
Note: Note(start=6.381250, end=6.765625, pitch=60, velocity=56)
Measure Index: 2
Ticks 

# Loading Data

In [26]:
import pretty_midi

def process_midi(filename, track_index=0):
    # Load the MIDI file
    midi_data = pretty_midi.PrettyMIDI(filename)

    # Print the TPQN (Ticks Per Quarter Note)
    print("TPQN:", midi_data.resolution)

    # Print the time per tick
    time_per_tick = midi_data.tick_to_time(1)
    print("Time per tick:", time_per_tick)

    # Get the specified track of the MIDI file
    track = midi_data.instruments[track_index]

    # Get all the "note on" messages in the track
    note_on_messages = [note for note in track.notes if note.velocity > 0]

    # Assuming 4/4 time signature
    ticks_per_measure = 4 * midi_data.resolution

    # Create a list to store the annotated notes
    annotated_notes = []

    # Iterate over each note on message
    for i, note in enumerate(note_on_messages):
        # Get the measure index of the note
        tick = midi_data.time_to_tick(note.start)
        measure_index = tick // ticks_per_measure

        # Get the ticks since the last measure for the note
        ticks_since_last_measure = tick % ticks_per_measure

        # Get the ticks since the last note
        if i > 0:
            ticks_since_last_note = tick - midi_data.time_to_tick(note_on_messages[i-1].start)
        else:
            ticks_since_last_note = tick

        # Create a dictionary to store the annotated note information
        annotated_note = {
            'note': note,
            'measure_index': measure_index,
            'ticks_since_last_measure': ticks_since_last_measure,
            'ticks_since_last_note': ticks_since_last_note
        }

        # Append the annotated note to the list
        annotated_notes.append(annotated_note)

    # Print the first 10 annotated notes
    # for i, annotated_note in enumerate(annotated_notes[:10]):
    #     print("Note:", annotated_note['note'])
    #     print("Measure Index:", annotated_note['measure_index'])
    #     print("Ticks since last measure:", annotated_note['ticks_since_last_measure'])
    #     print("Ticks since last note:", annotated_note['ticks_since_last_note'])
    #     print("---")

    return annotated_notes

# Usage
filename = random_midi_file
track_index = 0

annotated_notes = process_midi(filename, track_index)
print(len(annotated_notes))

TPQN: 384
Time per tick: 0.0015625
203


In [18]:
import torch
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = {'data': self.data[idx], 'label': self.labels[idx]}
        return sample

# Generate some random data
data = torch.randn(100, 3, 32, 32)  # 100 samples, 3 channels, 32x32 images
labels = torch.randint(0, 10, (100,))  # 100 labels for 10 classes

# Create dataset and dataloader
dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

# Iterate through the dataloader
for batch in dataloader:
    data_batch = batch['data']
    labels_batch = batch['label']
    print(data_batch.shape, labels_batch.shape)
    # Add your training code here


torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size(

In [None]:
root_directory = './lmd_full'
midi_files = []
for dirpath, _, filenames in os.walk(root_dir):
    for filename in filenames:

        if filename.endswith('.mid') or filename.endswith('.midi'):
            midi_files.append(
                pretty_midi.PrettyMIDI(os.path.join(dirpath, filename))
            )


class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, instance_idx):  
        sample = {'data': self.data[idx], 'label': self.labels[idx]}
        return sample

# Generate some random data
data = torch.randn(100, 3, 32, 32)  # 100 samples, 3 channels, 32x32 images
labels = torch.randint(0, 10, (100,))  # 100 labels for 10 classes

# Create dataset and dataloader
dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

# Iterate through the dataloader
for batch in dataloader:
    data_batch = batch['data']
    labels_batch = batch['label']
    print(data_batch.shape, labels_batch.shape)
    # Add your training code here



In [28]:
import sys
from collections import deque

def get_recursive_size(obj, seen=None):
    """Recursively finds the total memory usage of an object."""
    if seen is None:
        seen = set()
    
    obj_id = id(obj)
    
    if obj_id in seen:
        return 0
    
    seen.add(obj_id)
    
    size = sys.getsizeof(obj)
    
    if isinstance(obj, dict):
        size += sum(get_recursive_size(v, seen) for v in obj.values())
        size += sum(get_recursive_size(k, seen) for k in obj.keys())
    elif hasattr(obj, '__dict__'):
        size += get_recursive_size(obj.__dict__, seen)
    elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
        size += sum(get_recursive_size(i, seen) for i in obj)
    
    return size

# Example usage
example_object = [1, 2, {3: "a", 4: ["b", "c"]}]
print(get_recursive_size(pretty_midi.PrettyMIDI(random_midi_file)))

print(get_recursive_size(process_midi(random_midi_file)))

# get the number of instruments
midi_data.instruments


4972757
TPQN: 384
Time per tick: 0.0015625
103823


[Instrument(program=57, is_drum=False, name="Melodie 4"),
 Instrument(program=42, is_drum=False, name="Violoncl2"),
 Instrument(program=24, is_drum=False, name="GuitarAc3"),
 Instrument(program=24, is_drum=False, name="GuitarAc5")]