In [7]:
import os
import glob
import numpy as np
import music21
from music21 import converter, instrument, note, chord
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [8]:
class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Generator, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.lstm.hidden_size)
        c0 = torch.zeros(1, x.size(0), self.lstm.hidden_size)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Discriminator, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.lstm.hidden_size)
        c0 = torch.zeros(1, x.size(0), self.lstm.hidden_size)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

def load_midi_files(folder_path):
    midi_files = glob.glob(os.path.join(folder_path, "*.mid"))
    notes, durations, offsets = [], [], []
    for midi_file in midi_files:
        midi = music21.midi.MidiFile()
        midi.open(midi_file)
        midi.read()
        midi.close()
        for element in midi.recurse():
            if isinstance(element, note.Note):
                notes.append(str(element.pitch))
                durations.append(element.quarterLength)
                offsets.append(sum(duration for duration in offsets))
    return notes, durations, offsets

In [6]:
class MidiDataLoader:
    def __init__(self, folder_path):
        self.folder_path = folder_path

    def load_midi_files(self):
        midi_files = glob.glob(os.path.join(self.folder_path, "*.mid"))
        return midi_files

    def extract_notes_duration_offset(self, midi_file):
        notes = []
        durations = []
        offsets = []

        midi = converter.parse(midi_file)
        parts = instrument.partitionByInstrument(midi)
        if parts:  # file has instrument parts
            notes_to_parse = parts.parts[0].recurse()
        else:  # file has notes in a flat structure
            notes_to_parse = midi.flat.notes

        for element in notes_to_parse:
            if isinstance(element, note.Note):
                notes.append(str(element.pitch))
                durations.append(element.duration.quarterLength)
                offsets.append(element.offset)
            elif isinstance(element, chord.Chord):
                notes.append('.'.join(str(n) for n in element.normalOrder))
                durations.append(element.duration.quarterLength)
                offsets.append(element.offset)

        return notes, durations, offsets

    def create_dataset(self, midi_files):
        notes_list = []
        durations_list = []
        offsets_list = []

        for midi_file in midi_files:
            notes, durations, offsets = self.extract_notes_duration_offset(midi_file)
            notes_list.extend(notes)
            durations_list.extend(durations)
            offsets_list.extend(offsets)

        return notes_list, durations_list, offsets_list

In [5]:
input_dim = 100
hidden_dim = 128
output_dim = 1
lr = 0.001
batch_size = 64
num_epochs = 100



KeyError: 4982311808

In [None]:
def preprocess_data(notes, durations, offsets):
    notes = torch.tensor([note_to_int(n) for n in notes], dtype=torch.long)
    durations = torch.tensor(durations, dtype=torch.float)
    offsets = torch.tensor(offsets, dtype=torch.float)

    # Normalize durations and offsets
    durations = (durations - durations.min()) / (durations.max() - durations.min())
    offsets = (offsets - offsets.min()) / (offsets.max() - offsets.min())

    return notes, durations, offsets

def note_to_int(note):
    note_map = {'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5, 'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11}
    return note_map[note.split(':')[0]]

In [None]:
notes_train, notes_val, durations_train, durations_val, offsets_train, offsets_val = train_test_split(notes, durations, offsets, test_size=0.2, random_state=42)

In [None]:
class MidiDataset(data.Dataset):
    def __init__(self, notes, durations, offsets):
        self.notes = notes
        self.durations = durations
        self.offsets = offsets

    def __len__(self):
        return len(self.notes)

    def __getitem__(self, idx):
        return self.notes[idx], self.durations[idx], self.offsets[idx]

train_dataset = MidiDataset(notes_train, durations_train, offsets_train)
val_dataset = MidiDataset(notes_val, durations_val, offsets_val)

train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = data.DataLoader(val_dataset, batch_size=32, shuffle=False)