# 256 LSTM

In [None]:
# Step 1: Install Necessary Packages
# Install pretty_midi
!pip install pretty_midi
# Upgrade TensorFlow to the latest version compatible with your environment
!pip install --upgrade tensorflow

# Step 2: Import Essential Libraries
import os
import glob
import numpy as np
import pretty_midi
import tensorflow as tf
from tensorflow.keras.utils import Sequence
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, LSTM, Bidirectional, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import random
import time
import pickle

# Step 3: Suppress Unnecessary Warnings and Set TensorFlow Logging Level
import warnings
warnings.filterwarnings('ignore')
tf.get_logger().setLevel('ERROR')

# Step 4: Check for Available GPUs to Leverage for Training
gpu_devices = tf.config.list_physical_devices('GPU')
print("Number of GPUs detected:", len(gpu_devices))
if len(gpu_devices) == 0:
    print("Warning: No GPU detected. Training might be slow.")

# Step 5: Enable Mixed Precision Training for Improved Performance if Supported
try:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy('mixed_float16')
    print("Mixed precision training is enabled.")
except Exception as e:
    print("Mixed precision training is not supported on this device or encountered an error.")
    print(f"Error: {e}")

# Step 6: Set Random Seeds to Ensure Reproducibility
np.random.seed(123)
tf.random.set_seed(123)

# Step 7: Download and Extract the MAESTRO Dataset
dataset_directory = 'maestro_dataset_lstm'
if not os.path.exists(dataset_directory):
    print("\nDownloading the MAESTRO dataset...")
    !wget -q --show-progress https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip

    print("Extracting the MAESTRO dataset...")
    !unzip -q maestro-v3.0.0-midi.zip -d maestro_dataset_lstm

# Step 8: Collect All MIDI Files
all_midi_files = glob.glob(os.path.join(dataset_directory, '**', '*.midi'), recursive=True)
print(f"\nTotal MIDI files discovered: {len(all_midi_files)}")

# Optionally limit the number of MIDI files to manage training duration
all_midi_files = all_midi_files[:1200]  # Change this number based on how many MIDI files you want the model to train
print(f"Utilizing {len(all_midi_files)} MIDI files for the training process.")

# Step 9: Function to Transform MIDI Files into Note Event Sequences

def extract_note_events(midi_path):
    """Transforms a MIDI file into a chronological sequence of note events."""
    try:
        midi_object = pretty_midi.PrettyMIDI(midi_path)
        note_events = []
        for track in midi_object.instruments:
            if not track.is_drum:
                for note in track.notes:
                    note_events.append({
                        'note': note.pitch,
                        'onset': note.start,
                        'offset': note.end,
                        'intensity': note.velocity
                    })
        # Arrange notes by their onset times
        note_events.sort(key=lambda event: event['onset'])
        return midi_object, note_events
    except Exception as error:
        print(f"Failed to process {midi_path}: {error}")
        return None, None


# Function to Discretize Notes into Fixed Time Steps Aligned with Beats
def discretize_notes(note_events, midi_obj, interval=0.25):
    """
    Converts note events into fixed time-step sequences based on beats.

    Args:
        note_events (list): List of note event dictionaries.
        midi_obj (PrettyMIDI): Parsed MIDI object.
        interval (float): Time interval in beats for discretization.

    Returns:
        discretized_sequence (list): List of active pitches per time step.
    """
    beats = midi_obj.get_beats()
    if len(beats) < 2:
        # Use fixed intervals if beat detection is unreliable
        beats = np.arange(0, midi_obj.get_end_time(), interval)
    discretized_sequence = []
    for i in range(len(beats) - 1):
        start = beats[i]
        end = beats[i + 1]
        active_pitches = [event['note'] for event in note_events if event['onset'] < end and event['offset'] > start]
        discretized_sequence.append(active_pitches)
    return discretized_sequence

# Step 10: Process All MIDI Files to Extract Sequences
all_note_sequences = []
sequence_durations = []

