# 256 GRU






In [None]:
# Install essential packages
!pip install pretty_midi
!pip install --upgrade tensorflow

# Import necessary modules
import os
import glob
import numpy as np
import pretty_midi
import tensorflow as tf
from tensorflow.keras.utils import Sequence
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, GRU, Bidirectional
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import random
import time
import pickle

# Suppress warnings and configure TensorFlow logging
import warnings
warnings.filterwarnings('ignore')
tf.get_logger().setLevel('ERROR')

# Verify GPU availability for TensorFlow
gpu_count = len(tf.config.list_physical_devices('GPU'))
print("Available GPUs:", gpu_count)
if gpu_count == 0:
    print("No GPU detected. Please select a GPU runtime.")

# Activate mixed precision training if supported
try:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy('mixed_float16')
    print("Mixed precision training is active.")
except Exception as e:
    print("Mixed precision training could not be enabled.")
    print(f"Error: {e}")

# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Step 1: Acquire the MAESTRO dataset
dataset_directory = 'maestro_dataset'
if not os.path.exists(dataset_directory):
    print("\nInitiating download of the MAESTRO dataset...")
    !wget -q --show-progress https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip

    # Step 2: Extract the dataset
    print("Extracting the MAESTRO dataset...")
    !unzip -q maestro-v3.0.0-midi.zip -d maestro_dataset

# Step 3: Preprocess the MIDI files
midi_file_paths = glob.glob(os.path.join(dataset_directory, '**', '*.midi'), recursive=True)
print(f"\nTotal MIDI files located: {len(midi_file_paths)}")

# Optionally limit the number of MIDI files for quicker training.
# After many experiments with different MIDI files, the model does not showcase much improvement when more MIDI files are used.
midi_file_paths = midi_file_paths[:1200]
print(f"Selected {len(midi_file_paths)} MIDI files for model training.")

# Function to transform MIDI files into note event sequences
def extract_note_events(midi_path):
    """Transforms a MIDI file into a chronological sequence of note events."""
    try:
        midi_object = pretty_midi.PrettyMIDI(midi_path)
        note_events = []
        for track in midi_object.instruments:
            if not track.is_drum:
                for note in track.notes:
                    note_events.append({
                        'note': note.pitch,
                        'onset': note.start,
                        'offset': note.end,
                        'intensity': note.velocity
                    })
        # Arrange notes by their onset times
        note_events.sort(key=lambda event: event['onset'])
        return midi_object, note_events
    except Exception as error:
        print(f"Failed to process {midi_path}: {error}")
        return None, None

# Function to discretize note events based on musical beats
def discretize_notes(note_events, midi_object, beat_interval=0.5):
    """Discretizes note events into fixed intervals aligned with beats."""
    beats = midi_object.get_beats()
    if len(beats) < 2:
        # Default to fixed intervals if beat detection fails
        beats = np.arange(0, midi_object.get_end_time(), beat_interval)
    discretized = []
    for i in range(len(beats) - 1):
        current_start = beats[i]
        current_end = beats[i + 1]
        active_notes = [note['note'] for note in note_events
                        if note['onset'] < current_end and note['offset'] > current_start]
        discretized.append(active_notes)
    return discretized

# Aggregate all discretized note sequences
all_discretized_sequences = []
sequence_length_records = []

print("\nCommencing MIDI file processing...")
start_timer = time.time()
for path in midi_file_paths:
    midi_obj, notes = extract_note_events(path)
    if notes:
        discretized_sequence = discretize_notes(notes, midi_obj)
        sequence_length_records.append(len(discretized_sequence))
        all_discretized_sequences.append(discretized_sequence)
    else:
        print(f"Skipping {path} due to processing issues.")
processing_duration = time.time() - start_timer
print(f"Processing completed in {processing_duration:.2f} seconds.")

print(f"\nTotal sequences generated: {len(all_discretized_sequences)}")
print(f"Average sequence length: {np.mean(sequence_length_records):.2f} steps")
print(f"Sequence lengths range from {np.min(sequence_length_records)} to {np.max(sequence_length_records)} steps")

# Generate mappings between MIDI pitches and integer indices
all_pitch_values = [pitch for seq in all_discretized_sequences for timestep in seq for pitch in timestep]
unique_pitches = sorted(set(all_pitch_values))  # Typically 88 for piano

pitch_to_index = {pitch: idx for idx, pitch in enumerate(unique_pitches)}
index_to_pitch = {idx: pitch for idx, pitch in enumerate(unique_pitches)}
total_pitches = len(unique_pitches)  # Expected to be 88

print(f"\nUnique pitch count: {total_pitches}")
print(f"Pitch indexing configured for {total_pitches} pitches.")

# Persist pitch mappings for future reference
with open('pitch_to_index.pkl', 'wb') as file:
    pickle.dump(pitch_to_index, file)
with open('index_to_pitch.pkl', 'wb') as file:
    pickle.dump(index_to_pitch, file)

