In [1]:
!pip install pretty_midi

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m45.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25ldone
[?25h  Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592292 sha256=fdd1ef82b8347d261abd5cd858500e5dd9a98ec64317622fbab64a35cf7b1ea7
  Stored in directory: /root/.cache/pip/wheels/cd/a5/30/7b8b7f58709f5150f67f98fde4b891ebf0be9ef07a8af49f25
Successfully built pretty_midi
Installing collected packages: mido, pretty_m

# Imports

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pretty_midi
import os
import pathlib
import glob
import tensorflow as tf
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# Preprocessing and training

In [3]:
# Hyperparameters
SEQUENCE_LENGTH = 256  # Length of musical sequence
LATENT_DIM = 100      # Size of random noise vector
BATCH_SIZE = 32
NUM_EPOCHS = 100
LEARNING_RATE = 0.0002
BETA1 = 0.5          # Adam optimizer parameter

# Dataset class for MIDI files
class MidiDataset(Dataset):
    def __init__(self, midi_folder, sequence_length=SEQUENCE_LENGTH):
        self.sequence_length = sequence_length
        self.data = []
        # Load and process MIDI files
        for filename in list(glob.glob(str(data_dir/'*/*.mid*'))):
            if filename.endswith('.midi'):
                try:
                    midi_data = pretty_midi.PrettyMIDI(filename)
                    
                    # Extract piano roll (matrix representation of notes)
                    piano_roll = midi_data.get_piano_roll(fs=16)  # 16 samples per beat
                    
                    # Normalize and prepare sequences
                    piano_roll = (piano_roll > 0).astype(np.float32)
                    
                    # Create sequences of specified length
                    for i in range(0, piano_roll.shape[1] - sequence_length, sequence_length):
                        sequence = piano_roll[:, i:i + sequence_length]
                        self.data.append(sequence)
                        
                except Exception as e:
                    print(f"Error loading {filename}: {e}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return torch.FloatTensor(self.data[idx])

# Generator Network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.main = nn.Sequential(
            # Input is latent vector Z
            nn.Linear(LATENT_DIM, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            
            nn.Linear(1024, 128 * SEQUENCE_LENGTH),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.main(x).view(-1, 128, SEQUENCE_LENGTH)

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(
            nn.Linear(128 * SEQUENCE_LENGTH, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = x.view(-1, 128 * SEQUENCE_LENGTH)
        return self.main(x)

# Training function
def train_gan(generator, discriminator, dataloader, num_epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)
    
    criterion = nn.BCELoss()
    g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
    
    for epoch in range(num_epochs):
        for i, real_data in enumerate(tqdm(dataloader)):
            batch_size = real_data.size(0)
            real_data = real_data.to(device)
            
            # Train Discriminator
            d_optimizer.zero_grad()
            label_real = torch.ones(batch_size, 1).to(device)
            label_fake = torch.zeros(batch_size, 1).to(device)
            
            output_real = discriminator(real_data)
            d_loss_real = criterion(output_real, label_real)
            
            noise = torch.randn(batch_size, LATENT_DIM).to(device)
            fake_data = generator(noise)
            output_fake = discriminator(fake_data.detach())
            d_loss_fake = criterion(output_fake, label_fake)
            
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_optimizer.step()
            
            # Train Generator
            g_optimizer.zero_grad()
            output_fake = discriminator(fake_data)
            g_loss = criterion(output_fake, label_real)
            
            g_loss.backward()
            g_optimizer.step()
            
        print(f"Epoch [{epoch}/{num_epochs}] d_loss: {d_loss.item():.4f} g_loss: {g_loss.item():.4f}")
        
        # Save sample generation every 10 epochs
        if (epoch + 1) % 10 == 0:
            generate_sample(generator, epoch + 1, device)

# Function to generate and save a sample MIDI file
def generate_sample(generator, epoch, device):
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(1, LATENT_DIM).to(device)
        fake_data = generator(noise).cpu().numpy()[0]
        
        # Convert to MIDI
        pm = pretty_midi.PrettyMIDI()
        piano_program = pretty_midi.Instrument(program=0)  # Piano
        
        # Convert piano roll back to notes
        threshold = 0.5
        piano_roll = (fake_data > threshold).astype(int)
        
        for pitch in range(128):
            for time in range(SEQUENCE_LENGTH):
                if piano_roll[pitch][time] == 1:
                    note = pretty_midi.Note(
                        velocity=100,
                        pitch=pitch,
                        start=time * 0.0625,  # Convert based on fs=16
                        end=(time + 1) * 0.0625
                    )
                    piano_program.notes.append(note)
        
        pm.instruments.append(piano_program)
        pm.write(f'generated_music_epoch_{epoch}.mid')
    generator.train()

# Collecting data and running

In [4]:
data_dir = pathlib.Path('data/maestro-v2.0.0')
if not data_dir.exists():
    tf.keras.utils.get_file(
        'maestro-v2.0.0-midi.zip',
        origin='https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip',
        extract=True, cache_dir='.', cache_subdir='data',
    )

# Create dataset and dataloader
dataset = MidiDataset(str(data_dir/'*/*.mid*'))
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize networks
generator = Generator()
discriminator = Discriminator()

# Train the GAN
train_gan(generator, discriminator, dataloader, NUM_EPOCHS)

Downloading data from https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip
[1m59243107/59243107[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step


100%|██████████| 1397/1397 [00:46<00:00, 29.90it/s]


Epoch [0/100] d_loss: 0.0798 g_loss: 7.7849


100%|██████████| 1397/1397 [00:46<00:00, 30.13it/s]


Epoch [1/100] d_loss: 0.0501 g_loss: 13.7797


100%|██████████| 1397/1397 [00:49<00:00, 28.14it/s]


Epoch [2/100] d_loss: 0.1283 g_loss: 8.1964


100%|██████████| 1397/1397 [00:48<00:00, 28.81it/s]


Epoch [3/100] d_loss: 0.0239 g_loss: 18.1097


100%|██████████| 1397/1397 [00:49<00:00, 28.36it/s]


Epoch [4/100] d_loss: 0.1710 g_loss: 13.5794


100%|██████████| 1397/1397 [00:49<00:00, 28.47it/s]


Epoch [5/100] d_loss: 0.0480 g_loss: 27.1573


100%|██████████| 1397/1397 [00:49<00:00, 28.42it/s]


Epoch [6/100] d_loss: 0.2325 g_loss: 22.9024


100%|██████████| 1397/1397 [00:48<00:00, 28.68it/s]


Epoch [7/100] d_loss: 0.0813 g_loss: 58.9825


100%|██████████| 1397/1397 [00:48<00:00, 28.55it/s]


Epoch [8/100] d_loss: 0.0049 g_loss: 74.3765


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [9/100] d_loss: 0.0289 g_loss: 58.3728


100%|██████████| 1397/1397 [00:48<00:00, 28.57it/s]


Epoch [10/100] d_loss: 0.0009 g_loss: 91.5247


100%|██████████| 1397/1397 [00:48<00:00, 28.65it/s]


Epoch [11/100] d_loss: 0.0388 g_loss: 65.5102


100%|██████████| 1397/1397 [00:48<00:00, 28.60it/s]


Epoch [12/100] d_loss: 0.0202 g_loss: 60.6555


100%|██████████| 1397/1397 [00:48<00:00, 28.57it/s]


Epoch [13/100] d_loss: 0.0271 g_loss: 44.3926


100%|██████████| 1397/1397 [00:48<00:00, 28.69it/s]


Epoch [14/100] d_loss: 0.0083 g_loss: 65.3224


100%|██████████| 1397/1397 [00:48<00:00, 28.70it/s]


Epoch [15/100] d_loss: 0.0178 g_loss: 94.9837


100%|██████████| 1397/1397 [00:48<00:00, 28.62it/s]


Epoch [16/100] d_loss: 0.0106 g_loss: 50.8859


100%|██████████| 1397/1397 [00:48<00:00, 28.58it/s]


Epoch [17/100] d_loss: 0.0664 g_loss: 84.4076


100%|██████████| 1397/1397 [00:48<00:00, 28.62it/s]


Epoch [18/100] d_loss: 0.1322 g_loss: 64.3686


100%|██████████| 1397/1397 [00:49<00:00, 28.51it/s]


Epoch [19/100] d_loss: 0.0921 g_loss: 30.1752


100%|██████████| 1397/1397 [00:49<00:00, 28.51it/s]


Epoch [20/100] d_loss: 0.0256 g_loss: 27.0523


100%|██████████| 1397/1397 [00:48<00:00, 28.55it/s]


Epoch [21/100] d_loss: 0.0025 g_loss: 74.6936


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [22/100] d_loss: 0.1843 g_loss: 35.0355


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [23/100] d_loss: 0.1987 g_loss: 67.8568


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [24/100] d_loss: 0.0039 g_loss: 77.0464


100%|██████████| 1397/1397 [00:48<00:00, 28.56it/s]


Epoch [25/100] d_loss: 0.2952 g_loss: 46.1849


100%|██████████| 1397/1397 [00:48<00:00, 28.58it/s]


Epoch [26/100] d_loss: 0.1995 g_loss: 28.3697


100%|██████████| 1397/1397 [00:48<00:00, 28.59it/s]


Epoch [27/100] d_loss: 0.0555 g_loss: 57.1616


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [28/100] d_loss: 0.0643 g_loss: 59.2465


100%|██████████| 1397/1397 [00:48<00:00, 28.62it/s]


Epoch [29/100] d_loss: 0.3034 g_loss: 15.2474


100%|██████████| 1397/1397 [00:48<00:00, 28.60it/s]


Epoch [30/100] d_loss: 0.0852 g_loss: 17.1358


100%|██████████| 1397/1397 [00:48<00:00, 28.57it/s]


Epoch [31/100] d_loss: 0.0309 g_loss: 33.5146


100%|██████████| 1397/1397 [00:48<00:00, 28.58it/s]


Epoch [32/100] d_loss: 0.0751 g_loss: 26.2118


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [33/100] d_loss: 0.0234 g_loss: 28.0725


100%|██████████| 1397/1397 [00:48<00:00, 28.53it/s]


Epoch [34/100] d_loss: 0.2024 g_loss: 20.2686


100%|██████████| 1397/1397 [00:48<00:00, 28.61it/s]


Epoch [35/100] d_loss: 0.1541 g_loss: 65.6872


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [36/100] d_loss: 0.0156 g_loss: 28.6102


100%|██████████| 1397/1397 [00:48<00:00, 28.53it/s]


Epoch [37/100] d_loss: 0.0731 g_loss: 14.2314


100%|██████████| 1397/1397 [00:48<00:00, 28.64it/s]


Epoch [38/100] d_loss: 0.1619 g_loss: 41.9086


100%|██████████| 1397/1397 [00:48<00:00, 28.55it/s]


Epoch [39/100] d_loss: 0.0665 g_loss: 17.2894


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [40/100] d_loss: 0.0777 g_loss: 89.9258


100%|██████████| 1397/1397 [00:48<00:00, 28.57it/s]


Epoch [41/100] d_loss: 0.0499 g_loss: 17.0056


100%|██████████| 1397/1397 [00:48<00:00, 28.62it/s]


Epoch [42/100] d_loss: 0.0012 g_loss: 85.1442


100%|██████████| 1397/1397 [00:48<00:00, 28.71it/s]


Epoch [43/100] d_loss: 0.0470 g_loss: 39.2909


100%|██████████| 1397/1397 [00:48<00:00, 28.59it/s]


Epoch [44/100] d_loss: 0.2457 g_loss: 32.1452


100%|██████████| 1397/1397 [00:48<00:00, 28.57it/s]


Epoch [45/100] d_loss: 0.0296 g_loss: 21.5396


100%|██████████| 1397/1397 [00:48<00:00, 28.58it/s]


Epoch [46/100] d_loss: 0.0425 g_loss: 49.9404


100%|██████████| 1397/1397 [00:48<00:00, 28.55it/s]


Epoch [47/100] d_loss: 0.0186 g_loss: 14.6932


100%|██████████| 1397/1397 [00:48<00:00, 28.56it/s]


Epoch [48/100] d_loss: 0.1494 g_loss: 21.7847


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [49/100] d_loss: 0.1509 g_loss: 15.2410


100%|██████████| 1397/1397 [00:48<00:00, 28.53it/s]


Epoch [50/100] d_loss: 0.0102 g_loss: 24.3565


100%|██████████| 1397/1397 [00:48<00:00, 28.52it/s]


Epoch [51/100] d_loss: 0.2039 g_loss: 15.3169


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [52/100] d_loss: 0.0095 g_loss: 22.6739


100%|██████████| 1397/1397 [00:48<00:00, 28.52it/s]


Epoch [53/100] d_loss: 0.3315 g_loss: 25.3971


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [54/100] d_loss: 0.2433 g_loss: 15.7812


100%|██████████| 1397/1397 [00:49<00:00, 28.51it/s]


Epoch [55/100] d_loss: 0.1734 g_loss: 20.4945


100%|██████████| 1397/1397 [00:48<00:00, 28.52it/s]


Epoch [56/100] d_loss: 0.0469 g_loss: 20.3039


100%|██████████| 1397/1397 [00:48<00:00, 28.52it/s]


Epoch [57/100] d_loss: 0.0697 g_loss: 42.4557


100%|██████████| 1397/1397 [00:49<00:00, 28.48it/s]


Epoch [58/100] d_loss: 0.0489 g_loss: 30.4797


100%|██████████| 1397/1397 [00:49<00:00, 28.47it/s]


Epoch [59/100] d_loss: 0.0824 g_loss: 13.8515


100%|██████████| 1397/1397 [00:49<00:00, 28.48it/s]


Epoch [60/100] d_loss: 0.1327 g_loss: 11.6446


100%|██████████| 1397/1397 [00:49<00:00, 28.46it/s]


Epoch [61/100] d_loss: 0.2557 g_loss: 13.7346


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [62/100] d_loss: 0.1450 g_loss: 13.4840


100%|██████████| 1397/1397 [00:49<00:00, 28.47it/s]


Epoch [63/100] d_loss: 0.0734 g_loss: 24.5390


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [64/100] d_loss: 0.0648 g_loss: 24.2301


100%|██████████| 1397/1397 [00:49<00:00, 28.46it/s]


Epoch [65/100] d_loss: 0.1915 g_loss: 20.6949


100%|██████████| 1397/1397 [00:49<00:00, 28.47it/s]


Epoch [66/100] d_loss: 0.0444 g_loss: 15.8573


100%|██████████| 1397/1397 [00:49<00:00, 28.45it/s]


Epoch [67/100] d_loss: 0.6578 g_loss: 15.1267


100%|██████████| 1397/1397 [00:49<00:00, 28.47it/s]


Epoch [68/100] d_loss: 0.1164 g_loss: 11.3196


100%|██████████| 1397/1397 [00:49<00:00, 28.46it/s]


Epoch [69/100] d_loss: 0.0545 g_loss: 12.4061


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [70/100] d_loss: 0.0259 g_loss: 16.6936


100%|██████████| 1397/1397 [00:48<00:00, 28.52it/s]


Epoch [71/100] d_loss: 0.0704 g_loss: 10.8239


100%|██████████| 1397/1397 [00:48<00:00, 28.51it/s]


Epoch [72/100] d_loss: 0.0480 g_loss: 8.7429


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [73/100] d_loss: 0.2262 g_loss: 7.3823


100%|██████████| 1397/1397 [00:49<00:00, 28.49it/s]


Epoch [74/100] d_loss: 0.0242 g_loss: 8.3162


100%|██████████| 1397/1397 [00:49<00:00, 28.49it/s]


Epoch [75/100] d_loss: 0.2755 g_loss: 10.0320


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [76/100] d_loss: 0.0495 g_loss: 10.9161


100%|██████████| 1397/1397 [00:49<00:00, 28.48it/s]


Epoch [77/100] d_loss: 0.0268 g_loss: 11.5400


100%|██████████| 1397/1397 [00:49<00:00, 28.49it/s]


Epoch [78/100] d_loss: 0.2058 g_loss: 10.3213


100%|██████████| 1397/1397 [00:49<00:00, 28.48it/s]


Epoch [79/100] d_loss: 0.1067 g_loss: 15.1139


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [80/100] d_loss: 0.0501 g_loss: 7.8054


100%|██████████| 1397/1397 [00:49<00:00, 28.49it/s]


Epoch [81/100] d_loss: 0.0661 g_loss: 9.5666


100%|██████████| 1397/1397 [00:49<00:00, 28.49it/s]


Epoch [82/100] d_loss: 0.0939 g_loss: 5.8791


100%|██████████| 1397/1397 [00:49<00:00, 28.51it/s]


Epoch [83/100] d_loss: 0.1225 g_loss: 9.2781


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [84/100] d_loss: 0.1134 g_loss: 10.8298


100%|██████████| 1397/1397 [00:48<00:00, 28.52it/s]


Epoch [85/100] d_loss: 0.1633 g_loss: 8.3685


100%|██████████| 1397/1397 [00:48<00:00, 28.52it/s]


Epoch [86/100] d_loss: 0.0204 g_loss: 11.7708


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [87/100] d_loss: 0.1729 g_loss: 8.5973


100%|██████████| 1397/1397 [00:49<00:00, 28.51it/s]


Epoch [88/100] d_loss: 0.0613 g_loss: 12.0412


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [89/100] d_loss: 0.1376 g_loss: 7.6325


100%|██████████| 1397/1397 [00:49<00:00, 28.50it/s]


Epoch [90/100] d_loss: 0.1164 g_loss: 8.6157


100%|██████████| 1397/1397 [00:48<00:00, 28.55it/s]


Epoch [91/100] d_loss: 0.2648 g_loss: 8.4078


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [92/100] d_loss: 0.0617 g_loss: 11.5771


100%|██████████| 1397/1397 [00:48<00:00, 28.52it/s]


Epoch [93/100] d_loss: 0.0379 g_loss: 7.5983


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]


Epoch [94/100] d_loss: 0.2214 g_loss: 10.2002


100%|██████████| 1397/1397 [00:48<00:00, 28.53it/s]


Epoch [95/100] d_loss: 0.0097 g_loss: 11.2177


100%|██████████| 1397/1397 [00:48<00:00, 28.53it/s]


Epoch [96/100] d_loss: 0.1494 g_loss: 10.5225


100%|██████████| 1397/1397 [00:48<00:00, 28.52it/s]


Epoch [97/100] d_loss: 0.3477 g_loss: 9.2556


100%|██████████| 1397/1397 [00:48<00:00, 28.55it/s]


Epoch [98/100] d_loss: 0.0851 g_loss: 8.0889


100%|██████████| 1397/1397 [00:48<00:00, 28.54it/s]

Epoch [99/100] d_loss: 0.0109 g_loss: 10.1826





# Save model

In [5]:
torch.save(generator.state_dict(), 'generator_model.pth')

# Generate new samples

In [6]:
def generate_new_music(generator=None, num_samples=5, output_dir="generated_music"):
    """
    Generate new music samples using a trained generator
    """
    # Load the trained generator if not provided
    if generator is None:
        generator = Generator()
        generator.load_state_dict(torch.load('generator_model.pth'))
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    generator.eval()
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    with torch.no_grad():
        for i in range(num_samples):
            # Generate random noise
            noise = torch.randn(1, 100).to(device)  # 100 is the LATENT_DIM
            
            # Generate music data
            fake_data = generator(noise).cpu().numpy()[0]
            
            # Convert to MIDI
            pm = pretty_midi.PrettyMIDI()
            piano_program = pretty_midi.Instrument(program=0)  # Piano
            
            # Convert piano roll to notes
            threshold = 0.5
            piano_roll = (fake_data > threshold).astype(int)
            
            # Create notes from piano roll
            for pitch in range(128):
                current_note_start = None
                for time in range(piano_roll.shape[1]):
                    if piano_roll[pitch][time] == 1 and current_note_start is None:
                        current_note_start = time
                    elif (piano_roll[pitch][time] == 0 or time == piano_roll.shape[1] - 1) and current_note_start is not None:
                        note = pretty_midi.Note(
                            velocity=100,
                            pitch=pitch,
                            start=current_note_start * 0.0625,
                            end=time * 0.0625
                        )
                        piano_program.notes.append(note)
                        current_note_start = None
            
            pm.instruments.append(piano_program)
            pm.write(f"{output_dir}/generated_music_{i+1}.mid")
    
    print(f"Generated {num_samples} music samples in {output_dir}/")

In [7]:
generate_new_music(num_samples=2)

  generator.load_state_dict(torch.load('generator_model.pth'))


Generated 2 music samples in generated_music/
