Skip to content

Commit

Permalink
fix(components): Only run preview.llm.bulk_inference after tuning t…
Browse files Browse the repository at this point in the history
…hird-party models with RLHF

PiperOrigin-RevId: 601226133
  • Loading branch information
Googler committed Jan 24, 2024
1 parent f65bb0f commit b9e08de
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 14 deletions.
1 change: 1 addition & 0 deletions components/google-cloud/RELEASE.md
Expand Up @@ -6,6 +6,7 @@
* Add Vertex model get component (`v1.model.ModelGetOp`).
* Migrate to Protobuf 4 (`protobuf>=4.21.1,<5`). Require `kfp>=2.6.0`.
* Support setting version aliases in (`v1.model.ModelUploadOp`).
* Only run `preview.llm.bulk_inference` pipeline after RLHF tuning for third-party models when `eval_dataset` is provided.

## Release 2.8.0
* Release AutoSxS pipeline to preview.
Expand Down
Expand Up @@ -17,7 +17,7 @@
_DEFAULT_AUTOSXS_IMAGE_TAG = '20240116_0507_RC00'

def get_private_image_tag() -> str:
return os.getenv('PRIVATE_IMAGE_TAG') or '20231213_0507_RC00'
return os.getenv('PRIVATE_IMAGE_TAG') or '20240124_0507_RC00'


def get_autosxs_image_tag() -> str:
Expand Down
Expand Up @@ -572,3 +572,31 @@ def get_uri(artifact: dsl.Input[dsl.Artifact], is_dir: bool = False) -> str: #
@dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
def get_empty_string() -> str:
return ''


@dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
def validate_rlhf_inputs(
large_model_reference: str,
eval_dataset: Optional[str] = None,
) -> None:
"""Checks user-provided arguments are valid for the RLHF pipeline."""
models_that_support_bulk_inference = {
't5-small',
't5-large',
't5-xl',
't5-xxl',
'llama-2-7b',
'llama-2-7b-chat',
'llama-2-13b',
'llama-2-13b-chat',
}
if (
eval_dataset
and large_model_reference not in models_that_support_bulk_inference
):
raise ValueError(
f'eval_dataset not supported for {large_model_reference}. '
'Please set this value to None when tuning this model. '
'This model can be evaluated after tuning using Batch or Online '
'Prediction.'
)
Expand Up @@ -68,7 +68,7 @@ def rlhf_pipeline(
kl_coeff: Coefficient for KL penalty. This regularizes the policy model and penalizes if it diverges from its initial distribution. If set to 0, the reference language model is not loaded into memory. Default value is 0.1.
instruction: This field lets the model know what task it needs to perform. Base models have been trained over a large set of varied instructions. You can give a simple and intuitive description of the task and the model will follow it, e.g. "Classify this movie review as positive or negative" or "Translate this sentence to Danish". Do not specify this if your dataset already prepends the instruction to the inputs field.
deploy_model: Whether to deploy the model to an endpoint in `us-central1`. Default is True.
eval_dataset: Optional Cloud storage path to an evaluation dataset. If provided, inference will be performed on this dataset after training. The dataset format is jsonl. Each example in the dataset must contain a field `input_text` that contains the prompt.
eval_dataset: Optional Cloud storage path to an evaluation dataset. Note, eval dataset can only be provided for third-party models. If provided, inference will be performed on this dataset after training. The dataset format is jsonl. Each example in the dataset must contain a field `input_text` that contains the prompt.
project: Project used to run custom jobs. If not specified the project used to run the pipeline will be used.
location: Location used to run custom jobs. If not specified the location used to run the pipeline will be used.
tensorboard_resource_id: Optional tensorboard resource id in format `projects/{project_number}/locations/{location}/tensorboards/{tensorboard_id}`. If provided, tensorboard metrics will be uploaded to this location.
Expand All @@ -78,6 +78,12 @@ def rlhf_pipeline(
endpoint_resource_name: Path the Online Prediction Endpoint. This will be an empty string if the model was not deployed.
"""
# fmt: on

function_based.validate_rlhf_inputs(
large_model_reference=large_model_reference,
eval_dataset=eval_dataset,
).set_display_name('Validate Inputs')

reward_model_pipeline = (
reward_model_graph.pipeline(
preference_dataset=preference_dataset,
Expand Down Expand Up @@ -110,22 +116,30 @@ def rlhf_pipeline(
tensorboard_resource_id=tensorboard_resource_id,
).set_display_name('Reinforcement Learning')

should_perform_inference = function_based.value_exists(
has_inference_dataset = function_based.value_exists(
value=eval_dataset
).set_display_name('Resolve Inference Dataset')
with kfp.dsl.Condition(
should_perform_inference.output == True, name='Perform Inference' # pylint: disable=singleton-comparison
has_inference_dataset.output == True, # pylint: disable=singleton-comparison
name='Perform Inference',
):
component.infer_pipeline(
project=project,
location=location,
large_model_reference=large_model_reference,
model_checkpoint=rl_model_pipeline.outputs['output_model_path'],
prompt_dataset=eval_dataset,
prompt_sequence_length=prompt_sequence_length,
target_sequence_length=target_sequence_length,
instruction=instruction,
)
has_model_checkpoint = function_based.value_exists(
value=rl_model_pipeline.outputs['output_model_path']
).set_display_name('Resolve Model Checkpoint')
with kfp.dsl.Condition(
has_model_checkpoint.output == True, # pylint: disable=singleton-comparison
name='Test Model Checkpoint Exists',
):
component.infer_pipeline(
project=project,
location=location,
large_model_reference=large_model_reference,
model_checkpoint=rl_model_pipeline.outputs['output_model_path'],
prompt_dataset=eval_dataset,
prompt_sequence_length=prompt_sequence_length,
target_sequence_length=target_sequence_length,
instruction=instruction,
)

llm_model_handler = deployment_graph.pipeline(
output_adapter_path=rl_model_pipeline.outputs['output_adapter_path'],
Expand Down

0 comments on commit b9e08de

Please sign in to comment.