# Transformer Music Generation






In [None]:
# Install essential packages if not already present
!pip install pretty_midi tensorflow

# Import necessary modules
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
import pickle
import pretty_midi
import random

# Suppress TensorFlow warnings for clarity
import warnings
warnings.filterwarnings('ignore')
tf.get_logger().setLevel('ERROR')

# 1. Function to Load the Trained Transformer Model and Pitch Mappings

def load_trained_transformer_model(model_path='final_transformer_music_model.keras',
                                   pitch_to_index_path='transformer_pitch_to_index.pkl',
                                   index_to_pitch_path='transformer_index_to_pitch.pkl'):
    """
    Loads the pre-trained Transformer model along with pitch mappings.

    Args:
        model_path (str): Path to the saved Transformer model.
        pitch_to_index_path (str): Path to the pitch-to-index mapping pickle file.
        index_to_pitch_path (str): Path to the index-to-pitch mapping pickle file.

    Returns:
        tf.keras.Model: Loaded Transformer model.
        dict: Mapping from pitches to indices.
        dict: Mapping from indices to pitches.
    """
    print("Loading the pre-trained Transformer model...")
    try:
        # Load the model without any custom layers
        transformer_model = load_model(model_path, compile=False)
        print("Transformer model successfully loaded.\n")
    except Exception as e:
        print(f"Error loading the model: {e}")
        raise e

    print("Retrieving pitch mappings...")
    try:
        with open(pitch_to_index_path, 'rb') as pt_file:
            pitch_to_idx = pickle.load(pt_file)
        with open(index_to_pitch_path, 'rb') as itp_file:
            idx_to_pitch = pickle.load(itp_file)
        print("Pitch mappings successfully retrieved.\n")
    except Exception as e:
        print(f"Error loading pitch mappings: {e}")
        raise e

    # Verify that mappings are consistent with the model's output layer
    num_pitches_model = transformer_model.output_shape[-1]
    num_pitches_mapping = len(pitch_to_idx)

    if num_pitches_model != num_pitches_mapping:
        raise ValueError(f"Mismatch between model's output pitches ({num_pitches_model}) and pitch mapping size ({num_pitches_mapping}).")
    else:
        print(f"Model's output layer has {num_pitches_model} pitches, matching the pitch mappings.\n")

    return transformer_model, pitch_to_idx, idx_to_pitch

# 2. Function to Load Seed Sequences for Initialization

def load_seed_sequences(seeds_file='transformer_train_seeds.pkl'):
    """
    Loads serialized seed sequences to initialize the music generation.

    Args:
        seeds_file (str): Path to the serialized training sequences pickle file.

    Returns:
        list: List of seed sequences with encoded pitch indices.
    """
    print("Loading serialized training sequences for seed selection...")
    try:
        with open(seeds_file, 'rb') as seed_file:
            training_seeds = pickle.load(seed_file)
        print(f"Loaded {len(training_seeds)} training seed sequences.\n")
    except Exception as e:
        print(f"Error loading seed sequences: {e}")
        raise e
    return training_seeds

# 3. Function to Generate New Music Sequences Using the Transformer Model

