In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/prompt_optmizer/vertex_ai_prompt_optimizer_ui.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fprompts%2Fprompt_optmizer%2Fvertex_ai_prompt_optimizer_ui.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/prompts/prompt_optmizer/vertex_ai_prompt_optimizer_ui.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/prompt_optmizer/vertex_ai_prompt_optimizer_ui.ipynb">
      <img width="32px" src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

# Overview
Welcome to Vertex AI Prompt Optimizer (VAPO)! This Notebook showcases VAPO, a tool that iteratively optimizes prompts to suit a target model (e.g., `gemini-1.5-pro`) using target-specific metric(s).

Key Use Cases:

* Prompt Optimization: Enhance the quality of an initial prompt by refining its structure and content to match the target model's optimal input characteristics.

* Prompt Translation: Adapt prompts optimized for one model to work effectively with a different target model.

# Step 0: Install packages and libraries

In [None]:
! pip3 install -U google-cloud-aiplatform -q

import datetime
import os
import time

from IPython.display import HTML, display
from google.auth import default
from google.cloud import aiplatform, storage
from google.colab import auth, output
import gspread
import ipywidgets as widgets
import jinja2
from jinja2 import BaseLoader, Environment
import jinja2.meta
import pandas as pd
import tensorflow.io.gfile as gfile

output.enable_custom_widget_manager()
from io import StringIO
import json
import re


def authenticate():
    auth.authenticate_user()
    creds, _ = default()
    return gspread.authorize(creds)


def is_target_required_metric(eval_metric: str) -> bool:
    return eval_metric in [
        "bleu",
        "exact_match",
        "question_answering_correctness",
        "rouge_1",
        "rouge_2",
        "rouge_l",
        "rouge_l_sum",
        "tool_call_valid",
        "tool_name_match",
        "tool_parameter_key_match",
        "tool_parameter_kv_match",
    ]


def is_run_target_required(eval_metric_types: list[str], source_model: str) -> bool:
    if source_model:
        return False

    label_required = False
    for metric in eval_metric_types:
        label_required = label_required or is_target_required_metric(metric)
    return label_required


_TARGET_KEY = "target"


def validate_prompt_and_data(
    template: str,
    dataset_path: str,
    placeholder_to_content: str,
    label_enforced: bool,
) -> None:
    """Validates the prompt template and the dataset."""
    placeholder_to_content = json.loads(placeholder_to_content)
    with gfile.GFile(dataset_path, "r") as f:
        data = [json.loads(line) for line in f.readlines()]

    env = jinja2.Environment()
    try:
        parsed_content = env.parse(template)
    except jinja2.exceptions.TemplateSyntaxError as e:
        raise ValueError(f"Invalid template: {template}") from e

    template_variables = jinja2.meta.find_undeclared_variables(parsed_content)
    extra_keys = set()
    for ex in data:
        ex.update(placeholder_to_content)
        missing_keys = [key for key in template_variables if key not in ex]
        extra_keys.update([key for key in ex if key not in template_variables])
        if label_enforced:
            if _TARGET_KEY not in ex:
                raise ValueError(
                    f"The example {ex} doesn't have a key corresponding to the target"
                    f" var: {_TARGET_KEY}"
                )
            if not ex[_TARGET_KEY]:
                raise ValueError(f"The following example has an empty target: {ex}")
        if missing_keys:
            raise ValueError(
                f"The example {ex} doesn't have a key corresponding to following"
                f" template vars: {missing_keys}"
            )
    if extra_keys:
        raise Warning(
            "Warning: extra keys in the examples not used in the context/task"
            f" template {extra_keys}"
        )


def run_custom_job(
    display_name: str,
    container_uri: str,
    container_args: dict[str, str],
) -> None:
    """A sample to create custom jobs."""
    worker_pool_specs = [
        {
            "replica_count": 1,
            "container_spec": {
                "image_uri": container_uri,
                "args": [f"--{k}={v}" for k, v in container_args.items()],
            },
            "machine_spec": {
                "machine_type": "n1-standard-4",
            },
        }
    ]

    custom_job = aiplatform.CustomJob(
        display_name=display_name,
        worker_pool_specs=worker_pool_specs,
    )
    custom_job.submit()
    return custom_job


