In [6]:
# Install required packages
%pip install pretty_midi tensorflow numpy matplotlib mido

import pretty_midi
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from mido import MidiFile

# Load MIDI and convert to note array
def midi_to_notes(midi_file):
    notes = []
    midi_data = pretty_midi.PrettyMIDI(midi_file)
    
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            notes.append([note.pitch, note.start, note.end, note.velocity])
    
    notes_array = np.array(notes, dtype=np.float32)

    # Check for NaN values and replace them
    if np.isnan(notes_array).any():
        print("Warning: NaN values detected in MIDI data! Replacing with zeros.")
        notes_array = np.nan_to_num(notes_array)
    
    return notes_array

# Load MIDI file
midi_file_path = "your_midi_file.mid"  
notes_array = midi_to_notes(midi_file_path)

# Normalize data
def normalize_notes(notes):
    min_values = np.min(notes, axis=0)
    max_values = np.max(notes, axis=0)

    # Prevent division by zero
    range_values = max_values - min_values
    range_values[range_values == 0] = 1  

    return (notes - min_values) / range_values, min_values, max_values

notes_array, min_values, max_values = normalize_notes(notes_array)

# Prepare training data (sequence-based)
sequence_length = 50  # Number of previous notes to predict the next one
X, y = [], []

for i in range(len(notes_array) - sequence_length):
    X.append(notes_array[i:i+sequence_length])
    y.append(notes_array[i+sequence_length])

X = np.array(X)
y = np.array(y)

# Check for NaN in training data
if np.isnan(X).any() or np.isnan(y).any():
    print("Warning: NaN found in training data! Fixing...")
    X = np.nan_to_num(X)
    y = np.nan_to_num(y)

# Build the LSTM model
model = tf.keras.Sequential([
    tf.keras.layers.LSTM(128, return_sequences=True, input_shape=(sequence_length, 4)),
    tf.keras.layers.LSTM(64),
    tf.keras.layers.Dense(4, activation='linear')  # Predict pitch, start, end, velocity
])
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the model
model.fit(X, y, epochs=50, batch_size=32)

# Generate new notes
def generate_notes(model, seed_sequence, num_notes=100):
    generated = []
    current_sequence = seed_sequence.copy()

    for _ in range(num_notes):
        prediction = model.predict(current_sequence[np.newaxis, :, :])[0]
        
        if np.isnan(prediction).any():
            print("Warning: NaN detected in model output! Replacing with zeros.")
            prediction = np.nan_to_num(prediction)
        
        generated.append(prediction)
        current_sequence = np.vstack([current_sequence[1:], prediction])  # Shift sequence
    
    return np.array(generated)

# Generate new MIDI notes
seed_sequence = X[-1]  # Start with last sequence from training
generated_notes = generate_notes(model, seed_sequence)

# Denormalize the notes
def denormalize_notes(notes, min_values, max_values):
    return notes * (max_values - min_values) + min_values

generated_notes = denormalize_notes(generated_notes, min_values, max_values)

# Ensure valid pitch range and durations
generated_notes[:, 0] = np.clip(generated_notes[:, 0], 0, 127)  # Ensure valid pitch
generated_notes[:, 2] += 0.05  # Ensure valid note durations

# Convert predictions back to MIDI format
def notes_to_midi(notes, output_file):
    midi_data = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(0)  # Piano

    for note in notes:
        try:
            pitch = int(np.round(note[0]))  # Round pitch values
            start = float(note[1])
            end = float(note[2])
            velocity = int(np.clip(note[3], 0, 127))  # Clip velocity to valid range

            # Ensure valid timing
            if start >= end:
                end = start + 0.1  # Adjust to avoid zero-duration notes

            pm_note = pretty_midi.Note(velocity, pitch, start, end)
            instrument.notes.append(pm_note)
        
        except ValueError:
            print("Skipping invalid note:", note)

    midi_data.instruments.append(instrument)
    midi_data.write(output_file)

    print(f"MIDI saved: {output_file}")

notes_to_midi(generated_notes, "output.mid")

# Plot MIDI notes for debugging
plt.scatter(generated_notes[:, 1], generated_notes[:, 0], c=generated_notes[:, 3], cmap='coolwarm')
plt.xlabel("Start Time")
plt.ylabel("Pitch")
plt.title("Generated MIDI Notes")
plt.colorbar(label="Velocity")
plt.show()

# Play the MIDI file to check if it sounds correct
midi = MidiFile("output.mid")
for msg in midi.play():
    print(msg)

print("🎵 MIDI generation complete! Saved as 'output.mid'.")


Note: you may need to restart the kernel to use updated packages.
Epoch 1/50



[notice] A new release of pip available: 22.2.2 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - loss: 0.5481
Epoch 2/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step - loss: 0.2808
Epoch 3/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step - loss: 0.0777
Epoch 4/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 57ms/step - loss: 0.0529
Epoch 5/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 54ms/step - loss: 0.0807
Epoch 6/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 59ms/step - loss: 0.0323
Epoch 7/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step - loss: 0.0225
Epoch 8/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 59ms/step - loss: 0.0325
Epoch 9/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 56ms/step - loss: 0.0405
Epoch 10/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 60ms/step - loss: 0.0409
Epoch 11/50
[1m1/1[0m [32m━━━