def generate_music_sequence_transformer(model, seeds, idx_to_pitch, seq_length=64,
                                       generation_steps=500, temperature=1.0, activation_threshold=0.5):
    """
    Generates a new sequence of musical pitches using the trained Transformer model.

    Args:
        model (tf.keras.Model): Pre-trained Transformer model for music generation.
        seeds (list): List of seed sequences for initialization.
        idx_to_pitch (dict): Dictionary mapping indices to pitch values.
        seq_length (int): Number of previous time steps to consider for prediction.
        generation_steps (int): Total number of new time steps to generate.
        temperature (float): Controls the randomness of predictions.
        activation_threshold (float): Threshold to determine active pitches.

    Returns:
        list: Generated sequence of pitch indices over the specified steps.
    """
    # Randomly select a seed sequence from the training data
    initial_seed = random.choice(seeds)
    print("Randomly selected an initial seed sequence for generation.\n")

    # Initialize the generated sequence with the seed
    generated_sequence = initial_seed.copy()

    print("Commencing music generation...")
    for step in range(generation_steps):
        # Prepare input data by encoding the last 'seq_length' time steps
        input_array = np.zeros((1, seq_length, len(idx_to_pitch)), dtype=np.float32)
        for t in range(seq_length):
            if len(generated_sequence) >= seq_length:
                current_step = generated_sequence[-seq_length + t]
            else:
                current_step = generated_sequence[:t+1]
            for pitch_idx in current_step:
                if 0 <= pitch_idx < len(idx_to_pitch):
                    input_array[0, t, pitch_idx] = 1.0  # Multi-hot encoding

        # Predict the next set of pitches
        predictions = model.predict(input_array, verbose=0)[0, -1]  # Shape: (num_pitches,)

        # Apply temperature scaling to introduce randomness
        if temperature <= 0:
            temperature = 1.0  # Prevent division by zero or negative temperature
        scaled_preds = np.log(predictions + 1e-8) / temperature
        exp_scaled = np.exp(scaled_preds)
        probability_distribution = exp_scaled / np.sum(exp_scaled)

        # Determine active pitches based on the threshold
        active_pitches = [idx for idx, prob in enumerate(probability_distribution) if prob > activation_threshold]

        # Ensure that all pitch indices are within valid range
        active_pitches = [idx for idx in active_pitches if idx < len(idx_to_pitch)]

        if not active_pitches:
            # Ensure at least one pitch is active by selecting the most probable pitch
            active_pitches = [np.argmax(probability_distribution)]

        # Append the new pitches to the generated sequence
        generated_sequence.append(active_pitches)

        # Optional: Log progress at intervals
        if (step + 1) % 100 == 0:
            print(f"Progress Update: Generated {step + 1} steps...")

    print(f"\nMusic generation completed: {generation_steps} new steps generated.\n")
    return generated_sequence

# 4. Function to Convert Generated Sequences to a MIDI File

def convert_sequence_to_midi_transformer(generated_sequence, idx_to_pitch, output_file='new_composition_transformer.mid',
                                        step_duration=0.3, note_volume=70):
    """
    Converts a generated sequence of pitch indices into a MIDI file.

    Args:
        generated_sequence (list): List of generated pitch indices.
        idx_to_pitch (dict): Dictionary mapping indices to pitch values.
        output_file (str): Desired filename for the output MIDI file.
        step_duration (float): Duration of each time step in beats.
        note_volume (int): Velocity (volume) of the notes.

    Returns:
        str: Path to the saved MIDI file.
    """
    print(f"Transforming the generated sequence into MIDI format and saving as '{output_file}'...")

    # Initialize a PrettyMIDI object
    midi_composition = pretty_midi.PrettyMIDI()

    # Create an Instrument instance for Acoustic Grand Piano
    acoustic_grand_piano_program = pretty_midi.instrument_name_to_program('Acoustic Grand Piano')
    acoustic_grand_piano = pretty_midi.Instrument(program=acoustic_grand_piano_program)

    current_time = 0.0
    active_notes_tracker = {}

    for timestep_pitches in generated_sequence:
        # Convert pitch indices to actual MIDI pitch numbers
        current_active = []
        for pitch_idx in timestep_pitches:
            if pitch_idx in idx_to_pitch:
                current_active.append(idx_to_pitch[pitch_idx])

        # Determine pitches to deactivate
        pitches_to_remove = set(active_notes_tracker.keys()) - set(current_active)
        for pitch in pitches_to_remove:
            note = active_notes_tracker[pitch]
            note.end = current_time
            acoustic_grand_piano.notes.append(note)
            del active_notes_tracker[pitch]

        # Activate new pitches or extend existing ones
        for pitch in current_active:
            if pitch not in active_notes_tracker:
                # Start a new note
                new_note = pretty_midi.Note(
                    velocity=note_volume,
                    pitch=pitch,
                    start=current_time,
                    end=current_time + step_duration
                )
                active_notes_tracker[pitch] = new_note
            else:
                # Extend the duration of an already active note
                active_notes_tracker[pitch].end = current_time + step_duration

        # Advance the current time
        current_time += step_duration

    # Turn off any remaining active notes at the end
    for pitch, note in active_notes_tracker.items():
        note.end = current_time
        acoustic_grand_piano.notes.append(note)

    # Add the instrument to the PrettyMIDI object
    midi_composition.instruments.append(acoustic_grand_piano)

    # Save the MIDI file
    midi_composition.write(output_file)
    print(f"MIDI file '{output_file}' has been successfully saved.\n")

    return output_file