def run_apd(config: dict[str, str], bucket_uri: str, display_name: str) -> None:
    """A function to the vertex prompt optimizer."""
    print(f"\n\nJob display name: {display_name}")
    version = "preview_v1_0"
    container_uri = "us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/apd"
    config_path = f"{bucket_uri}/{display_name}/input_config.json"

    with gfile.GFile(config_path, "w") as f:
        json.dump(config, f)

    aiplatform.init(
        project=config["project"],
        location=config["target_model_location"],
        staging_bucket=f"{bucket_uri}/{display_name}",
    )

    return run_custom_job(
        display_name=display_name,
        container_uri=f"{container_uri}:{version}",
        container_args={"config": config_path},
    )


def update_best_display(
    df: pd.DataFrame,
    textarea: widgets.Textarea,
    best_score_label: widgets.Label,
    eval_metric: str,
) -> None:
    """Update the best prompt display."""

    df["score"] = df[f"metrics.{eval_metric}/mean"]

    best_template = df.loc[df["score"].argmax(), "prompt"]
    best_score = df.loc[df["score"].argmax(), "score"]
    original_score = df.loc[0, "score"]

    def placeholder_llm():
        return "{{llm()}}"

    env = Environment(loader=BaseLoader())
    env.globals["llm"] = placeholder_llm

    best_template = best_template.replace("store('answer', llm())", "llm()")
    textarea.value = best_template
    improvement = best_score - original_score
    no_improvement_str = "\nNo better template is found yet." if not improvement else ""
    best_score_label.value = (
        f"Score: {best_score}" f" Improvement: {improvement: .3f} {no_improvement_str}"
    )


def generate_dataframe(filename: str) -> pd.DataFrame:
    """Generates a pandas dataframe from a json file."""
    if not gfile.exists(filename):
        return pd.DataFrame()

    with gfile.GFile(filename, "r") as f:
        try:
            data = json.load(f)
        except:
            return pd.DataFrame()
        return pd.json_normalize(data)


def left_aligned_df_html(df: pd.DataFrame) -> None:
    """Displays a Pandas DataFrame in Colab with left-aligned values."""

    # Convert to HTML table, but keep the HTML in a variable
    html_table = df.to_html(index=False, classes="left-aligned")

    # Add CSS styling to left-align table data cells and override default styles
    styled_html = f"""
    <style>
        .left-aligned td, .left-aligned th {{ text-align: left !important; }}
    </style>
    {html_table}
    """

    # Display the styled HTML table
    return HTML(styled_html)


def extract_top_level_function_name(source_code: str) -> str | None:
    match = re.search(r"^def\s+([a-zA-Z_]\w*)\s*\(", source_code, re.MULTILINE)
    if match:
        return match.group(1)
    return None