print("\nProcessing MIDI files to extract note sequences...")
start_processing_time = time.time()
for midi_file in all_midi_files:
    midi_obj, note_events = extract_note_events(midi_file)
    if note_events:
        sequence = discretize_notes(note_events, midi_obj, interval=0.25)
        sequence_durations.append(len(sequence))
        all_note_sequences.append(sequence)
    else:
        print(f"Skipping {midi_file} due to processing errors.")
processing_time_elapsed = time.time() - start_processing_time
print(f"Completed processing in {processing_time_elapsed:.2f} seconds.")

print(f"\nTotal note sequences extracted: {len(all_note_sequences)}")
print(f"Average sequence length: {np.mean(sequence_durations):.2f} steps")
print(f"Sequence lengths range from {np.min(sequence_durations)} to {np.max(sequence_durations)} steps")

# Step 11: Create Pitch Mappings
collected_pitches = [pitch for seq in all_note_sequences for timestep in seq for pitch in timestep]
unique_pitches_sorted = sorted(set(collected_pitches))  # Typically 88 for standard piano

# Mapping pitches to unique integer indices
pitch_to_index_map = {pitch: idx for idx, pitch in enumerate(unique_pitches_sorted)}
index_to_pitch_map = {idx: pitch for idx, pitch in enumerate(unique_pitches_sorted)}
total_pitches = len(unique_pitches_sorted)

print(f"\nTotal unique pitches identified: {total_pitches}")

# Save pitch mappings for future reference
with open('pitch_to_index_lstm.pkl', 'wb') as pitch_idx_file:
    pickle.dump(pitch_to_index_map, pitch_idx_file)
with open('index_to_pitch_lstm.pkl', 'wb') as idx_pitch_file:
    pickle.dump(index_to_pitch_map, idx_pitch_file)

# Step 12: Define Hyperparameters
input_window = 64  # Previous time steps to consider
num_epochs = 40
early_stop_patience = 5

# Step 13: Split Data into Training, Validation, and Testing Sets
train_data, test_data = train_test_split(all_note_sequences, test_size=0.2, random_state=42)
train_data, validation_data = train_test_split(train_data, test_size=0.1, random_state=42)

print(f"Number of training sequences: {len(train_data)}")
print(f"Number of validation sequences: {len(validation_data)}")
print(f"Number of testing sequences: {len(test_data)}")

# Step 14: Define a Data Generator Class
class LSTM_DataGenerator(Sequence):
    def __init__(self, dataset, seq_len, batch_sz, num_pitches, pitch_map, shuffle=True):
        """
        Initializes the data generator for LSTM training.

        Args:
            dataset (list): List of sequences, each containing active pitches per time step.
            seq_len (int): Number of previous time steps to use as input.
            batch_sz (int): Size of each data batch.
            num_pitches (int): Total number of unique pitches.
            pitch_map (dict): Mapping from pitch to index.
            shuffle (bool): Whether to shuffle data after each epoch.
        """
        self.dataset = dataset
        self.sequence_length = seq_len
        self.batch_size = batch_sz
        self.num_pitches = num_pitches
        self.pitch_mapping = pitch_map
        self.shuffle = shuffle
        self.sample_pairs = self._prepare_samples()
        self.on_epoch_end()

    def _prepare_samples(self):
        samples = []
        for seq in self.dataset:
            if len(seq) < self.sequence_length + 1:
                continue
            for i in range(len(seq) - self.sequence_length):
                input_seq = seq[i:i + self.sequence_length]
                target_seq = seq[i + 1:i + self.sequence_length + 1]
                samples.append((input_seq, target_seq))
        return samples

    def __len__(self):
        return int(np.ceil(len(self.sample_pairs) / self.batch_size))

    def __getitem__(self, idx):
        batch_samples = self.sample_pairs[idx * self.batch_size : (idx + 1) * self.batch_size]
        X = np.zeros((len(batch_samples), self.sequence_length, self.num_pitches), dtype=np.float32)
        Y = np.zeros((len(batch_samples), self.sequence_length, self.num_pitches), dtype=np.float32)

        for i, (input_seq, target_seq) in enumerate(batch_samples):
            for t, pitches in enumerate(input_seq):
                for pitch in pitches:
                    if pitch in self.pitch_mapping:
                        pitch_idx = self.pitch_mapping[pitch]
                        if 0 <= pitch_idx < self.num_pitches:
                            X[i, t, pitch_idx] = 1.0
            for t, pitches in enumerate(target_seq):
                for pitch in pitches:
                    if pitch in self.pitch_mapping:
                        pitch_idx = self.pitch_mapping[pitch]
                        if 0 <= pitch_idx < self.num_pitches:
                            Y[i, t, pitch_idx] = 1.0
        return X, Y

    def on_epoch_end(self):
        if self.shuffle:
            random.shuffle(self.sample_pairs)

