Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(component): Migrate AutoSxS pipeline to preview and move related…
… files to _implementation/llm directory to help Model Eval team use side by side metrics as part of their pipeline PiperOrigin-RevId: 588917968
- Loading branch information
Googler
committed
Dec 13, 2023
1 parent
efeed83
commit ef95981
Showing
10 changed files
with
737 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
137 changes: 137 additions & 0 deletions
137
...s/google-cloud/google_cloud_pipeline_components/_implementation/llm/arbiter_preprocess.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright 2023 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""KFP Container component for preprocessing predictions for the Arbiter.""" | ||
|
||
import os | ||
from typing import Dict, List | ||
|
||
from google_cloud_pipeline_components import _placeholders | ||
from google_cloud_pipeline_components import utils as gcpc_utils | ||
from google_cloud_pipeline_components._implementation.llm import utils | ||
from kfp import dsl | ||
|
||
|
||
def _resolve_image() -> str: | ||
"""Determines the image URI to create a container from.""" | ||
return ( | ||
os.environ.get('AUTOSXS_IMAGE_OVERRIDE') | ||
or utils.get_default_image_uri('autosxs')) | ||
|
||
|
||
# pylint: disable=unused-argument,dangerous-default-value | ||
@dsl.container_component | ||
def arbiter_preprocess( | ||
evaluation_dataset: str, | ||
id_columns: List[str], | ||
response_column_a: str, | ||
response_column_b: str, | ||
task: str, | ||
is_bp_output_a: bool, | ||
is_bp_output_b: bool, | ||
autorater_prompt_parameters: Dict[str, Dict[str, str]], | ||
preprocessed_evaluation_dataset: dsl.Output[dsl.Dataset], # pylint: disable=unused-argument # pytype: disable=unsupported-operands | ||
preprocessed_evaluation_dataset_uri: dsl.OutputPath(str), # pylint: disable=unused-argument # pytype: disable=invalid-annotation | ||
gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation | ||
prediction_uris_a: str = '', | ||
prediction_uris_b: str = '', | ||
model_a_prompt_parameters: Dict[str, Dict[str, str]] = {}, | ||
model_b_prompt_parameters: Dict[str, Dict[str, str]] = {}, | ||
human_preference_column: str = '', | ||
) -> dsl.ContainerSpec: # pylint: disable=g-doc-args | ||
"""Preprocesses predictions tables for the AutoSxS Arbiter. | ||
Args: | ||
evaluation_dataset: GCS or BigQuery URIs representing a dataset of prompts | ||
and responses. | ||
id_columns: The columns which distinguish unique evaluation examples. | ||
response_column_a: The column containing responses for model a. | ||
response_column_b: The column containing responses for model a. | ||
task: Task to evaluate. | ||
output_path: Path to write the path where preprocessed predictions are | ||
stored. | ||
is_bp_output_a: If True, the prediction URIs will be parsed as if they came | ||
from Vertex Batch Prediction, where response_column_a represents a field | ||
in the model output containing the response. If False, the expected format | ||
will be a table containing all model_prompt_parameters and the | ||
response_column. | ||
is_bp_output_b: If True, the prediction URIs will be parsed as if they came | ||
from Vertex Batch Prediction, where response_column_b represents a field | ||
in the model output containing the response. If False, the expected format | ||
will be a table containing all model_prompt_parameters and the | ||
response_column. | ||
prediction_uris: A list of GCS or BigQuery URIs representing a dataset of | ||
prompts and responses for model a. | ||
prediction_uris: A list of GCS or BigQuery URIs representing a dataset of | ||
prompts and responses for model b. | ||
model_a_prompt_parameters: Map of model A prompt template parameters to | ||
columns or templates. | ||
model_b_prompt_parameters: Map of model B prompt template parameters to | ||
columns or templates. | ||
autorater_prompt_parameters: Map of autorater prompt template parameters to | ||
columns or templates. | ||
human_preference_column: The column containing ground truths. The default | ||
value is an empty string if not be provided by users. | ||
Returns: | ||
preprocessed_evaluation_dataset: Dataset of the table containing the inputs | ||
expected by the Arbiter. | ||
preprocessed_evaluation_dataset_uri: URI of the table containing the inputs | ||
expected by the Arbiter. | ||
gcp_resources: Tracker for GCP resources created by this component. | ||
""" | ||
return gcpc_utils.build_serverless_customjob_container_spec( | ||
project=_placeholders.PROJECT_ID_PLACEHOLDER, | ||
location=_placeholders.LOCATION_PLACEHOLDER, | ||
custom_job_payload=utils.build_payload( | ||
display_name='arbiter_preprocess', | ||
machine_type='n1-standard-4', | ||
image_uri=_resolve_image(), | ||
args=[ | ||
'--', # Used to mark the start of component flags. | ||
'arbiter_preprocess', | ||
f'--evaluation_dataset={evaluation_dataset}', | ||
f'--prediction_uris_a={prediction_uris_a}', | ||
f'--prediction_uris_b={prediction_uris_b}', | ||
( | ||
'--id_columns=' | ||
"{{$.inputs.parameters['id_columns'].json_escape[0]}}" | ||
), | ||
( | ||
'--autorater_prompt_parameters=' | ||
"{{$.inputs.parameters['autorater_prompt_parameters']" | ||
'.json_escape[0]}}' | ||
), | ||
( | ||
'--model_a_prompt_parameters=' | ||
"{{$.inputs.parameters['model_a_prompt_parameters']" | ||
'.json_escape[0]}}' | ||
), | ||
( | ||
'--model_b_prompt_parameters=' | ||
"{{$.inputs.parameters['model_b_prompt_parameters']" | ||
'.json_escape[0]}}' | ||
), | ||
f'--response_column_a={response_column_a}', | ||
f'--response_column_b={response_column_b}', | ||
f'--human_preference_column={human_preference_column}', | ||
f'--task={task}', | ||
f'--is_batch_prediction_output_a={is_bp_output_a}', | ||
f'--is_batch_prediction_output_b={is_bp_output_b}', | ||
f'--output_dir={dsl.PIPELINE_ROOT_PLACEHOLDER}', | ||
f'--preprocessed_evaluation_dataset_uri={preprocessed_evaluation_dataset_uri}', | ||
'--executor_input={{$.json_escape[1]}}', | ||
], | ||
), | ||
gcp_resources=gcp_resources, | ||
) |
105 changes: 105 additions & 0 deletions
105
...ents/google-cloud/google_cloud_pipeline_components/_implementation/llm/autosxs_arbiter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright 2023 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""KFP Container component that performs AutoSxS.""" | ||
|
||
import os | ||
from typing import Any, Dict, List | ||
|
||
from google_cloud_pipeline_components import _placeholders | ||
from google_cloud_pipeline_components import utils as gcpc_utils | ||
from google_cloud_pipeline_components._implementation.llm import utils | ||
from kfp import dsl | ||
|
||
|
||
def _resolve_image() -> str: | ||
"""Determines the image URI to create a container from.""" | ||
return ( | ||
os.environ.get('AUTOSXS_IMAGE_OVERRIDE') | ||
or utils.get_default_image_uri('autosxs')) | ||
|
||
|
||
def _get_prediction_endpoint_overrides() -> str: | ||
"""Used for integration tests to override the prediction endpoint.""" | ||
return os.environ.get('PREDICTION_ENDPOINT_OVERRIDES', '') | ||
|
||
|
||
@dsl.container_component | ||
def autosxs_arbiter( | ||
inference_output_uri: str, | ||
id_columns: List[str], | ||
task: str, | ||
judgments: dsl.Output[dsl.Dataset], # pylint: disable=unused-argument # pytype: disable=unsupported-operands | ||
judgments_uri: dsl.OutputPath(str), # pytype: disable=invalid-annotation | ||
gcp_resources: dsl.OutputPath(str), | ||
metadata: dsl.OutputPath(str), | ||
human_preference_column: str = '', | ||
judgments_format: str = 'jsonl', | ||
bigquery_destination_prefix: str = '', | ||
experimental_args: Dict[str, Any] = {}, | ||
) -> dsl.ContainerSpec: # pylint: disable=g-doc-args | ||
"""Evaluate two models using an autorater. | ||
Args: | ||
inference_output_uri: Directory of model A's inference output. | ||
id_columns: The columns which distinguish unique evaluation examples. | ||
human_preference_column: Human preference column included in our inference | ||
output. | ||
task: Evaluation task in the form {task}@{version}. task can be one of | ||
"summarization", "question_answer". Version is an integer with 3 digits or | ||
"latest". Ex: summarization@001 or question_answer@latest. | ||
judgments_format: The format to write judgments to. Can be either 'json' or | ||
'bigquery'. | ||
bigquery_destination_prefix: BigQuery table to write judgments to if the | ||
specified format is 'bigquery'. | ||
experimental_args: Experimentally released arguments. Subject to change. | ||
Returns: | ||
judgments: Individual judgments used to calculate the win rates. | ||
judgments_uri: URI of the Judgments Artifact. | ||
gcp_resources: Tracker for GCP resources created by this component. | ||
metadata: Computed runtime metrics metadata from this component. | ||
""" | ||
return gcpc_utils.build_serverless_customjob_container_spec( | ||
project=_placeholders.PROJECT_ID_PLACEHOLDER, | ||
# Hardcode location to us-central1 for text-bison availability. | ||
location='us-central1', | ||
custom_job_payload=utils.build_payload( | ||
display_name='autosxs_arbiter', | ||
machine_type='n1-standard-4', | ||
image_uri=_resolve_image(), | ||
args=[ | ||
'--', # Used to mark the start of component flags. | ||
'arbiter', | ||
f'--inference_output_uri={inference_output_uri}', | ||
f'--human_preference_column={human_preference_column}', | ||
f'--task={task}', | ||
f'--prediction_endpoint_overrides={_get_prediction_endpoint_overrides()}', | ||
f'--output_dir={dsl.PIPELINE_ROOT_PLACEHOLDER}', | ||
f'--judgments_uri={judgments_uri}', | ||
f'--judgments_format={judgments_format}', | ||
f'--bigquery_destination_prefix={bigquery_destination_prefix}', | ||
( | ||
'--id_columns=' | ||
"{{$.inputs.parameters['id_columns'].json_escape[0]}}" | ||
), | ||
( | ||
'--experimental_args=' | ||
"{{$.inputs.parameters['experimental_args'].json_escape[0]}}" | ||
), | ||
'--executor_input={{$.json_escape[1]}}', | ||
f'--metadata_path={metadata}', | ||
], | ||
), | ||
gcp_resources=gcp_resources, | ||
) |
66 changes: 66 additions & 0 deletions
66
...le-cloud/google_cloud_pipeline_components/_implementation/llm/autosxs_metrics_computer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright 2023 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""KFP Container component for computing AutoSXS metrics.""" | ||
|
||
import os | ||
|
||
from google_cloud_pipeline_components import _placeholders | ||
from google_cloud_pipeline_components import utils as gcpc_utils | ||
from google_cloud_pipeline_components._implementation.llm import utils | ||
from kfp import dsl | ||
|
||
|
||
def _resolve_image() -> str: | ||
"""Determines the image URI to create a container from.""" | ||
return os.environ.get( | ||
'AUTOSXS_IMAGE_OVERRIDE' | ||
) or utils.get_default_image_uri('autosxs') | ||
|
||
|
||
@dsl.container_component | ||
def autosxs_metrics_computer( | ||
judgments_dir: str, | ||
has_human_preference: bool, | ||
autosxs_metrics: dsl.Output[dsl.Metrics], # pylint: disable=unused-argument # pytype: disable=unsupported-operands | ||
gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation | ||
) -> dsl.ContainerSpec: # pylint: disable=g-doc-args | ||
"""Compute AutoSXS metrics using judgments outputs from Arbiter. | ||
Args: | ||
judgments_dir: Path where store the Judgments. | ||
has_human_preference: Boolean value. True if users provided human preference | ||
data, otherwise false. | ||
Returns: | ||
autosxs_metrics: Autosxs win rate metrics and human alignment metrics. | ||
gcp_resources: Tracker for GCP resources created by this component. | ||
""" | ||
return gcpc_utils.build_serverless_customjob_container_spec( | ||
project=_placeholders.PROJECT_ID_PLACEHOLDER, | ||
# Hardcode location to us-central1 for text-bison availability. | ||
location='us-central1', | ||
custom_job_payload=utils.build_payload( | ||
display_name='autosxs_metrics_computer', | ||
machine_type='n1-standard-4', | ||
image_uri=_resolve_image(), | ||
args=[ | ||
'--', # Used to mark the start of component flags. | ||
'autosxs_metrics', | ||
f'--judgments_dir={judgments_dir}', | ||
f'--has_human_preference={has_human_preference}', | ||
'--executor_input={{$.json_escape[1]}}', | ||
], | ||
), | ||
gcp_resources=gcp_resources, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.