class ProgressForm:
    """A class to display the progress of the optimization job."""

    def __init__(self):
        self.instruction_progress_bar = None
        self.instruction_display = None
        self.instruction_best = None
        self.instruction_score = None

        self.demo_progress_bar = None
        self.demo_display = None
        self.demo_best = None
        self.demo_score = None

        self.job_state_display = None

        self.instruction_df = None
        self.demo_df = None

        self.started = False

    def init(self, params: dict[str, str]):
        """Initialize the progress form."""
        self.job_state_display = display(
            HTML("<span>Job State: Not Started!</span>"), display_id=True
        )
        self.status_display = display(HTML(""), display_id=True)

        if params["optimization_mode"] in ["instruction", "instruction_and_demo"]:
            (
                self.instruction_progress_bar,
                self.instruction_display,
                self.instruction_best,
                self.instruction_score,
            ) = self.create_progress_ui("Instruction", params["num_steps"])

        if params["optimization_mode"] in ["demonstration", "instruction_and_demo"]:
            (
                self.demo_progress_bar,
                self.demo_display,
                self.demo_best,
                self.demo_score,
            ) = self.create_progress_ui(
                "Demonstration", params["num_demo_set_candidates"]
            )

        eval_metric = "composite_metric"
        if len(params["eval_metrics_types"]) == 1:
            eval_metric = params["eval_metrics_types"][0]

        if eval_metric != "composite_metric" and "custom_metric_source_code" in params:
            self.eval_metric = extract_top_level_function_name(
                params["custom_metric_source_code"]
            )
        else:
            self.eval_metric = eval_metric

        self.output_path = params["output_path"]
        self.started = True

    def update_progress(
        self,
        progress_bar: widgets.IntProgress,
        templates_file: str,
        df: pd.DataFrame | None,
        df_display: display,
        best_textarea: widgets.Textarea,
        best_score: widgets.Label,
        eval_metric: str,
    ):
        """Update the progress of the optimization job."""

        def get_last_step(df: pd.DataFrame):
            if df.empty:
                return -1
            return int(df["step"].max())

        if progress_bar is None or df is None:
            return pd.DataFrame()

        new_df = generate_dataframe(templates_file)

        last_step = get_last_step(df)
        new_last_step = get_last_step(new_df)
        if new_last_step > last_step:
            df_display.update(left_aligned_df_html(new_df))
            update_best_display(new_df, best_textarea, best_score, eval_metric)
            progress_bar.value = progress_bar.value + new_last_step - last_step

        return new_df

    def create_progress_ui(
        self, opt_mode: str, num_opt_steps: int
    ) -> tuple[widgets.IntProgress, display, widgets.Textarea, widgets.Label]:
        """Create the progress UI for a specific optimization mode."""
        print(f"\n\n{opt_mode} Optimization")
        progress_bar = widgets.IntProgress(
            value=0, min=0, max=num_opt_steps, step=1, description="Progress"
        )
        display(progress_bar)
        print("\nGenerated Templates:")
        templates_display = display("No template is evaluated yet!", display_id=True)

        print("\nBest Template so far:")
        best_textarea = widgets.Textarea(
            value="NA",
            disabled=False,
            layout=widgets.Layout(width="80%", height="150px"),
        )
        display(best_textarea)

        best_score = widgets.Label(value="Score: NA Improvement: NA")
        display(best_score)

        return progress_bar, templates_display, best_textarea, best_score

    def monitor_progress(self, job: aiplatform.CustomJob, params: dict[str, str]):
        """Monitor the progress of the optimization job."""
        if not self.started:
            self.init(params)

        self.job_state_display.update(HTML(f"<span>Job State: {job.state.name}</span>"))

        # Initial display of the dataframe
        instruction_templates_file = f"{self.output_path}/instruction/templates.json"
        demo_templates_file = f"{self.output_path}/demonstration/templates.json"

        if not job.done():
            self.instruction_df = self.update_progress(
                self.instruction_progress_bar,
                instruction_templates_file,
                self.instruction_df,
                self.instruction_display,
                self.instruction_best,
                self.instruction_score,
                self.eval_metric,
            )
            self.demo_df = self.update_progress(
                self.demo_progress_bar,
                demo_templates_file,
                self.demo_df,
                self.demo_display,
                self.demo_best,
                self.demo_score,
                self.eval_metric,
            )
            return True

        if job.state.name != "JOB_STATE_SUCCEEDED":
            errors = [f"Error: Job failed with error {job.error}."]
            for err_file in [
                f"{self.output_path}/instruction/error.json",
                f"{self.output_path}/demonstration/error.json",
            ]:
                if gfile.exists(err_file):
                    with gfile.GFile(err_file, "r") as f:
                        error_json = json.load(f)
                    errors.append(f"Detailed error: {error_json}")
                    errors.append(
                        f"Please feel free to send {err_file} to the VAPO team to help"
                        " resolving the issue."
                    )

            errors.append(
                "All the templates found before failure can be found under"
                f" {self.output_path}"
            )
            errors.append(
                "Please consider rerunning to make sure the failure is intransient."
            )
            err = "\n".join(errors)
            self.status_display.update(HTML(f'<span style="color: red;">{err}</span>'))
        else:
            self.status_display.update(
                HTML(
                    '<span style="color: green;">Job succeeded!</span> <span>All the'
                    f" artifacts can be found under {self.output_path}</span>"
                )
            )
        return False


def display_dataframe(df: pd.DataFrame) -> None:
    """Display a pandas dataframe in Colab."""

    # Function to wrap text in a scrollable div
    def wrap_in_scrollable_div(text):
        return f'<div class="scrollable">{text}</div>'

    # Apply the function to every cell using the format method
    styled_html = df.style.format(wrap_in_scrollable_div).to_html(index=False)

    # Display the HTML in the notebook
    display(HTML(styled_html))