# 5. Function to Execute the Complete Music Generation Workflow

def perform_music_generation_transformer(model_path='final_transformer_music_model.keras',
                                       pitch_to_idx_path='transformer_pitch_to_index.pkl',
                                       index_to_pitch_path='transformer_index_to_pitch.pkl',
                                       seeds_file='transformer_train_seeds.pkl',
                                       output_midi='new_composition_transformer.mid',
                                       seq_length=64,
                                       generation_steps=300,
                                       temperature=1.0,
                                       activation_threshold=0.5,
                                       step_duration=0.3,
                                       note_volume=70):
    """
    Orchestrates the entire process of generating new music using the trained Transformer model.

    Args:
        model_path (str): Path to the trained Transformer model file.
        pitch_to_idx_path (str): Path to the pitch-to-index mapping file.
        index_to_pitch_path (str): Path to the index-to-pitch mapping file.
        seeds_file (str): Path to the serialized training sequences file.
        output_midi (str): Desired name for the output MIDI file.
        seq_length (int): Number of previous time steps to consider.
        generation_steps (int): Total number of new time steps to generate.
        temperature (float): Controls the randomness of predictions.
        activation_threshold (float): Threshold to determine active pitches.
        step_duration (float): Duration of each time step in beats.
        note_volume (int): Velocity (volume) of the notes.

    Returns:
        None
    """
    # Load the trained Transformer model and pitch mappings
    transformer_model, pitch_to_idx, idx_to_pitch = load_trained_transformer_model(
        model_path, pitch_to_idx_path, index_to_pitch_path
    )

    # Load training seeds for initializing the generation process
    training_seeds = load_seed_sequences(seeds_file)

    # Generate a new sequence of pitches using the Transformer model
    new_music_sequence = generate_music_sequence_transformer(
        model=transformer_model,
        seeds=training_seeds,
        idx_to_pitch=idx_to_pitch,
        seq_length=seq_length,
        generation_steps=generation_steps,
        temperature=temperature,
        activation_threshold=activation_threshold
    )

    # Convert the generated sequence into a MIDI file
    midi_file_path = convert_sequence_to_midi_transformer(
        generated_sequence=new_music_sequence,
        idx_to_pitch=idx_to_pitch,
        output_file=output_midi,
        step_duration=step_duration,
        note_volume=note_volume
    )

    print(f"Music generation process completed. MIDI file available at: {midi_file_path}")

# Execute the music generation process
if __name__ == "__main__":
    perform_music_generation_transformer(
        model_path='final_transformer_music_model.keras',
        pitch_to_idx_path='transformer_pitch_to_index.pkl',
        index_to_pitch_path='transformer_index_to_pitch.pkl',
        seeds_file='transformer_train_seeds.pkl',
        output_midi='new_composition_transformer.mid',
        seq_length=64,
        generation_steps=300,          # Modify as needed
        temperature=1.0,               # Adjust to control randomness
        activation_threshold=0.5,      # Tune for pitch activation sensitivity
        step_duration=0.3,             # Set the duration per time step
        note_volume=70                  # Set the volume for generated notes
    )
