# Efficient Storage

In [1]:
import os
import random
import pretty_midi

def get_random_midi_file(root_dir):
    midi_files = []
    for dirpath, _, filenames in os.walk(root_dir):
        for filename in filenames:

            if filename.endswith('.mid') or filename.endswith('.midi'):
                midi_files.append(os.path.join(dirpath, filename))
    
    if not midi_files:
        raise FileNotFoundError(f"No MIDI files found in {root_dir}")
    
    return random.choice(midi_files)

# Usage
root_directory = './lmd_full'
random_midi_file = get_random_midi_file(root_directory)
print(f"Random MIDI file: {random_midi_file}")

filename = random_midi_file

def get_random_pretty_midi_file():
    filename = get_random_midi_file(root_directory)
    return pretty_midi.PrettyMIDI(filename)

Random MIDI file: ./lmd_full/6/60f4f7f37aa4dae34d541673cfc956ff.mid


In [2]:
import torch
def process_midi_data(midi_data):
    song = []
    time_per_quarter_note = midi_data.tick_to_time(midi_data.resolution)
    for i, instrument in enumerate(midi_data.instruments):
        for note in instrument.notes:
            track = i
            start = note.start
            duration = note.end - note.start
            pitch = note.pitch

            notedata = [track, start, duration, pitch]
            song.append(notedata)
        # sort by start time
        song.sort(key=lambda x: x[1])
        # prepend the time_per_quarter_note
    song.insert(0, [time_per_quarter_note, 0, 0, 0])
    return torch.tensor(song, dtype=torch.float32)

song_lengths = []
songs = []
for i in range(100):
    try:
        midi_data = get_random_pretty_midi_file()
    except Exception as e:
        print(f"Error reading file: {e}")
        continue

    song = process_midi_data(midi_data)
    songs.append(song)

def get_memory_used(tensor):
    return tensor.element_size() * tensor.nelement()

total_memory = 0
for song in songs:
    total_memory += get_memory_used(song)

print(f"Total memory used: {total_memory} bytes")
# in MB
print(f"Total memory used: {total_memory / 1024 / 1024} MB")



Error reading file: data byte must be in range 0..127
Error reading file: data byte must be in range 0..127
Total memory used: 5545296 bytes
Total memory used: 5.2884063720703125 MB


# Loading Data (old)

In [3]:

        

def process_midi(filename, track_index=0):
    # Load the MIDI file
    midi_data = pretty_midi.PrettyMIDI(filename)

    # Print the TPQN (Ticks Per Quarter Note)
    print("TPQN:", midi_data.resolution)

    # Print the time per tick
    time_per_tick = midi_data.tick_to_time(1)
    print("Time per tick:", time_per_tick)

    # Get the specified track of the MIDI file
    track = midi_data.instruments[track_index]

    # Get all the "note on" messages in the track
    note_on_messages = [note for note in track.notes if note.velocity > 0]

    # Assuming 4/4 time signature
    ticks_per_measure = 4 * midi_data.resolution

    # Create a list to store the annotated notes
    annotated_notes = []

    # Iterate over each note on message
    for i, note in enumerate(note_on_messages):
        # Get the measure index of the note
        tick = midi_data.time_to_tick(note.start)
        measure_index = tick // ticks_per_measure

        # Get the ticks since the last measure for the note
        ticks_since_last_measure = tick % ticks_per_measure

        # Get the ticks since the last note
        if i > 0:
            ticks_since_last_note = tick - midi_data.time_to_tick(note_on_messages[i-1].start)
        else:
            ticks_since_last_note = tick

        # Create a dictionary to store the annotated note information
        annotated_note = {
            'note': note,
            'measure_index': measure_index,
            'ticks_since_last_measure': ticks_since_last_measure,
            'ticks_since_last_note': ticks_since_last_note
        }

        # Append the annotated note to the list
        annotated_notes.append(annotated_note)

    # Print the first 10 annotated notes
    # for i, annotated_note in enumerate(annotated_notes[:10]):
    #     print("Note:", annotated_note['note'])
    #     print("Measure Index:", annotated_note['measure_index'])
    #     print("Ticks since last measure:", annotated_note['ticks_since_last_measure'])
    #     print("Ticks since last note:", annotated_note['ticks_since_last_note'])
    #     print("---")

    return annotated_notes

# Usage
filename = random_midi_file
track_index = 0

annotated_notes = process_midi(filename, track_index)
print(len(annotated_notes))

TPQN: 96
Time per tick: 0.005208333333333333
12


In [4]:
import torch
import pretty_midi

# class for a song

