Skip to content

Commit

Permalink
feat: GenAI - Tuning - Supervised - Added support for the `adapter_si…
Browse files Browse the repository at this point in the history
…ze` parameter

PiperOrigin-RevId: 631251312
  • Loading branch information
Ark-kun authored and Copybara-Service committed May 7, 2024
1 parent 20b1866 commit 88188d2
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion vertexai/tuning/_supervised_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
#

from typing import Optional, Union
from typing import Literal, Optional, Union

from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types

Expand All @@ -29,6 +29,7 @@ def train(
tuned_model_display_name: Optional[str] = None,
epochs: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
adapter_size: Optional[Literal[1, 4, 8, 16]] = None,
) -> "SupervisedTuningJob":
"""Tunes a model using supervised training.
Expand All @@ -44,6 +45,7 @@ def train(
be up to 128 characters long and can consist of any UTF-8 characters.
epochs: Number of training epoches for this tuning job.
learning_rate_multiplier: Learning rate multiplier for tuning.
adapter_size: Adapter size for tuning.
Returns:
A `TuningJob` object.
Expand All @@ -54,6 +56,7 @@ def train(
hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters(
epoch_count=epochs,
learning_rate_multiplier=learning_rate_multiplier,
adapter_size=adapter_size,
),
)

Expand Down

0 comments on commit 88188d2

Please sign in to comment.