# Define training hyperparameters
past_steps = 64
batch_size = 64
num_epochs = 40
early_stop_patience = 10

# Partition the dataset to prevent data leakage
training_set, testing_set = train_test_split(all_discretized_sequences, test_size=0.2, random_state=42)
training_set, validation_set = train_test_split(training_set, test_size=0.1, random_state=42)

print(f"Number of training sequences: {len(training_set)}")
print(f"Number of validation sequences: {len(validation_set)}")
print(f"Number of testing sequences: {len(testing_set)}")

# Step 4: Define a Data Generator Class
class MusicSequenceLoader(Sequence):
    def __init__(self, dataset, sequence_len, batch_sz, pitch_count, pitch_map, randomize=True):
        """
        Initializes the data generator for GRU training.

        Args:
            dataset (list): List of sequences, each containing active pitches per time step.
            sequence_len (int): Number of previous time steps to use as input.
            batch_sz (int): Size of each data batch.
            pitch_count (int): Total number of unique pitches.
            pitch_map (dict): Mapping from pitch to index.
            randomize (bool): Whether to shuffle data after each epoch.
        """
        self.dataset = dataset
        self.sequence_length = sequence_len
        self.batch_size = batch_sz
        self.num_pitches = pitch_count
        self.pitch_mapping = pitch_map
        self.is_random = randomize
        self.sample_indices = self._prepare_samples()
        self.on_epoch_end()

    def _prepare_samples(self):
        samples = []
        for sequence in self.dataset:
            if len(sequence) < self.sequence_length + 1:
                continue
            for i in range(len(sequence) - self.sequence_length):
                input_seq = sequence[i:i + self.sequence_length]
                target_seq = sequence[i + 1:i + self.sequence_length + 1]
                samples.append((input_seq, target_seq))
        return samples

    def __len__(self):
        return int(np.ceil(len(self.sample_indices) / self.batch_size))

    def __getitem__(self, idx):
        batch_samples = self.sample_indices[idx * self.batch_size : (idx + 1) * self.batch_size]
        X = np.zeros((len(batch_samples), self.sequence_length, self.num_pitches), dtype=np.float32)
        Y = np.zeros((len(batch_samples), self.sequence_length, self.num_pitches), dtype=np.float32)

        for i, (input_seq, target_seq) in enumerate(batch_samples):
            for t, pitches in enumerate(input_seq):
                for pitch in pitches:
                    if pitch in self.pitch_mapping:
                        pitch_idx = self.pitch_mapping[pitch]
                        if 0 <= pitch_idx < self.num_pitches:
                            X[i, t, pitch_idx] = 1.0
            for t, pitches in enumerate(target_seq):
                for pitch in pitches:
                    if pitch in self.pitch_mapping:
                        pitch_idx = self.pitch_mapping[pitch]
                        if 0 <= pitch_idx < self.num_pitches:
                            Y[i, t, pitch_idx] = 1.0
        return X, Y

    def on_epoch_end(self):
        if self.is_random:
            random.shuffle(self.sample_indices)

# Step 5: Instantiate Data Generators for Training, Validation, and Testing
training_loader = MusicSequenceLoader(training_set, past_steps, batch_size, total_pitches, pitch_to_index)
validation_loader = MusicSequenceLoader(validation_set, past_steps, batch_size, total_pitches, pitch_to_index, randomize=False)
testing_loader = MusicSequenceLoader(testing_set, past_steps, batch_size, total_pitches, pitch_to_index, randomize=False)

# Step 6: Define Dice Loss Function
# The following implementation was adapted from Stack Overflow (2024)
# Reference: Correct Implementation of Dice Loss in Tensorflow / Keras.
# Available at: https://stackoverflow.com/questions/72195156/correct-implementation-of-dice-loss-in-tensorflow-keras

def dice_loss(y_true, y_pred, smooth=1):
    """
    Computes the Dice Loss for multi-label classification.

    Args:
        y_true (tensor): Ground truth binary labels.
        y_pred (tensor): Predicted probabilities.
        smooth (float): Smoothing factor to prevent division by zero.

    Returns:
        tensor: Dice loss value.
    """
    y_true_f = tf.reshape(y_true, [-1, total_pitches])
    y_pred_f = tf.reshape(y_pred, [-1, total_pitches])
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=1)
    sum_labels = tf.reduce_sum(y_true_f, axis=1)
    sum_preds = tf.reduce_sum(y_pred_f, axis=1)
    dice_coeff = (2. * intersection + smooth) / (sum_labels + sum_preds + smooth)
    return 1 - tf.reduce_mean(dice_coeff)

# Step 7: Build the GRU Model Architecture
print("\nAssembling the GRU-based music generation model...")

# Define the input layer with the specified sequence length and number of pitches
input_layer = Input(shape=(past_steps, total_pitches), name='music_input')

# First Bidirectional GRU Layer
first_gru_layer = Bidirectional(
    GRU(units=256, return_sequences=True, dropout=0.3), #Mos harro me kthy dropout 0.3
    name='bi_gru_layer_1'
)(input_layer)