class Song:
    def __init__(self, filename):
        midi_data = pretty_midi.PrettyMIDI(filename)
        self.time_per_measure = midi_data.tick_to_time(midi_data.resolution * 4)
        self.tracks = []
        for track in midi_data.instruments:
            track_data = []
            for note in track.notes:
                start_tick = note.start
                pitch = note.pitch
                track_data.append([start_tick, pitch])
            # add the track data to the list of tracks in a torch tensor
            self.tracks.append(torch.tensor(track_data, dtype=torch.float32))

In [5]:
def process_midi_file(filename):
    try:
        song = Song(filename)
        return song, None
    except Exception as e:
        return None, e

def collect_midi_files(root_directory):
    midi_files = []
    for dirpath, _, filenames in os.walk(root_directory):
        for filename in filenames:
            if filename.endswith('.mid') or filename.endswith('.midi'):
                midi_files.append(os.path.join(dirpath, filename))
    return midi_files

if __name__ == "__main__":
    root_directory = './lmd_full'
    midi_files = collect_midi_files(root_directory)

    songs = []
    errors_processing = 0

    with Pool(cpu_count()) as pool:
        results = list(tqdm(pool.imap(process_midi_file, midi_files), total=len(midi_files)))

    for song, error in results:
        if song:
            songs.append(song)
        if error:
            errors_processing += 1

    print(f"Errors processing: {errors_processing}")


NameError: name 'Pool' is not defined

In [None]:
import pretty_midi
import os
root_directory = './lmd_full'
midi_files = []
for dirpath, _, filenames in os.walk(root_directory):
    for filename in filenames:

        if filename.endswith('.mid') or filename.endswith('.midi'):
            midi_files.append(
                os.path.join(dirpath, filename)
            )

songs_to_count = 100
print(f"Total number of MIDI files: {len(midi_files)}")
errors_processing = 0
total_number_of_notes = 0
for i, midi_data in enumerate(midi_files):
    try:
        # get the total number of notes
        notes_per_song = 0
        midi_file = pretty_midi.PrettyMIDI(midi_data)
        for track in midi_file.instruments:
            notes_per_song += len(track.notes)
        # print(f"notes per song: {notes_per_song}")
        total_number_of_notes += notes_per_song
    except Exception as e:
        print(e)
        errors_processing += 1
    if i > songs_to_count:
        break

    
print(f"Total number of notes: {total_number_of_notes}")
notes_per_file = total_number_of_notes / songs_to_count
print(f"notes per file: {notes_per_file}")

Total number of MIDI files: 178561
MThd not found. Probably not a MIDI file




Total number of notes: 380165
notes per file: 3801.65


In [None]:
from tqdm import tqdm

songs = []
errors_processing = 0
for midi_data in tqdm(midi_files):
    try:
        song = Song(midi_data)
        songs.append(song)
    except Exception as e:
        errors_processing += 1

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = {'data': self.data[idx], 'label': self.labels[idx]}
        return sample

# Generate some random data
data = torch.randn(100, 3, 32, 32)  # 100 samples, 3 channels, 32x32 images
labels = torch.randint(0, 10, (100,))  # 100 labels for 10 classes

# Create dataset and dataloader
dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

# Iterate through the dataloader
for batch in dataloader:
    data_batch = batch['data']
    labels_batch = batch['label']
    print(data_batch.shape, labels_batch.shape)
    # Add your training code here


torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size([4, 3, 32, 32]) torch.Size([4])
torch.Size(

In [None]:


class SongData(Dataset):
    def __init__(self, songs):
        self.songs = songs
        self.global_track_id_to_song_id = []
        for i, song in enumerate(songs):
            for track in song.instruments:
                self.global_track_id_to_song_id.append(i)

    def __len__(self):
        return len(self.global_track_id_to_song_id)

    def __getitem__(self, instance_idx):  
        song_idx = self.global_track_id_to_song_id[instance_idx]
        song = self.songs[song_idx]
        track_idx = instance_idx - song_idx
        sample = {'data': self.data[idx], 'label': self.labels[idx]}
        return sample

# Generate some random data
data = torch.randn(100, 3, 32, 32)  # 100 samples, 3 channels, 32x32 images
labels = torch.randint(0, 10, (100,))  # 100 labels for 10 classes

# Create dataset and dataloader
dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

# Iterate through the dataloader
for batch in dataloader:
    data_batch = batch['data']
    labels_batch = batch['label']
    print(data_batch.shape, labels_batch.shape)
    # Add your training code here



In [None]:
import sys
from collections import deque

def get_recursive_size(obj, seen=None):
    """Recursively finds the total memory usage of an object."""
    if seen is None:
        seen = set()
    
    obj_id = id(obj)
    
    if obj_id in seen:
        return 0
    
    seen.add(obj_id)
    
    size = sys.getsizeof(obj)
    
    if isinstance(obj, dict):
        size += sum(get_recursive_size(v, seen) for v in obj.values())
        size += sum(get_recursive_size(k, seen) for k in obj.keys())
    elif hasattr(obj, '__dict__'):
        size += get_recursive_size(obj.__dict__, seen)
    elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
        size += sum(get_recursive_size(i, seen) for i in obj)
    
    return size

# Example usage
example_object = [1, 2, {3: "a", 4: ["b", "c"]}]
object_1 = [([1.1]*500) for _ in range(1000)]
object_2 = [torch.tensor([1.1]*500) for _ in range(1000)]

print("python")
print(get_recursive_size(object_1))
print("tensors")
print(get_recursive_size(object_2))

# # get the number of instruments
# midi_data.instruments


4972757
TPQN: 384
Time per tick: 0.0015625
103823


[Instrument(program=57, is_drum=False, name="Melodie 4"),
 Instrument(program=42, is_drum=False, name="Violoncl2"),
 Instrument(program=24, is_drum=False, name="GuitarAc3"),
 Instrument(program=24, is_drum=False, name="GuitarAc5")]

# Preprocessing All Data into sequences

In [1]:
# Get a list of all the filenames
import os
import random
import pretty_midi

def get_all_filenames(root_dir):
    midi_files = []
    for dirpath, _, filenames in os.walk(root_dir):
        for filename in filenames:

            if filename.endswith('.mid') or filename.endswith('.midi'):
                midi_files.append(os.path.join(dirpath, filename))
    
    if not midi_files:
        raise FileNotFoundError(f"No MIDI files found in {root_dir}")
    return midi_files

# Usage
root_directory = './lmd_full'
files = get_all_filenames(root_directory)

In [2]:
import torch
import pretty_midi
def process_midi_data(midi_filename):
    try:
        midi_data = pretty_midi.PrettyMIDI(midi_filename)
    except Exception as e:
        return None, e
    song = []
    time_per_quarter_note = midi_data.tick_to_time(midi_data.resolution)
    for i, instrument in enumerate(midi_data.instruments):
        for note in instrument.notes:
            track = i
            start = note.start
            duration = note.end - note.start
            pitch = note.pitch

            notedata = [track, start, duration, pitch]
            song.append(notedata)
        # sort by start time
        song.sort(key=lambda x: x[1])
        # prepend the time_per_quarter_note
    song.insert(0, [time_per_quarter_note, 0, 0, 0])
    return torch.tensor(song, dtype=torch.float32), None

# Usage
filename = random.choice(files)
song, error = process_midi_data(filename)
print(f"Filename: {filename}")

# the reuslting format:
# header: [time_per_quarter_note, 0, 0, 0]
# each note: [track, start, duration, pitch]

print(f"first note:")
print(f"track: {song[1][0]}")
print(f"start: {song[1][1]}")
print(f"duration: {song[1][2]}")
print(f"pitch: {song[1][3]}")

print(f"\nsecond note:")
print(f"track: {song[2][0]}")
print(f"start: {song[2][1]}")
print(f"duration: {song[2][2]}")
print(f"pitch: {song[2][3]}")


Filename: ./lmd_full/6/659414fa1c3016d94d997128f528a73b.mid
first note:
track: 2.0
start: 3.2987499237060547
duration: 3.619999885559082
pitch: 60.0

second note:
track: 2.0
start: 3.2987499237060547
duration: 3.640000104904175
pitch: 72.0


In [3]:
# simple example of loading all the files (single process)
errors = []
songs = []
totalsongs = len(files)
for i, file in enumerate(files):
    if i % 1000 == 0:
        print(f"{i/totalsongs:.2f}% complete ({i})")
    try:
        song, error = process_midi_data(file)
        if error is not None:
            errors.append(error)
            continue
        songs.append(song)
    except Exception as e:
        print(f"unusual error: {e}")
    # if i > 100:
    #     break

0.00% complete (0)




0.01% complete (1000)
0.01% complete (2000)
0.02% complete (3000)
0.02% complete (4000)
0.03% complete (5000)
0.03% complete (6000)
0.04% complete (7000)
0.04% complete (8000)
0.05% complete (9000)
0.06% complete (10000)
0.06% complete (11000)
0.07% complete (12000)
0.07% complete (13000)
0.08% complete (14000)
0.08% complete (15000)
0.09% complete (16000)
0.10% complete (17000)
0.10% complete (18000)
0.11% complete (19000)
0.11% complete (20000)
0.12% complete (21000)
0.12% complete (22000)
0.13% complete (23000)
0.13% complete (24000)
unusual error: index -16307 is out of bounds for axis 0 with size 14508
0.14% complete (25000)
0.15% complete (26000)
0.15% complete (27000)
0.16% complete (28000)
0.16% complete (29000)
0.17% complete (30000)
0.17% complete (31000)
0.18% complete (32000)
0.18% complete (33000)
0.19% complete (34000)
0.20% complete (35000)
0.20% complete (36000)
0.21% complete (37000)
0.21% complete (38000)
0.22% complete (39000)
0.22% complete (40000)
0.23% complete (4

In [8]:
# Number of groups to split into
num_groups = 10

# Calculate the size of each group
group_size = len(songs) // num_groups

# Save each group directly
for i in range(num_groups):
    start_idx = i * group_size
    end_idx = (i + 1) * group_size if i < num_groups - 1 else len(songs)
    torch.save(songs[start_idx:end_idx], f'./dataset/song_group_{i}.pth')

# Handle any remaining tensors if the list size isn't perfectly divisible
if len(songs) % num_groups != 0:
    torch.save(songs[num_groups * group_size:], f'./dataset/song_group_{num_groups - 1}.pth')

# Training Hacking

In [None]:
# import torch

# loading all the data
# loaded_song_groups = []
# for i in range(num_groups):
#     loaded_song_groups.append(torch.load(f'./dataset/song_group_{i}.pth'))

# # Optionally, you can concatenate all groups back into a single list if needed
# loaded_songs = [song for group in loaded_song_groups for song in group]


In [1]:
import torch

# Load the first group of tensors
first_group = torch.load('./dataset/song_group_0.pth')

In [2]:
# the reuslting format:
# header: [time_per_quarter_note, 0, 0, 0]
# each note: [track, start, duration, pitch]

first_group[4]

tensor([[ 0.7059,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  1.4118,  0.1765, 43.0000],
        [ 0.0000,  1.4118,  1.0588, 55.0000],
        ...,
        [ 0.0000, 57.9706,  2.6471, 50.0000],
        [ 0.0000, 58.0147,  2.6471, 55.0000],
        [ 0.0000, 58.0588,  2.6471, 59.0000]])

this is how the dataset building would work on a mock dataset

In [22]:
test_context_length = 3

a = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
b = [11, 12, 13, 14, 15, 16, 17, 18]

dataset = [a, b]

# desired datastructure:
# index 1: song [0], notes 0-2
# index 2: song [0], notes 2-4
# index 3: song [0], notes 4-6
# index 4: song [0], notes 6-8
# index 5: song [0], notes 8-9
# index 6: song [1], notes 0-2
# ...

# stored in the format: [song_index, note_start_index]

# Create the dataset
train_idxs = []
for i, song in enumerate(dataset):
    for j in range(0, len(song), test_context_length):
        train_idxs.append([i, j])



test_context_length = 3

a = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
b = [11, 12, 13, 14, 15, 16, 17, 18]

dataset = [a, b]

# desired datastructure:
# index 1: song [0], notes 0-2
# index 2: song [0], notes 2-4
# index 3: song [0], notes 4-6
# index 4: song [0], notes 6-8
# index 5: song [0], notes 8-9
# index 6: song [1], notes 0-2
# ...

# stored in the format: [song_index, note_start_index]

# Create the dataset
data = []
for i, song in enumerate(dataset):
    for j in range(len(song) - test_context_length):
        if j % test_context_length == 0:
            data.append([i, j])


def build_idxs(dataset, context_length):
    idxs = []
    for i, song in enumerate(dataset):
        for j in range(len(song) - context_length):
            if j % context_length == 0:
                idxs.append([i, j])
    return idxs

# Usage
test_context_length = 3
dataset = [a, b]
train_idxs = build_idxs(dataset, test_context_length)

song, note_start = train_idxs[1]
dataset[song][note_start:note_start + test_context_length]

[4, 5, 6]

In [23]:
# dataset is the first_group, but remove the first element of each song
# first element is the header
# each song is a tensor
dataset = []
for song in first_group:
    dataset.append(song[1:])


context_length = 128

dataset_idxs = build_idxs(dataset, context_length)

song, note_start = dataset_idxs[1]

In [None]:
import torch
from torch.utils.data import Dataset

class SongDataSet(Dataset):
    def __init__(self, songs, context_length=128):
        self.songs = songs
        self.dataset_idxs = build_idxs(songs, context_length)

    def __len__(self):
        return len(self.dataset_idxs)

    def __getitem__(self, idx):
        song, note_start = self.dataset_idxs[idx]
        notes = self.songs[song][note_start:note_start + context_length]
        # pad notes if the length is less than context_length
        if len(notes) < context_length:
            padding = torch.zeros(context_length - len(notes), 4) # I DON"T HAVE THE BRAINPOWER TO FIGURE OUT RN
            notes = torch.cat([notes, padding])
        sample = {'data': self.data[idx], 'label': self.labels[idx]}
        return sample