# Import packages

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from music21 import converter, instrument, note, chord
import glob
# import pickle
import numpy as np

In [12]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(Generator, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(Discriminator, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        out = self.sigmoid(out)
        return out

class MIDIDataset(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.notes, self.durations, self.offsets = self.get_notes()

    def __len__(self):
        return len(self.notes)

    def __getitem__(self, idx):
        return self.notes[idx], self.durations[idx], self.offsets[idx]

    def get_notes(self):
        """ Get all the notes and chords from the MIDI files in the specified directory """
        notes = []
        durations = []
        offsets = []

        for file in glob.glob(self.folder_path + "/*.mid"):
            try:
                midi = converter.parse(file)

                notes_to_parse = None

                try:  # file has instrument parts
                    s2 = instrument.partitionByInstrument(midi)
                    notes_to_parse = s2.parts[0].recurse()
                except:  # file has notes in a flat structure
                    notes_to_parse = midi.flat.notes

                for element in notes_to_parse:
                    if isinstance(element, note.Note):
                        notes.append(str(element.pitch))
                        durations.append(element.duration.quarterLength)
                        offsets.append(element.offset)
                    elif isinstance(element, chord.Chord):
                        notes.append('.'.join(str(n) for n in element.normalOrder))
                        durations.append(element.duration.quarterLength)
                        offsets.append(element.offset)
            except Exception as e:
                print(f"Error parsing MIDI file {file}: {str(e)}")

        return notes, durations, offsets


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
input_size = 128  # Assuming MIDI notes encoded into 128-dimensional vectors
hidden_size = 64
output_size = 1
num_layers = 1
num_epochs = 100
batch_size = 64
learning_rate = 0.001

# Create DataLoader for the dataset
dataset = MIDIDataset("Music")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize models, optimizer, and loss function
generator = Generator(input_size, hidden_size, output_size, num_layers)
discriminator = Discriminator(input_size, hidden_size, output_size, num_layers)
criterion = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    print(epoch)
    for i, (notes, durations, offsets) in enumerate(dataloader):
        # Convert data to tensors
        notes_tensor = torch.tensor(notes)
        durations_tensor = torch.tensor(durations)
        offsets_tensor = torch.tensor(offsets)

        # Train Discriminator
        disc_optimizer.zero_grad()
        real_outputs = discriminator(notes_tensor)
        real_labels = torch.ones(batch_size, 1)
        real_loss = criterion(real_outputs, real_labels)
        
        fake_inputs = torch.randn(batch_size, input_size)  # Generate fake inputs
        fake_notes = generator(fake_inputs)
        fake_outputs = discriminator(fake_notes.detach())  # Detach generator gradients
        fake_labels = torch.zeros(batch_size, 1)
        fake_loss = criterion(fake_outputs, fake_labels)
        
        disc_loss = real_loss + fake_loss
        disc_loss.backward()
        disc_optimizer.step()

        # Train Generator
        gen_optimizer.zero_grad()
        fake_outputs = discriminator(fake_notes)
        gen_loss = criterion(fake_outputs, real_labels)
        gen_loss.backward()
        gen_optimizer.step()

        if (i+1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "
                  f"Gen Loss: {gen_loss.item():.4f}, Disc Loss: {disc_loss.item():.4f}")

# Save models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

Error parsing MIDI file Music/mendel_op19_2.mid: 4960591584




Error parsing MIDI file Music/chpn_op33_2.mid: 4940227920




Error parsing MIDI file Music/chpn_op35_2.mid: 5813870944




Error parsing MIDI file Music/brahms_opus1_3.mid: 5798130688




Error parsing MIDI file Music/schub_d760_3.mid: 5897637696


