Skip to content

Commit

Permalink
chore: LLM - Refactored the Distillation feature to support new disti…
Browse files Browse the repository at this point in the history
…llation pipeline

* Extracted the `_TuningJob` base class from the `_LanguageModelTuningJob` class. This allows creating jobs that rethrn the tuned model as something other than `LanguageModel`.
* Extracted the `_tuning.submit_distillation_pipeline_job` function from the `DistillationMixin` class. This allows submitting distillation jobs for models that cannot be wrapped in a `LanguageModel` class.

PiperOrigin-RevId: 629232787
  • Loading branch information
Ark-kun authored and Copybara-Service committed Apr 30, 2024
1 parent aab9c3e commit 3ce0126
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 66 deletions.
133 changes: 81 additions & 52 deletions vertexai/language_models/_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from vertexai.language_models import _language_models as tuning


class DistillationMixin:
_DISTILLATION_PIPELINE_URI = (
"https://us-kfp.pkg.dev/ml-pipeline/distillation/distillation/v1.0.0"
)
_DISTILLATION_PIPELINE_URI = (
"https://us-kfp.pkg.dev/ml-pipeline/distillation/distillation/v1.0.0"
)


class DistillationMixin:
def distill_from(
self,
*,
Expand Down Expand Up @@ -59,57 +60,85 @@ def distill_from(
else:
raise RuntimeError(f"Unsupported teacher model type: {teacher_model}")

pipeline_arguments = {
"teacher_model_reference": teacher_short_model_id,
"student_model_reference": student_short_model_id,
"dataset_uri": dataset,
"project": aiplatform_initializer.global_config.project,
"location": aiplatform_initializer.global_config.location,
}
if train_steps is not None:
pipeline_arguments["train_steps"] = train_steps
if learning_rate_multiplier is not None:
pipeline_arguments["learning_rate_multiplier"] = learning_rate_multiplier
if evaluation_spec is not None:
pipeline_arguments["evaluation_data_uri"] = evaluation_spec.evaluation_data
pipeline_arguments[
"evaluation_interval"
] = evaluation_spec.evaluation_interval
pipeline_arguments[
"enable_early_stopping"
] = evaluation_spec.enable_early_stopping
pipeline_arguments[
"enable_checkpoint_selection"
] = evaluation_spec.enable_checkpoint_selection
pipeline_arguments["tensorboard_resource_id"] = evaluation_spec.tensorboard
# pipeline_parameter_values["evaluation_output_root_dir"] = ...
if accelerator_type is not None:
pipeline_arguments["accelerator_type"] = accelerator_type
if aiplatform_initializer.global_config.encryption_spec_key_name is not None:
pipeline_arguments[
"encryption_spec_key_name"
] = aiplatform_initializer.global_config.encryption_spec_key_name
if max_context_length is not None:
pipeline_arguments["max_context_length"] = max_context_length
if model_display_name is None:
model_display_name = (
f"{student_short_model_id}"
f" distilled from {teacher_short_model_id}"
)
pipeline_arguments["model_display_name"] = model_display_name
# # Not exposing these parameters:
# temperature: Optional[float] = None,
# tpu_training_skip_cmek: Optional[bool] = None,
# api_endpoint: Optional[str] = None,
# version: Optional[str] = None,
pipeline_job = aiplatform.PipelineJob(
template_path=self._DISTILLATION_PIPELINE_URI,
display_name=None,
parameter_values=pipeline_arguments,
pipeline_job = submit_distillation_pipeline_job(
teacher_model=teacher_short_model_id,
student_model=student_short_model_id,
dataset=dataset,
train_steps=train_steps,
learning_rate_multiplier=learning_rate_multiplier,
evaluation_spec=evaluation_spec,
accelerator_type=accelerator_type,
model_display_name=model_display_name,
max_context_length=max_context_length,
)
pipeline_job.submit()
tuning_job = tuning._LanguageModelTuningJob(
base_model=self,
job=pipeline_job,
)
return tuning_job


def submit_distillation_pipeline_job(
*,
teacher_model: str,
student_model: str,
dataset: str,
train_steps: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
model_display_name: Optional[str] = None,
max_context_length: Optional[str] = None,
):
teacher_short_model_id = teacher_model.split("/")[-1]
student_short_model_id = student_model.split("/")[-1]
pipeline_arguments = {
"teacher_model_reference": teacher_model,
"student_model_reference": student_model,
"dataset_uri": dataset,
"project": aiplatform_initializer.global_config.project,
"location": aiplatform_initializer.global_config.location,
}
if train_steps is not None:
pipeline_arguments["train_steps"] = train_steps
if learning_rate_multiplier is not None:
pipeline_arguments["learning_rate_multiplier"] = learning_rate_multiplier
if evaluation_spec is not None:
pipeline_arguments["evaluation_data_uri"] = evaluation_spec.evaluation_data
pipeline_arguments[
"evaluation_interval"
] = evaluation_spec.evaluation_interval
pipeline_arguments[
"enable_early_stopping"
] = evaluation_spec.enable_early_stopping
pipeline_arguments[
"enable_checkpoint_selection"
] = evaluation_spec.enable_checkpoint_selection
pipeline_arguments["tensorboard_resource_id"] = evaluation_spec.tensorboard
# pipeline_parameter_values["evaluation_output_root_dir"] = ...
if accelerator_type is not None:
pipeline_arguments["accelerator_type"] = accelerator_type
if aiplatform_initializer.global_config.encryption_spec_key_name is not None:
pipeline_arguments[
"encryption_spec_key_name"
] = aiplatform_initializer.global_config.encryption_spec_key_name
if max_context_length is not None:
pipeline_arguments["max_context_length"] = max_context_length
if model_display_name is None:
model_display_name = (
f"{student_short_model_id}"
f" distilled from {teacher_short_model_id}"
)
pipeline_arguments["model_display_name"] = model_display_name
# # Not exposing these parameters:
# temperature: Optional[float] = None,
# tpu_training_skip_cmek: Optional[bool] = None,
# api_endpoint: Optional[str] = None,
# version: Optional[str] = None,
pipeline_job = aiplatform.PipelineJob(
template_path=_DISTILLATION_PIPELINE_URI,
display_name=None,
parameter_values=pipeline_arguments,
)
pipeline_job.submit()
return pipeline_job
59 changes: 45 additions & 14 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3491,33 +3491,41 @@ def _get_invalid_rlhf_model_msg(
)


class _LanguageModelTuningJob:
"""LanguageModelTuningJob represents a fine-tuning job."""
class _TuningJob:
"""TuningJob represents a fine-tuning job."""

def __init__(
self,
base_model: _LanguageModel,
job: aiplatform.PipelineJob,
):
"""Internal constructor. Do not call directly."""
self._base_model = base_model
self._job = job
self._model: Optional[_LanguageModel] = None
self._tuned_model_name: Optional[str] = None

def get_tuned_model(self) -> "_LanguageModel":
"""Blocks until the tuning is complete and returns a `LanguageModel` object."""
if self._model:
return self._model
def _get_tuned_model_name(self) -> str:
"""Extracts the tuned model name from the tuning pipeline job.
This method is used for both tuning, RLHF and distillation.
Returns:
The Vertex Model resource name of the tuned model.
"""
if self._tuned_model_name:
return self._tuned_model_name
self._job.wait()

# Getting tuned model from the pipeline.
model_task = None
# Searching for the model uploading task first.
# Note: Distillation does not have pipeline outputs yet.
upload_model_task_names = [
"upload-llm-model", # Most tuning pipelines
"upload-model", # New distillation pipeline uses "upload-model"
]
upload_model_tasks = [
task_info
for task_info in self._job.gca_resource.job_detail.task_details
if task_info.task_name == "upload-llm-model"
if task_info.task_name in upload_model_task_names
]
if len(upload_model_tasks) == 1:
model_task = upload_model_tasks[0]
Expand All @@ -3539,10 +3547,8 @@ def get_tuned_model(self) -> "_LanguageModel":
"output:model_resource_name"
].strip()
_LOGGER.info(f"Tuning has completed. Created Vertex Model: {vertex_model_name}")
self._model = type(self._base_model).get_tuned_model(
tuned_model_name=vertex_model_name
)
return self._model
self._tuned_model_name = vertex_model_name
return vertex_model_name

@property
def _status(self) -> Optional[aiplatform_types.pipeline_state.PipelineState]:
Expand All @@ -3554,6 +3560,31 @@ def _cancel(self):
self._job.cancel()


class _LanguageModelTuningJob(_TuningJob):
"""LanguageModelTuningJob represents a fine-tuning job."""

def __init__(
self,
base_model: _LanguageModel,
job: aiplatform.PipelineJob,
):
"""Internal constructor. Do not call directly."""
super().__init__(job=job)
self._base_model = base_model
self._model: Optional[_LanguageModel] = None

def get_tuned_model(self) -> "_LanguageModel":
"""Blocks until the tuning is complete and returns a `LanguageModel` object."""
if self._model:
return self._model
vertex_model_name = self._get_tuned_model_name()
_LOGGER.info(f"Tuning has completed. Created Vertex Model: {vertex_model_name}")
self._model = type(self._base_model).get_tuned_model(
tuned_model_name=vertex_model_name
)
return self._model


def _get_tuned_models_dir_uri(model_id: str) -> str:
if aiplatform_initializer.global_config.staging_bucket:
return (
Expand Down

0 comments on commit 3ce0126

Please sign in to comment.