# Step 15: Initialize Data Generators
training_generator = LSTM_DataGenerator(train_data, input_window, batch_size, total_pitches, pitch_to_index_map)
validation_generator = LSTM_DataGenerator(validation_data, input_window, batch_size, total_pitches, pitch_to_index_map, shuffle=False)
testing_generator = LSTM_DataGenerator(test_data, input_window, batch_size, total_pitches, pitch_to_index_map, shuffle=False)

# Step 16: Define Dice Loss Function
# The following implementation was adapted from Stack Overflow (2024)
# Reference: Correct Implementation of Dice Loss in Tensorflow / Keras.
# Available at: https://stackoverflow.com/questions/72195156/correct-implementation-of-dice-loss-in-tensorflow-keras

def dice_loss(y_true, y_pred, smooth=1):
    """
    Computes the Dice Loss for multi-label classification.

    Args:
        y_true (tensor): Ground truth binary labels.
        y_pred (tensor): Predicted probabilities.
        smooth (float): Smoothing factor to prevent division by zero.

    Returns:
        tensor: Dice loss value.
    """
    y_true_f = tf.reshape(y_true, [-1, total_pitches])
    y_pred_f = tf.reshape(y_pred, [-1, total_pitches])
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=1)
    sum_labels = tf.reduce_sum(y_true_f, axis=1)
    sum_preds = tf.reduce_sum(y_pred_f, axis=1)
    dice_coeff = (2. * intersection + smooth) / (sum_labels + sum_preds + smooth)
    return 1 - tf.reduce_mean(dice_coeff)

# Step 17: Build the LSTM Model Architecture
# The following implementation was inspired from Tensor Flow Sample (2024)
# Reference: tf.keras.layers.Bidirectional.
# Available at: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Bidirectional

print("\nConstructing the LSTM-based music generation model...")

# Define the input layer with the specified sequence length and number of pitches
input_tensor = Input(shape=(input_window, total_pitches), name='lstm_input_layer')

# First Bidirectional LSTM Layer with Enhanced Units and Dropout
first_bidirectional_lstm = Bidirectional(
    LSTM(units=256, return_sequences=True, dropout=0.3),
    name='bidirectional_lstm_layer_1'
)(input_tensor)

# Batch Normalization for Stabilizing Learning
first_batch_norm = BatchNormalization(name='batch_normalization_1')(first_bidirectional_lstm)

# Second Bidirectional LSTM Layer with Increased Units
second_bidirectional_lstm = Bidirectional(
    LSTM(units=256, return_sequences=True, dropout=0.3),
    name='bidirectional_lstm_layer_2'
)(first_batch_norm)

# Batch Normalization
second_batch_norm = BatchNormalization(name='batch_normalization_2')(second_bidirectional_lstm)

# Third Bidirectional LSTM Layer to Capture Deeper Temporal Patterns
third_bidirectional_lstm = Bidirectional(
    LSTM(units=256, return_sequences=True, dropout=0.3),
    name='bidirectional_lstm_layer_3'
)(second_batch_norm)

# Batch Normalization
third_batch_norm = BatchNormalization(name='batch_normalization_3')(third_bidirectional_lstm)

# Dense Output Layer with Sigmoid Activation for Multi-Label Prediction
output_layer = Dense(total_pitches, activation='sigmoid', name='lstm_output_layer')(third_batch_norm)

# Assemble the Complete Model
lstm_model = Model(inputs=input_tensor, outputs=output_layer)
print("Model architecture successfully constructed.\n")

