# Transformer Model






In [None]:
# Install essential packages if not already present
!pip install pretty_midi tensorflow

# Import necessary modules
import os
import glob
import numpy as np
import pretty_midi
import tensorflow as tf
from tensorflow.keras.utils import Sequence
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Dense, Dropout, LayerNormalization,
    MultiHeadAttention, GlobalAveragePooling1D
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
)
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import random
import time
import pickle

# Suppress warnings and configure TensorFlow logging
import warnings
warnings.filterwarnings('ignore')
tf.get_logger().setLevel('ERROR')

# Verify GPU availability for TensorFlow
gpu_devices = tf.config.list_physical_devices('GPU')
print("Number of GPUs detected:", len(gpu_devices))
if len(gpu_devices) == 0:
    print("Warning: No GPU detected. Training might be slow.")

# Activate mixed precision training if supported
try:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy('mixed_float16')
    print("Mixed precision training is enabled.")
except Exception as e:
    print("Mixed precision training is not supported on this device or encountered an error.")
    print(f"Error: {e}")

# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Step 1: Acquire the MAESTRO dataset
dataset_directory = 'maestro_dataset_transformer'
if not os.path.exists(dataset_directory):
    print("\nInitiating download of the MAESTRO dataset...")
    !wget -q --show-progress https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip

    # Step 2: Extract the dataset
    print("Extracting the MAESTRO dataset...")
    !unzip -q maestro-v3.0.0-midi.zip -d maestro_dataset_transformer

# Step 3: Preprocess the MIDI files
midi_file_paths = glob.glob(os.path.join(dataset_directory, '**', '*.midi'), recursive=True)
print(f"\nTotal MIDI files located: {len(midi_file_paths)}")

# Optionally limit the number of MIDI files for quicker training.
# It's recommended to use as many MIDI files as possible to improve model performance.
midi_file_paths = midi_file_paths[:100]  # Adjust as needed
print(f"Selected {len(midi_file_paths)} MIDI files for model training.")

# Function to transform MIDI files into note event sequences
def extract_note_events(midi_path):
    """Transforms a MIDI file into a chronological sequence of note events."""
    try:
        midi_object = pretty_midi.PrettyMIDI(midi_path)
        note_events = []
        for track in midi_object.instruments:
            if not track.is_drum:
                for note in track.notes:
                    note_events.append({
                        'note': note.pitch,
                        'onset': note.start,
                        'offset': note.end,
                        'intensity': note.velocity
                    })
        # Arrange notes by their onset times
        note_events.sort(key=lambda event: event['onset'])
        return midi_object, note_events
    except Exception as error:
        print(f"Failed to process {midi_path}: {error}")
        return None, None

# Function to discretize note events based on musical beats
def discretize_notes(note_events, midi_object, beat_interval=0.5):
    """Discretizes note events into fixed intervals aligned with beats."""
    beats = midi_object.get_beats()
    if len(beats) < 2:
        # Default to fixed intervals if beat detection fails
        beats = np.arange(0, midi_object.get_end_time(), beat_interval)
    discretized = []
    for i in range(len(beats) - 1):
        current_start = beats[i]
        current_end = beats[i + 1]
        active_notes = [note['note'] for note in note_events
                        if note['onset'] < current_end and note['offset'] > current_start]
        discretized.append(active_notes)
    return discretized

# Aggregate all discretized note sequences
all_discretized_sequences = []
sequence_length_records = []

print("\nCommencing MIDI file processing...")
start_timer = time.time()
for path in midi_file_paths:
    midi_obj, notes = extract_note_events(path)
    if notes:
        discretized_sequence = discretize_notes(notes, midi_obj)
        sequence_length_records.append(len(discretized_sequence))
        all_discretized_sequences.append(discretized_sequence)
    else:
        print(f"Skipping {path} due to processing issues.")
processing_duration = time.time() - start_timer
print(f"Processing completed in {processing_duration:.2f} seconds.")

print(f"\nTotal sequences generated: {len(all_discretized_sequences)}")
print(f"Average sequence length: {np.mean(sequence_length_records):.2f} steps")
print(f"Sequence lengths range from {np.min(sequence_length_records)} to {np.max(sequence_length_records)} steps")

# Generate mappings between MIDI pitches and integer indices
all_pitch_values = [pitch for seq in all_discretized_sequences for timestep in seq for pitch in timestep]
unique_pitches = sorted(set(all_pitch_values))  # Typically 88 for piano

pitch_to_index = {pitch: idx for idx, pitch in enumerate(unique_pitches)}
index_to_pitch = {idx: pitch for idx, pitch in enumerate(unique_pitches)}
total_pitches = len(unique_pitches)  # Expected to be 88

print(f"\nUnique pitch count: {total_pitches}")
print(f"Pitch indexing configured for {total_pitches} pitches.")

# Persist pitch mappings for future reference
with open('transformer_pitch_to_index.pkl', 'wb') as file:
    pickle.dump(pitch_to_index, file)
with open('transformer_index_to_pitch.pkl', 'wb') as file:
    pickle.dump(index_to_pitch, file)

