In [1]:
import os
import torch
import torch.nn as nn
from music21 import converter, instrument, note

In [None]:
# Define the GAN architecture
class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Generator, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h0 = torch.zeros(2, x.size(0), self.hidden_dim)
        c0 = torch.zeros(2, x.size(0), self.hidden_dim)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Discriminator, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h0 = torch.zeros(2, x.size(0), self.hidden_dim)
        c0 = torch.zeros(2, x.size(0), self.hidden_dim)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

midi_folder = './Music'
midi_files = [f for f in os.listdir(midi_folder) if f.endswith('.mid')]

# Prepare the dataset
midi_data = []
for midi_file in midi_files:
    midi = converter.parse(midi_file)
    notes = []
    for element in midi.recurse():
        if isinstance(element, note.Note):
            notes.append(element.pitch.midi)
    midi_data.append(notes)

# Train the GAN
generator = Generator(5, 1, 1)
discriminator = Discriminator(5, 1, 1)
generator.to(device)
discriminator