Skip to content

Commit

Permalink
feat(components): add resolve_reference_model_metadata to rlhf_prepro…
Browse files Browse the repository at this point in the history
…cessor component

PiperOrigin-RevId: 618957077
  • Loading branch information
Googler committed Apr 16, 2024
1 parent da80440 commit 6311354
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
def pipeline(
output_adapter_path: str,
large_model_reference: str,
policy_model_reference: str,
model_display_name: Optional[str] = None,
deploy_model: bool = True,
encryption_spec_key_name: str = '',
Expand All @@ -45,6 +46,7 @@ def pipeline(
Args:
output_adapter_path: Path to the trained model adapter if LoRA tuning was used.
large_model_reference: Name of the base model. Supported values are `text-bison@001`, `t5-small`, `t5-large`, `t5-xl` and `t5-xxl`. `text-bison@001` and `t5-small` are supported in `us-central1` and `europe-west4`. `t5-large`, `t5-xl` and `t5-xxl` are only supported in `europe-west4`.
policy_model_reference: The name of the model for deployment. The name should be in capitalized snake case format.
model_display_name: Name of the fine-tuned model shown in the Model Registry. If not provided, a default name will be created.
deploy_model: Whether to deploy the model to an endpoint in `us-central1`. Default is True.
encryption_spec_key_name: Customer-managed encryption key. If this is set, then all resources created by the CustomJob will be encrypted with the provided encryption key. Note that this is not supported for TPU at the moment.
Expand All @@ -68,14 +70,8 @@ def pipeline(
.set_display_name('Resolve Model Display Name')
)

reference_model_metadata = function_based.resolve_reference_model_metadata(
large_model_reference=large_model_reference,
).set_display_name('Resolve Model Metadata')

upload_model = function_based.resolve_upload_model(
large_model_reference=reference_model_metadata.outputs[
'large_model_reference'
]
large_model_reference=policy_model_reference,
).set_display_name('Resolve Upload Model')
upload_task = upload_llm_model.refined_upload_llm_model(
project=_placeholders.PROJECT_ID_PLACEHOLDER,
Expand All @@ -90,9 +86,7 @@ def pipeline(
).set_display_name('Upload Model')
deploy_model = function_based.resolve_deploy_model(
deploy_model=deploy_model,
large_model_reference=reference_model_metadata.outputs[
'large_model_reference'
],
large_model_reference=policy_model_reference,
).set_display_name('Resolve Deploy Model')
deploy_task = deploy_llm_model.deploy_llm_model(
project=_placeholders.PROJECT_ID_PLACEHOLDER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
DO NOT EDIT - This file is generated, manual changes will be overridden.
"""

IMAGE_TAG = '20240407_1707'
IMAGE_TAG = '20240414_0507'
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def pipeline(
input_reward_adapter_path: str,
input_preference_dataset_path: str,
large_model_reference: str,
reward_model_reference: str,
policy_model_reference: str,
policy_model_path: str,
prompt_sequence_length: int = 512,
target_sequence_length: int = 64,
lora_dim: int = 1,
Expand All @@ -64,6 +67,9 @@ def pipeline(
input_reward_adapter_path: Path to the reward LoRA adapter to use during reinforcement learning.
input_preference_dataset_path: Path to preference dataset used by the reward model.
large_model_reference: Name of the base model. Supported values are `text-bison@001`, `t5-small`, `t5-large`, `t5-xl` and `t5-xxl`. `text-bison@001` and `t5-small` are supported in `us-central1` and `europe-west4`. `t5-large`, `t5-xl` and `t5-xxl` are only supported in `europe-west4`.
reward_model_reference: Name of the reward model. The name should be in capitalized snake case format.
policy_model_reference: Name of the policy model. The name should be in capitalized snake case format.
policy_model_path: The model checkpoint path to the reinforcer model.
prompt_sequence_length: Maximum tokenized sequence length for input text. Higher values increase memory overhead. This value should be at most 8192. Default value is 512.
target_sequence_length: Maximum tokenized sequence length for target text. Higher values increase memory overhead. This value should be at most 1024. Default value is 64.
lora_dim: The rank of the LoRA adapter. If >0, then use LoRA-tuning. If =0, then use full-tuning. Default is 1.
Expand All @@ -90,10 +96,6 @@ def pipeline(
use_test_spec=env.get_use_test_machine_spec(),
).set_display_name('Resolve Machine Spec')

reference_model_metadata = function_based.resolve_reference_model_metadata(
large_model_reference=large_model_reference,
).set_display_name('Resolve Model Metadata')

processed_dataset = preprocess_chat_dataset.preprocess_chat_dataset(
large_model_reference=large_model_reference,
input_dataset_uri=prompt_dataset,
Expand All @@ -109,9 +111,7 @@ def pipeline(
# Target field name does not matter because this field is not used.
targets_field_name='non_existent_targets_field_name',
output_split_name=env.TRAIN_SPLIT,
large_model_reference=reference_model_metadata.outputs[
'large_model_reference'
],
large_model_reference=policy_model_reference,
instruction=instruction,
encryption_spec_key_name=encryption_spec_key_name,
)
Expand All @@ -122,17 +122,13 @@ def pipeline(
accelerator_type=machine_spec.outputs['accelerator_type'],
).set_display_name('Resolve Reinforcer Image URI')
num_microbatches = function_based.resolve_num_microbatches(
large_model_reference=reference_model_metadata.outputs[
'large_model_reference'
]
large_model_reference=policy_model_reference,
).set_display_name('Resolve Number of Microbatches')
rl_model = (
reinforcer.reinforcer(
project=project,
location=machine_spec.outputs['tuning_location'],
input_reference_model_path=reference_model_metadata.outputs[
'reference_model_path'
],
input_reference_model_path=policy_model_path,
input_reward_model_path=input_reward_model_path,
input_reward_adapter_path=input_reward_adapter_path,
input_dataset_path=prompt_dataset_importer.outputs[
Expand All @@ -142,12 +138,8 @@ def pipeline(
train_steps=reinforcement_learning_train_steps,
accelerator_type=machine_spec.outputs['accelerator_type'],
accelerator_count=machine_spec.outputs['accelerator_count'],
large_model_reference=reference_model_metadata.outputs[
'large_model_reference'
],
reward_model_reference=reference_model_metadata.outputs[
'reward_model_reference'
],
large_model_reference=policy_model_reference,
reward_model_reference=reward_model_reference,
machine_type=machine_spec.outputs['machine_type'],
image_uri=rl_image_uri.output,
inputs_sequence_length=prompt_sequence_length,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

PipelineOutput = NamedTuple(
'Outputs',
reward_model_base_path=str,
reward_model_adapter_path=str,
reward_dataset_path=str,
)
Expand All @@ -39,6 +38,8 @@
def pipeline(
preference_dataset: str,
large_model_reference: str,
reward_model_reference: str,
reward_model_path: str,
prompt_sequence_length: int = 512,
target_sequence_length: int = 64,
batch_size: int = 64,
Expand All @@ -59,6 +60,8 @@ def pipeline(
Args:
preference_dataset: Cloud storage path to a human preference JSONL dataset used to train a reward model. Each example in a preference dataset must contain `candidate_0` and `candidate_1` fields that contain candidate responses, `choice` that specifies the preferred candidate and either `input_text` (if tuning a text model) or `messages` (if tuning a chat model). Chat datasets must contain at least 1 message in a `messages` field. Each message must be valid JSON that contains `author` and `content` fields, where valid `author` values are `user` and `assistant` and `content` must be non-empty. Each row may contain multiple messages, but the first and last author must be the `user`. An optional `context` field may be provided for each example in a chat dataset. If provided, the `context` will preprended to the message `content`. The `instruction` serves as the default context. (Useful if most messages use the same system-level context.) Any context provided in the example will override the default value.
large_model_reference: Name of the base model. Supported values are `text-bison@001`, `t5-small`, `t5-large`, `t5-xl` and `t5-xxl`. `text-bison@001` and `t5-small` are supported in `us-central1` and `europe-west4`. `t5-large`, `t5-xl` and `t5-xxl` are only supported in `europe-west4`.
reward_model_reference: Name of the base model. The name should be in capitalized snake case format.
reward_model_path: The model checkpoint path for the reward model.
prompt_sequence_length: Maximum tokenized sequence length for input text. Higher values increase memory overhead. This value should be at most 8192. Default value is 512.
target_sequence_length: Maximum tokenized sequence length for target text. Higher values increase memory overhead. This value should be at most 1024. Default value is 64.
batch_size: Number of examples in each finetuning step. Default is 64.
Expand All @@ -73,7 +76,6 @@ def pipeline(
encryption_spec_key_name: Customer-managed encryption key. If this is set, then all resources created by the CustomJob will be encrypted with the provided encryption key. Note that this is not supported for TPU at the moment.
Returns:
reward_model_base_path: Path to the base model used by the reward model.
reward_model_adapter_path: Path to the output LoRA adapter.
reward_dataset_path: Preference dataset use for tuning the reward model.
"""
Expand All @@ -86,10 +88,6 @@ def pipeline(
use_test_spec=env.get_use_test_machine_spec(),
).set_display_name('Resolve Machine Spec')

reference_model_metadata = function_based.resolve_reference_model_metadata(
large_model_reference=large_model_reference,
).set_display_name('Resolve Model Metadata')

processed_preference_dataset = (
preprocess_chat_dataset.preprocess_chat_dataset(
large_model_reference=large_model_reference,
Expand All @@ -113,9 +111,7 @@ def pipeline(
comma_separated_candidates_field_names=comma_separated_candidates_field_names.output,
choice_field_name=choice_column,
split=env.TRAIN_SPLIT,
large_model_reference=reference_model_metadata.outputs[
'reward_model_reference'
],
large_model_reference=reward_model_reference,
instruction=instruction,
encryption_spec_key_name=encryption_spec_key_name,
)
Expand All @@ -132,9 +128,7 @@ def pipeline(
comma_separated_candidates_field_names=comma_separated_candidates_field_names.output,
choice_field_name=choice_column,
split=env.TRAIN_SPLIT,
large_model_reference=reference_model_metadata.outputs[
'reward_model_reference'
],
large_model_reference=reward_model_reference,
instruction=instruction,
encryption_spec_key_name=encryption_spec_key_name,
)
Expand All @@ -146,17 +140,13 @@ def pipeline(
accelerator_type=machine_spec.outputs['accelerator_type'],
).set_display_name('Resolve Reward Model Image URI')
num_microbatches = function_based.resolve_num_microbatches(
large_model_reference=reference_model_metadata.outputs[
'reward_model_reference'
]
large_model_reference=reward_model_reference,
).set_display_name('Resolve Number of Microbatches')
reward_model = (
reward_model_trainer.reward_model_trainer(
project=project,
location=machine_spec.outputs['tuning_location'],
input_model_path=reference_model_metadata.outputs[
'reward_model_path'
],
input_model_path=reward_model_path,
input_dataset_path=preference_dataset_importer.outputs[
'output_dataset_path'
],
Expand All @@ -166,9 +156,7 @@ def pipeline(
train_steps=reward_model_train_steps,
accelerator_type=machine_spec.outputs['accelerator_type'],
accelerator_count=machine_spec.outputs['accelerator_count'],
large_model_reference=reference_model_metadata.outputs[
'reward_model_reference'
],
large_model_reference=reward_model_reference,
machine_type=machine_spec.outputs['machine_type'],
image_uri=reward_model_image_uri.output,
inputs_sequence_length=prompt_sequence_length,
Expand All @@ -185,9 +173,6 @@ def pipeline(
)

return PipelineOutput(
reward_model_base_path=reference_model_metadata.outputs[
'reward_model_path'
],
reward_model_adapter_path=reward_model.outputs['output_adapter_path'],
reward_dataset_path=preference_dataset_importer.outputs[
'output_dataset_path'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,37 @@

@dsl.container_component
def rlhf_preprocessor(
large_model_reference: str,
gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation
has_tensorboard_id: dsl.OutputPath(bool), # pytype: disable=invalid-annotation
has_inference_dataset: dsl.OutputPath(bool), # pytype: disable=invalid-annotation
metadata_large_model_reference: dsl.OutputPath(str), # pytype: disable=invalid-annotation
metadata_reference_model_path: dsl.OutputPath(str), # pytype: disable=invalid-annotation
metadata_reward_model_reference: dsl.OutputPath(str), # pytype: disable=invalid-annotation
metadata_reward_model_path: dsl.OutputPath(str), # pytype: disable=invalid-annotation
evaluation_dataset: str = '',
tensorboard_resource_id: str = '',
input_reference_model_path: str = '',
image_uri: str = utils.get_default_image_uri('refined_cpu', ''),
) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
# fmt: off
"""Preprocess RLHF pipeline inputs.
Args:
large_model_reference: The model for fine tuning.
evaluation_dataset: Path to evaluation data.
tensorboard_resource_id: TensorBoard resource id.
metadata_large_model_reference: The base model for fine tuning. The name should be in capitalized snake case format.
metadata_reference_model_path: The model checkpoint path for the reinforcer model
metadata_reward_model_reference: The base model for training reward model. The name should be in capitalized snake case format.
metadata_reward_model_path: The model checkpoint path for the reward model.
Returns:
gcp_resources: GCP resources that can be used to track the custom job.
has_tensorboard_id: Whether a tensorboard id is provided.
has_inference_dataset: Whether inference data are provided.
"""
# fmt: on
return gcpc_utils.build_serverless_customjob_container_spec(
project=_placeholders.PROJECT_ID_PLACEHOLDER,
location=_placeholders.LOCATION_PLACEHOLDER,
Expand All @@ -52,8 +65,14 @@ def rlhf_preprocessor(
'--app_name=rlhf_preprocessor',
f'--evaluation_dataset={evaluation_dataset}',
f'--tensorboard_resource_id={tensorboard_resource_id}',
f'--large_model_reference={large_model_reference}',
f'--input_reference_model_path={input_reference_model_path}',
f'--has_tensorboard_id_path={has_tensorboard_id}',
f'--has_inference_dataset_path={has_inference_dataset}',
f'--metadata_large_model_reference_path={metadata_large_model_reference}',
f'--metadata_reference_model_path_path={metadata_reference_model_path}',
f'--metadata_reward_model_reference_path={metadata_reward_model_reference}',
f'--metadata_reward_model_path_path={metadata_reward_model_path}',
],
),
gcp_resources=gcp_resources,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google_cloud_pipeline_components._implementation.llm import function_based
from google_cloud_pipeline_components._implementation.llm import reinforcement_learning_graph
from google_cloud_pipeline_components._implementation.llm import reward_model_graph
from google_cloud_pipeline_components._implementation.llm import rlhf_preprocessor
from google_cloud_pipeline_components._implementation.llm import validate_pipeline
from google_cloud_pipeline_components.preview.llm.infer import component
import kfp
Expand Down Expand Up @@ -94,11 +95,23 @@ def rlhf_pipeline(
eval_dataset=eval_dataset,
).set_display_name('Validate Inputs')

preprocess_metadata = rlhf_preprocessor.rlhf_preprocessor(
large_model_reference=large_model_reference,
evaluation_dataset=eval_dataset,
tensorboard_resource_id=tensorboard_resource_id,
).set_display_name('Preprocess Inputs')

reward_model_pipeline = (
(
reward_model_graph.pipeline(
preference_dataset=preference_dataset,
large_model_reference=large_model_reference,
reward_model_reference=preprocess_metadata.outputs[
'metadata_reward_model_reference'
],
reward_model_path=preprocess_metadata.outputs[
'metadata_reward_model_path'
],
prompt_sequence_length=prompt_sequence_length,
target_sequence_length=target_sequence_length,
eval_dataset=validate_pipeline_task.outputs[
Expand All @@ -120,8 +133,8 @@ def rlhf_pipeline(
)
rl_model_pipeline = reinforcement_learning_graph.pipeline(
prompt_dataset=prompt_dataset,
input_reward_model_path=reward_model_pipeline.outputs[
'reward_model_base_path'
input_reward_model_path=preprocess_metadata.outputs[
'metadata_reward_model_path'
],
input_reward_adapter_path=reward_model_pipeline.outputs[
'reward_model_adapter_path'
Expand All @@ -130,6 +143,15 @@ def rlhf_pipeline(
'reward_dataset_path'
],
large_model_reference=large_model_reference,
reward_model_reference=preprocess_metadata.outputs[
'metadata_reward_model_reference'
],
policy_model_reference=preprocess_metadata.outputs[
'metadata_large_model_reference'
],
policy_model_path=preprocess_metadata.outputs[
'metadata_reference_model_path'
],
prompt_sequence_length=prompt_sequence_length,
target_sequence_length=target_sequence_length,
reinforcement_learning_rate_multiplier=reinforcement_learning_rate_multiplier,
Expand Down Expand Up @@ -174,6 +196,9 @@ def rlhf_pipeline(
llm_model_handler = deployment_graph.pipeline(
output_adapter_path=rl_model_pipeline.outputs['output_adapter_path'],
large_model_reference=large_model_reference,
policy_model_reference=preprocess_metadata.outputs[
'metadata_large_model_reference'
],
model_display_name=model_display_name,
deploy_model=deploy_model,
encryption_spec_key_name=encryption_spec_key_name,
Expand Down

0 comments on commit 6311354

Please sign in to comment.