# Define training hyperparameters
sequence_length = 64  # Should match the generation
batch_size = 256
num_epochs = 40
early_stop_patience = 5

# Partition the dataset to prevent data leakage
training_set, testing_set = train_test_split(all_discretized_sequences, test_size=0.2, random_state=42)
training_set, validation_set = train_test_split(training_set, test_size=0.1, random_state=42)

print(f"Number of training sequences: {len(training_set)}")
print(f"Number of validation sequences: {len(validation_set)}")
print(f"Number of testing sequences: {len(testing_set)}")

# Step 4: Define a Data Generator Class for Transformer
class TransformerDataGenerator(Sequence):
    def __init__(self, sequences, seq_len, batch_sz, num_pitches, pitch_map, shuffle=True):
        """
        Initializes the data generator for Transformer training.

        Args:
            sequences (list): List of quantized note sequences.
            seq_len (int): Number of previous time steps for input.
            batch_sz (int): Size of each batch.
            num_pitches (int): Total number of unique pitches.
            pitch_map (dict): Mapping from pitch to ID.
            shuffle (bool): Whether to shuffle data after each epoch.
        """
        self.sequences = sequences
        self.seq_length = seq_len
        self.batch_size = batch_sz
        self.num_pitches = num_pitches
        self.pitch_mapping = pitch_map
        self.shuffle = shuffle
        self.inputs, self.targets = self._prepare_data()
        self.on_epoch_end()

    def _prepare_data(self):
        inputs = []
        targets = []
        for seq in self.sequences:
            if len(seq) < self.seq_length + 1:
                continue
            for i in range(len(seq) - self.seq_length):
                input_seq = seq[i:i + self.seq_length]
                target_seq = seq[i + 1:i + self.seq_length + 1]  # Shifted by one time step
                inputs.append(input_seq)
                targets.append(target_seq)
        return inputs, targets

    def __len__(self):
        return int(np.ceil(len(self.inputs) / self.batch_size))

    def __getitem__(self, idx):
        batch_inputs = self.inputs[idx * self.batch_size : (idx + 1) * self.batch_size]
        batch_targets = self.targets[idx * self.batch_size : (idx + 1) * self.batch_size]

        # Convert inputs to multi-hot sequences
        X = np.zeros((len(batch_inputs), self.seq_length, self.num_pitches), dtype=np.float32)
        for i, seq in enumerate(batch_inputs):
            for t, pitch_set in enumerate(seq):
                for pitch in pitch_set:
                    if pitch in self.pitch_mapping:
                        X[i, t, self.pitch_mapping[pitch]] = 1.0  # Multi-hot encoding

        # Convert targets to multi-hot sequences
        Y = np.zeros((len(batch_targets), self.seq_length, self.num_pitches), dtype=np.float32)
        for i, seq in enumerate(batch_targets):
            for t, pitch_set in enumerate(seq):
                for pitch in pitch_set:
                    if pitch in self.pitch_mapping:
                        Y[i, t, self.pitch_mapping[pitch]] = 1.0

        return X, Y

    def on_epoch_end(self):
        if self.shuffle:
            combined = list(zip(self.inputs, self.targets))
            random.shuffle(combined)
            self.inputs, self.targets = zip(*combined)

# Step 5: Instantiate Data Generators for Training, Validation, and Testing
train_generator = TransformerDataGenerator(training_set, sequence_length, batch_size, total_pitches, pitch_to_index)
val_generator = TransformerDataGenerator(validation_set, sequence_length, batch_size, total_pitches, pitch_to_index, shuffle=False)
test_generator = TransformerDataGenerator(testing_set, sequence_length, batch_size, total_pitches, pitch_to_index, shuffle=False)

# Step 6: Define Dice Loss Function
def dice_loss(y_true, y_pred, smooth=1):
    """
    Computes the Dice Loss for multi-label classification.

    Args:
        y_true (tensor): Ground truth binary labels.
        y_pred (tensor): Predicted probabilities.
        smooth (float): Smoothing factor to prevent division by zero.

    Returns:
        tensor: Dice loss value.
    """
    y_true_f = tf.reshape(y_true, [-1, total_pitches])
    y_pred_f = tf.reshape(y_pred, [-1, total_pitches])
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=1)
    sum_labels = tf.reduce_sum(y_true_f, axis=1)
    sum_preds = tf.reduce_sum(y_pred_f, axis=1)
    dice_coeff = (2. * intersection + smooth) / (sum_labels + sum_preds + smooth)
    return 1 - tf.reduce_mean(dice_coeff)