def split_gcs_path(gcs_path: str) -> tuple[str, str]:
    """Splits a full GCS path into bucket name and prefix."""
    if gcs_path.startswith("gs://"):
        path_without_scheme = gcs_path[5:]  # Remove the 'gs://' part
        parts = path_without_scheme.split("/", 1)
        bucket_name = parts[0]
        prefix = parts[1] if len(parts) > 1 else ""
        return bucket_name, prefix
    else:
        raise ValueError("Invalid GCS path. Must start with 'gs://'")


def list_gcs_objects(full_gcs_path: str) -> list[str]:
    """Lists all the objects in the given GCS path."""
    bucket_name, prefix = split_gcs_path(full_gcs_path)
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blobs = bucket.list_blobs(
        prefix=prefix
    )  # List all objects that start with the prefix

    return [blob.name for blob in blobs]


def find_directories_with_files(
    full_gcs_path: str, required_files: list[str]
) -> list[str]:
    """Finds directories containing specific files under the given full GCS path."""
    bucket_name, prefix = split_gcs_path(full_gcs_path)
    all_paths = list_gcs_objects(f"gs://{bucket_name}/{prefix}")
    directories = set()

    # Create a dictionary to track files found in each directory
    file_presence = {}
    for path in all_paths:
        directory = "/".join(path.split("/")[:-1])  # Get the directory part of the path
        filename = path.split("/")[-1]  # Get the filename part of the path
        if directory:
            if directory not in file_presence:
                file_presence[directory] = set()
            file_presence[directory].add(filename)

    # Check which directories have all required files
    for directory, files in file_presence.items():
        if all(file in files for file in required_files):
            directories.add(f"gs://{bucket_name}/{directory}")

    return list(directories)


def extract_metric_name(metric_string: str):
    # Use a regular expression to find the metric name
    match = re.search(r"\.(\w+)/", metric_string)
    # Return the matched group if found
    return match.group(1) if match else metric_string


def read_file_from_gcs(filename: str):
    with gfile.GFile(filename, "r") as f:
        return f.read()


def process_results(df: pd.DataFrame) -> pd.DataFrame:
    """Process the results removing columns that could be confusing."""
    columns_to_drop = []
    # Dropping columns that could be confusing.
    for col in df.columns:
        if "confidence" in col:
            columns_to_drop.append(col)
        if "raw_eval_resp" in col:
            columns_to_drop.append(col)
        if col == "instruction":
            columns_to_drop.append(col)
        if col == "context":
            columns_to_drop.append(col)
    return df.drop(columns=columns_to_drop)


class ResultsUI:
    """A UI to display the results of a VAPO run."""

    def __init__(self, path: str):
        required_files = ["eval_results.json", "templates.json"]
        runs = find_directories_with_files(path, required_files)

        self.run_label = widgets.Label("Select Run:")
        self.run_dropdrown = widgets.Dropdown(
            options=runs, value=runs[0], layout=widgets.Layout(width="200px")
        )
        self.run_dropdrown.observe(self.display_run_handler, names="value")

        # Create a label widget for the description
        self.dropdown_description = widgets.Label("Select Template:")
        self.template_dropdown = widgets.Dropdown(
            options=[],
            value=None,
            layout=widgets.Layout(width="400px"),
            disabled=True,
        )
        self.template_dropdown.observe(self.display_template_handler, names="value")
        self.results_output = widgets.Output(
            layout=widgets.Layout(
                height="600px", overflow="auto", margin="20px 0px 0px 0px"
            )
        )
        self.display_run(runs[0])

    def display_template_handler(self, change: dict[str, str]) -> None:
        """Display the template and the corresponding evaluation results."""
        if change["new"] is None:
            return
        df_index = int(change["new"].split(" ")[1])
        self.display_eval_results(df_index)

    def display_run_handler(self, change) -> None:
        if change["new"] is None:
            return

        path = change["new"]
        self.display_run(path)

    def display_run(self, path: str) -> None:
        """Display the results of a VAPO run."""
        self.run_dropdrown.disabled = True
        filename = f"{path}/eval_results.json"
        eval_results = json.loads(read_file_from_gcs(filename))

        filename = f"{path}/templates.json"
        templates = json.loads(read_file_from_gcs(filename))

        if len(templates) == len(eval_results):
            offset = 0
        elif len(templates) == len(eval_results) + 1:
            # In some setups it is possible to have 1 more template than results.
            offset = 1
        else:
            raise ValueError(
                "Number of templates doesn't match number of eval results"
                f" {len(templates)} vs {len(eval_results)}"
            )
        self.templates = [
            pd.json_normalize(template) for template in templates[offset:]
        ]
        metric_columns = [col for col in self.templates[0].columns if "metric" in col]

        self.eval_results = [
            process_results(pd.read_json(StringIO(result["metrics_table"])))
            for result in eval_results
        ]
        options = []
        for i, template in enumerate(self.templates):
            metrics = []
            for col in metric_columns:
                value = template[col].tolist()[0]
                short_col = extract_metric_name(col)
                metrics.append(f"{short_col}: {value}")
            metrics_str = " ".join(metrics)
            options.append(f"Template {i} {metrics_str}")

        self.template_dropdown.disabled = False
        self.template_dropdown.options = options
        self.run_dropdrown.disabled = False

    def display_eval_results(self, index: int) -> None:
        """Display the evaluation results for a specific template."""
        with self.results_output:
            self.results_output.clear_output(wait=True)  # Clear previous output
            display_dataframe(self.templates[index])
            print()
            display_dataframe(self.eval_results[index])

    def get_container(self) -> widgets.Output:
        """Get the container widget for the results UI."""
        return widgets.VBox(
            [
                self.run_label,
                self.run_dropdrown,
                self.dropdown_description,
                self.template_dropdown,
                self.results_output,
            ]
        )