# Second Bidirectional GRU Layer
second_gru_layer = Bidirectional(
    GRU(units=256, return_sequences=True, dropout=0.3),
    name='bi_gru_layer_2'
)(first_gru_layer)

# Third Bidirectional GRU Layer
third_gru_layer = Bidirectional(
    GRU(units=256, return_sequences=True, dropout=0.3),
    name='bi_gru_layer_3'
)(second_gru_layer)

# Output Dense Layer with sigmoid activation for multi-label prediction
output_layer = Dense(total_pitches, activation='sigmoid', name='music_output')(third_gru_layer)

# Define the complete model
music_model = Model(inputs=input_layer, outputs=output_layer)

# Step 8: Compile the Model with Dice Loss and Built-in Metrics
# Define balance factor for handling class imbalance (if needed)
balance_factor = 2.0  # Currently not used in Dice Loss

# Compile the model with Dice Loss and built-in metrics
music_model.compile(
    loss=dice_loss,
    optimizer=Adam(learning_rate=0.001),
    metrics=[
        'binary_accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc')
    ]
)

# Display the model's architecture summary
music_model.summary()

# Step 9: Configure Callbacks for Model Training
checkpoint_directory = './model_checkpoints'
os.makedirs(checkpoint_directory, exist_ok=True)

# Checkpoint to save the best model based on validation loss
checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join(checkpoint_directory, 'best_music_model.keras'),
    save_weights_only=False,
    monitor='val_loss',
    mode='min',
    save_best_only=True,
    verbose=1
)

# Early stopping to halt training when validation loss stops improving
early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=early_stop_patience,
    verbose=1,
    restore_best_weights=True
)

# Reduce learning rate when a metric has stopped improving
lr_reduction_callback = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=2,
    verbose=1,
    min_lr=1e-6
)

# Aggregate all callbacks into a list
training_callbacks = [checkpoint_callback, early_stopping_callback, lr_reduction_callback]

# Step 10: Serialize and Save Training Sequences for Generation
def serialize_training_sequences(training_sequences, window_length):
    """
    Saves encoded training sequences to a pickle file for later use in music generation.

    Args:
        training_sequences (list): List of training sequences with integer-encoded pitches.
        window_length (int): The length of the input sequence window.

    Saves:
        'serialized_train_sequences.pkl': Pickle file containing processed training sequences.
    """
    serialized_sequences = []
    for sequence in training_sequences:
        encoded_sequence = []
        for timestep in sequence:
            encoded_timestep = [pitch_idx for pitch_idx in timestep if pitch_idx in pitch_to_index]
            if encoded_timestep:
                encoded_sequence.append(encoded_timestep)
        if len(encoded_sequence) >= window_length:
            serialized_sequences.append(encoded_sequence)
    with open('serialized_train_sequences.pkl', 'wb') as file:
        pickle.dump(serialized_sequences, file)
    print("Training sequences have been serialized and saved to 'serialized_train_sequences.pkl'.")

# Execute the serialization of training sequences
serialize_training_sequences(training_set, past_steps)

# Step 11: Begin Model Training
print("\nCommencing model training...")
training_history = music_model.fit(
    training_loader,
    validation_data=validation_loader,
    epochs=num_epochs,
    callbacks=training_callbacks,
    verbose=1
)

# Step 12: Evaluate Model Performance on the Testing Dataset
print("\nAssessing model performance on the testing dataset...")
evaluation_results = music_model.evaluate(testing_loader, verbose=1)
print(f"Test Loss: {evaluation_results[0]}")
print(f"Test Binary Accuracy: {evaluation_results[1]}")
print(f"Test Precision: {evaluation_results[2]}")
print(f"Test Recall: {evaluation_results[3]}")
print(f"Test AUC: {evaluation_results[4]}")

# Step 13: Persist the Final Trained Model for Future Use
music_model.save('final_gru_music_model.keras')
print("The trained model has been saved as 'final_gru_music_model.keras'.")

# Step 14: Function to Visualize Training and Validation Metrics
def visualize_training_metrics(history):
    """
    Plots the training and validation metrics over each epoch.

    Args:
        history (History): Keras History object containing training metrics.
    """
    metrics_to_plot = ['loss', 'binary_accuracy', 'precision', 'recall', 'auc']

    plt.figure(figsize=(25, 15))

    for idx, metric in enumerate(metrics_to_plot, 1):
        plt.subplot(3, 2, idx)
        plt.plot(history.history[metric], label=f'Training {metric}')
        plt.plot(history.history[f'val_{metric}'], label=f'Validation {metric}')
        plt.title(f'{metric.replace("_", " ").title()} Progress')
        plt.xlabel('Epoch')
        plt.ylabel(metric.replace("_", " ").title())
        plt.legend()

    plt.tight_layout()
    plt.show()

# Generate plots for training history
visualize_training_metrics(training_history)