# Step 7: Build the Transformer Model Architecture
def build_transformer_model(seq_len, num_pitches, embed_dim, num_heads, ff_dim, dropout=0.3):
    """
    Constructs a Transformer-based model for music generation.

    Args:
        seq_len (int): Length of input sequences.
        num_pitches (int): Number of unique pitches.
        embed_dim (int): Embedding dimension.
        num_heads (int): Number of attention heads.
        ff_dim (int): Feed-forward network dimension.
        dropout (float): Dropout rate.

    Returns:
        tf.keras.Model: Compiled Transformer model.
    """
    inputs = Input(shape=(seq_len, num_pitches), name='input_layer')  # Multi-hot encoded inputs

    # Project multi-hot vectors to embedding space
    x = Dense(embed_dim, activation='relu', name='input_projection')(inputs)

    # Transformer Blocks
    for i in range(4):
        # Multi-Head Self-Attention
        attention_output = MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=dropout, name=f'mha_{i}'
        )(x, x)
        attention_output = Dropout(dropout)(attention_output)
        x = LayerNormalization(epsilon=1e-6)(x + attention_output)

        # Feed-Forward Network
        ff_output = Dense(ff_dim, activation='relu')(x)
        ff_output = Dense(embed_dim)(ff_output)
        ff_output = Dropout(dropout)(ff_output)
        x = LayerNormalization(epsilon=1e-6)(x + ff_output)

    # TimeDistributed Dense Layer to maintain sequence output
    x = Dense(embed_dim, activation='relu', name='time_distributed_dense')(x)

    # Output Layer with sigmoid activation for multi-label prediction
    outputs = Dense(num_pitches, activation='sigmoid', dtype='float32', name='output_layer')(x)

    model = Model(inputs=inputs, outputs=outputs, name='Transformer_Music_Generator')
    return model

print("\nBuilding the Transformer-based music generation model...")
embed_dim = 256
num_heads = 8
ff_dim = 512
dropout_rate = 0.5

transformer_model = build_transformer_model(sequence_length, total_pitches, embed_dim, num_heads, ff_dim, dropout=dropout_rate)
transformer_model.summary()

# Step 8: Compile the Model with Dice Loss and Built-in Metrics
transformer_model.compile(
    loss=dice_loss,  # Using Dice Loss to match GRU model
    optimizer=Adam(learning_rate=0.001),
    metrics=[
        'binary_accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc')
    ]
)

# Step 9: Configure Callbacks for Training
checkpoint_dir = './transformer_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_cb = ModelCheckpoint(
    filepath=os.path.join(checkpoint_dir, 'best_transformer_model.keras'),
    save_weights_only=False,
    monitor='val_loss',
    mode='min',
    save_best_only=True,
    verbose=1
)

early_stop_cb = EarlyStopping(
    monitor='val_loss',
    patience=early_stop_patience,
    verbose=1,
    restore_best_weights=True
)

reduce_lr_cb = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    verbose=1,
    min_lr=1e-6
)

callbacks = [checkpoint_cb, early_stop_cb, reduce_lr_cb]

# Step 10: Serialize and Save Training Sequences for Generation
def serialize_training_sequences(sequences, window_size):
    """
    Saves encoded training sequences to a pickle file for later use in music generation.

    Args:
        sequences (list): List of training sequences with encoded pitch indices.
        window_size (int): Length of the input sequence window.

    Saves:
        'transformer_train_seeds.pkl': Pickle file containing processed training sequences.
    """
    seeds = []
    for seq in sequences:
        if len(seq) >= window_size:
            seeds.append(seq[:window_size])
    with open('transformer_train_seeds.pkl', 'wb') as file:
        pickle.dump(seeds, file)
    print("Training seeds have been serialized and saved to 'transformer_train_seeds.pkl'.")

# Execute the serialization of training sequences
serialize_training_sequences(training_set, sequence_length)

# Step 11: Begin Model Training
print("\nCommencing model training...")
training_history = transformer_model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=num_epochs,
    callbacks=callbacks,
    verbose=1
)

# Step 12: Evaluate Model Performance on the Testing Dataset
print("\nAssessing model performance on the testing dataset...")
evaluation_results = transformer_model.evaluate(test_generator, verbose=1)
print(f"Test Loss: {evaluation_results[0]}")
print(f"Test Binary Accuracy: {evaluation_results[1]}")
print(f"Test Precision: {evaluation_results[2]}")
print(f"Test Recall: {evaluation_results[3]}")
print(f"Test AUC: {evaluation_results[4]}")

# Step 13: Persist the Final Trained Model for Future Use
transformer_model.save('final_transformer_music_model.keras')
print("The trained Transformer model has been saved as 'final_transformer_music_model.keras'.")

# Step 14: Function to Visualize Training and Validation Metrics
def visualize_training_metrics(history):
    """
    Plots the training and validation metrics over each epoch.

    Args:
        history (History): Keras History object containing training metrics.
    """
    metrics_to_plot = ['loss', 'binary_accuracy', 'precision', 'recall', 'auc']

    plt.figure(figsize=(25, 15))

    for idx, metric in enumerate(metrics_to_plot, 1):
        plt.subplot(3, 2, idx)
        plt.plot(history.history[metric], label=f'Training {metric}')
        plt.plot(history.history[f'val_{metric}'], label=f'Validation {metric}')
        plt.title(f'{metric.replace("_", " ").title()} Progress')
        plt.xlabel('Epoch')
        plt.ylabel(metric.replace("_", " ").title())
        plt.legend()

    plt.tight_layout()
    plt.show()

# Generate plots for training history
visualize_training_metrics(training_history)