# Step 1: Configure your prompt template
Prompts consist of two key parts:
* System Instruction (SI) Template: A fixed instruction shared across all queries for a given task.
* Task/Context Template: A dynamic part that changes based on the task.

APD enables the translation and optimization of the System Instruction Template, while the Task/Context Template remains essential for evaluating different SI templates.

In [None]:
SYSTEM_INSTRUCTION = "Answer the following question. Let's think step by step.\n"  # @param {type:"string"}
PROMPT_TEMPLATE = (
    "Question: {{question}}\n\nAnswer:{{target}}"  # @param {type:"string"}
)

# Step 2: Input your data
To optimize the model, provide a CSV or JSONL file containing labeled validation samples
* Focus on examples that specifically demonstrate the issues you want to address.
* Recommendation: Use 50-100 distinct samples for reliable results. However, the tool can still be effective with as few as 5 samples.

For prompt translation:
* Consider using the source model to label examples that the target model struggles with, helping to identify areas for improvement.


In [None]:
# @markdown **Project setup**: <br/>
PROJECT_ID = "[YOUR_PROJECT]"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}
OUTPUT_PATH = "[OUTPUT_PATH]"  # @param {type:"string"}
# @markdown * GCS path of your bucket, e.g., gs://prompt_translation_demo, used to store all artifacts.
INPUT_DATA_PATH = "[INPUT_DATA_PATH]"  # @param {type:"string"}
# @markdown * Specify a GCS path for the input data, e.g., gs://prompt_translation_demo/input_data.jsonl.

# Step 3: Configure optimization settings
The optimization configs are defaulted to the values that are most commonly used and which we recommend using initially.

In [None]:
TARGET_MODEL = "gemini-1.5-flash-001"  # @param ["gemini-1.0-pro-001", "gemini-1.0-pro-002", "gemini-1.5-flash-001", "gemini-1.5-pro-001", "gemini-1.0-ultra-001"]
SOURCE_MODEL = ""  # @param ["", "gemini-1.0-pro-001", "gemini-1.0-pro-002", "gemini-1.5-flash-001", "gemini-1.5-pro-001", "gemini-1.0-ultra-001", "text-bison@001", "text-bison@002", "text-bison32k@002", "text-unicorn@001"]
# @markdown * If set, it will be used to generate ground truth responses for the input examples. This is useful to migrate the prompt from a source model.
OPTIMIZATION_MODE = "instruction_and_demo"  # @param ["instruction", "demonstration", "instruction_and_demo"]
OPTIMIZATION_METRIC = "question_answering_correctness"  # @param ["bleu", "coherence", "exact_match", "fluency", "groundedness", "text_quality", "verbosity", "rouge_1", "rouge_2", "rouge_l", "rouge_l_sum", "safety", "question_answering_correctness", "question_answering_quality", "summarization_quality", "tool_name_match", "tool_parameter_key_match", "tool_parameter_kv_match", "tool_call_valid"] {type:"string"}