# Step 18: Compile the Model with Dice Loss and Built-in Metrics

# Compile the model with Dice Loss and built-in metrics
lstm_model.compile(
    loss=dice_loss,
    optimizer=Adam(learning_rate=0.001),
    metrics=[
        'binary_accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc')
    ]
)

# Display the model architecture
lstm_model.summary()

# Step 19: Set Up Callbacks for Training
checkpoint_directory = './lstm_model_checkpoints'
os.makedirs(checkpoint_directory, exist_ok=True)

# ModelCheckpoint to save the best model based on validation loss
checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join(checkpoint_directory, 'best_lstm_model.keras'),
    save_weights_only=False,
    monitor='val_loss',
    mode='min',
    save_best_only=True,
    verbose=1
)

# EarlyStopping to prevent overfitting by halting training when validation loss stops improving
early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=early_stop_patience,
    verbose=1,
    restore_best_weights=True
)

# ReduceLROnPlateau to reduce learning rate when a metric has stopped improving
reduce_lr_callback = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=2,
    verbose=1,
    min_lr=1e-6
)

# Aggregate all callbacks into a list
training_callbacks = [checkpoint_callback, early_stopping_callback, reduce_lr_callback]

# Step 20: Serialize and Save Training Sequences for Generation
def save_training_seeds(training_sequences, window_length):
    """
    Saves processed training sequences to a pickle file for seed selection during generation.

    Args:
        training_sequences (list): List of training sequences with integer-encoded pitches.
        window_length (int): The length of the input sequence window.

    Saves:
        'lstm_train_seeds.pkl': Pickle file containing processed training sequences.
    """
    serialized_seeds = []
    for sequence in training_sequences:
        encoded_sequence = []
        for timestep in sequence:
            encoded_timestep = [pitch_idx for pitch_idx in timestep if pitch_idx in pitch_to_index_map]
            if encoded_timestep:
                encoded_sequence.append(encoded_timestep)
        if len(encoded_sequence) >= window_length:
            serialized_seeds.append(encoded_sequence)
    with open('lstm_train_seeds.pkl', 'wb') as f:
        pickle.dump(serialized_seeds, f)
    print("Training sequences have been serialized and saved to 'lstm_train_seeds.pkl'.")

# Execute the serialization of training sequences
save_training_seeds(train_data, input_window)

# Step 21: Begin Model Training
print("\nCommencing the training process for the LSTM model...")
training_history = lstm_model.fit(
    training_generator,
    validation_data=validation_generator,
    epochs=num_epochs,
    callbacks=training_callbacks,
    verbose=1
)

# Step 22: Evaluate Model Performance on the Test Set
print("\nAssessing model performance on the test dataset...")
test_metrics = lstm_model.evaluate(testing_generator, verbose=1)
print(f"Test Loss: {test_metrics[0]}")
print(f"Test Binary Accuracy: {test_metrics[1]}")
print(f"Test Precision: {test_metrics[2]}")
print(f"Test Recall: {test_metrics[3]}")
print(f"Test AUC: {test_metrics[4]}")

# Step 23: Save the Final Trained Model
lstm_model.save('final_lstm_music_model.keras')
print("The trained LSTM model has been saved as 'final_lstm_music_model.keras'.")

# Step 24: Visualize Training and Validation Metrics
def plot_training_metrics(history):
    """
    Plots the training and validation metrics over each epoch.

    Args:
        history (History): Keras History object containing training metrics.
    """
    metrics_to_plot = ['loss', 'binary_accuracy', 'precision', 'recall', 'auc']

    plt.figure(figsize=(25, 15))

    for idx, metric in enumerate(metrics_to_plot, 1):
        plt.subplot(3, 2, idx)
        plt.plot(history.history[metric], label=f'Training {metric}')
        plt.plot(history.history[f'val_{metric}'], label=f'Validation {metric}')
        plt.title(f'{metric.replace("_", " ").title()} Over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel(metric.replace("_", " ").title())
        plt.legend()

    plt.tight_layout()
    plt.show()

# Generate plots for training history
plot_training_metrics(training_history)