# Step 4: Configure advanced optimization settings [Optional]

In [None]:
# @markdown **Instruction Optimization Configs**: <br/>
NUM_INST_OPTIMIZATION_STEPS = 10  # @param {type:"integer"}
NUM_TEMPLATES_PER_STEP = 2  # @param {type:"integer"}
# @markdown * Number of prompt templates generated and evaluated at each optimization step.

# @markdown **Demonstration Optimization Configs**: <br/>
NUM_DEMO_OPTIMIZATION_STEPS = 10  # @param {type:"integer"}
NUM_DEMO_PER_PROMPT = 3  # @param {type:"integer"}
# @markdown * Number of the demonstrations to include in each prompt.

# @markdown **Model Configs**: <br/>
TARGET_MODEL_QPS = 3  # @param {type:"integer"}
SOURCE_MODEL_QPS = 3  # @param {type:"integer"}
OPTIMIZER_MODEL = "gemini-1.5-flash-001"  # @param ["gemini-1.0-pro-001", "gemini-1.0-pro-002", "gemini-1.5-flash-001", "gemini-1.5-pro-001", "gemini-1.0-ultra-001", "text-bison@001", "text-bison@002", "text-bison32k@002", "text-unicorn@001"]
# @markdown * The model used to generated alternative prompts in the instruction optimization mode.
OPTIMIZER_MODEL_QPS = 3  # @param {type:"integer"}
EVAL_MODEL_QPS = 3  # @param {type:"integer"}
# @markdown * The QPS for calling the eval model, which is currently gemini-1.5-pro-001.

# @markdown **Multi-metric Configs**: <br/>
# @markdown Use this section only if you need more than one metric for optimization. This will override the metric you picked above.
OPTIMIZATION_METRIC_1 = "NA"  # @param ["NA", "bleu", "coherence", "exact_match", "fluency", "groundedness", "text_quality", "verbosity", "rouge_1", "rouge_2", "rouge_l", "rouge_l_sum", "safety", "question_answering_correctness", "question_answering_quality", "summarization_quality", "tool_name_match", "tool_parameter_key_match", "tool_parameter_kv_match", "tool_call_valid"] {type:"string"}
OPTIMIZATION_METRIC_1_WEIGHT = 0.0  # @param {type:"number"}
OPTIMIZATION_METRIC_2 = "NA"  # @param ["NA", "bleu", "coherence", "exact_match", "fluency", "groundedness", "text_quality", "verbosity", "rouge_1", "rouge_2", "rouge_l", "rouge_l_sum", "safety", "question_answering_correctness", "question_answering_quality", "summarization_quality", "tool_name_match", "tool_parameter_key_match", "tool_parameter_kv_match", "tool_call_valid"] {type:"string"}
OPTIMIZATION_METRIC_2_WEIGHT = 0.0  # @param {type:"number"}
OPTIMIZATION_METRIC_3 = "NA"  # @param ["NA", "bleu", "coherence", "exact_match", "fluency", "groundedness", "text_quality", "verbosity", "rouge_1", "rouge_2", "rouge_l", "rouge_l_sum", "safety", "question_answering_correctness", "question_answering_quality", "summarization_quality", "tool_name_match", "tool_parameter_key_match", "tool_parameter_kv_match", "tool_call_valid"] {type:"string"}
OPTIMIZATION_METRIC_3_WEIGHT = 0.0  # @param {type:"number"}
METRIC_AGGREGATION_TYPE = "weighted_sum"  # @param ["weighted_sum", "weighted_average"]

# @markdown **Misc Configs**: <br/>
PLACEHOLDER_TO_VALUE = "{}"  # @param
# @markdown * This variable is used for long prompt optimization to not optimize parts of prompt identified by placeholders. It provides a mapping from the placeholder variables to their content. See link for details.
RESPONSE_MIME_TYPE = "application/json"  # @param ["text/plain", "application/json"]
# @markdown * This variable determines the format of the output for the target model. See link for details.
TARGET_LANGUAGE = "English"  # @param ["English", "French", "German", "Hebrew", "Hindi", "Japanese", "Korean", "Portuguese", "Simplified Chinese", "Spanish", "Traditional Chinese"]
# @markdown * The language of the system instruction.

# Step 5: Run Prompt Optimizer

In [None]:
timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
display_name = f"pt_{timestamp}"

in_colab_enterprise = "GOOGLE_CLOUD_PROJECT" in os.environ
if not in_colab_enterprise:
    gc = authenticate()

label_enforced = is_run_target_required(
    [
        OPTIMIZATION_METRIC,
        OPTIMIZATION_METRIC_1,
        OPTIMIZATION_METRIC_2,
        OPTIMIZATION_METRIC_3,
    ],
    SOURCE_MODEL,
)
input_data_path = f"{INPUT_DATA_PATH}"
validate_prompt_and_data(
    "\n".join([SYSTEM_INSTRUCTION, PROMPT_TEMPLATE]),
    input_data_path,
    PLACEHOLDER_TO_VALUE,
    label_enforced,
)

output_path = f"{OUTPUT_PATH}/{display_name}"

params = {
    "project": PROJECT_ID,
    "num_steps": NUM_INST_OPTIMIZATION_STEPS,
    "prompt_template": SYSTEM_INSTRUCTION,
    "demo_and_query_template": PROMPT_TEMPLATE,
    "target_model": TARGET_MODEL,
    "target_model_qps": TARGET_MODEL_QPS,
    "target_model_location": LOCATION,
    "source_model": SOURCE_MODEL,
    "source_model_qps": SOURCE_MODEL_QPS,
    "source_model_location": LOCATION,
    "eval_model_qps": EVAL_MODEL_QPS,
    "eval_model_location": LOCATION,
    "optimization_mode": OPTIMIZATION_MODE,
    "num_demo_set_candidates": NUM_DEMO_OPTIMIZATION_STEPS,
    "demo_set_size": NUM_DEMO_PER_PROMPT,
    "aggregation_type": METRIC_AGGREGATION_TYPE,
    "data_limit": 50,
    "optimizer_model": OPTIMIZER_MODEL,
    "optimizer_model_qps": OPTIMIZER_MODEL_QPS,
    "optimizer_model_location": LOCATION,
    "num_template_eval_per_step": NUM_TEMPLATES_PER_STEP,
    "input_data_path": input_data_path,
    "output_path": output_path,
    "response_mime_type": RESPONSE_MIME_TYPE,
    "language": TARGET_LANGUAGE,
    "placeholder_to_content": json.loads(PLACEHOLDER_TO_VALUE),
}

if OPTIMIZATION_METRIC_1 == "NA":
    params["eval_metrics_types"] = [OPTIMIZATION_METRIC]
    params["eval_metrics_weights"] = [1.0]
else:
    metrics = []
    weights = []
    for metric in [OPTIMIZATION_METRIC_1, OPTIMIZATION_METRIC_2, OPTIMIZATION_METRIC_3]:
        if metric == "NA":
            break
        metrics.append(metric)
        weights.append(OPTIMIZATION_METRIC_1_WEIGHT)
    params["eval_metrics_types"] = metrics
    params["eval_metrics_weights"] = weights

job = run_apd(params, OUTPUT_PATH, display_name)
print(f"Job ID: {job.name}")

progress_form = ProgressForm()
while progress_form.monitor_progress(job, params):
    time.sleep(5)

# Step 6: Inspect the Results
You can use the following cell to inspect all the predictions made by all the
generated templates during one or multiple VAPO runs.

In [None]:
RESULT_PATH = "[GCS_PATH]"  # @param {type:"string"}
# @markdown * Specify a GCS path that contains artifacts of a single or multiple VAPO runs.

results_ui = ResultsUI(RESULT_PATH)

results_df_html = """
<style>
  .scrollable {
    width: 100%;
    height: 80px;
    overflow-y: auto;
    overflow-x: hidden;  /* Hide horizontal scrollbar */
  }
  tr:nth-child(odd) {
    background: var(--colab-highlighted-surface-color);
  }
  tr:nth-child(even) {
    background-color: var(--colab-primary-surface-color);
  }
  th {
    background-color: var(--colab-highlighted-surface-color);
  }
</style>
"""

display(HTML(results_df_html))
display(results_ui.get_container())