Skip to content

Commit

Permalink
feat: GenAI - Tuning - Added support for CMEK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642696819
  • Loading branch information
Ark-kun authored and Copybara-Service committed Jun 12, 2024
1 parent e832a8a commit eb651bc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/unit/vertexai/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

import copy
import datetime
import importlib
from typing import Dict, Iterable
from unittest import mock
import uuid

from google.cloud import aiplatform
import vertexai
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import initializer
Expand Down Expand Up @@ -150,6 +152,10 @@ class TestgenerativeModelTuning:
"""Unit tests for generative model tuning."""

def setup_method(self):
importlib.reload(initializer)
importlib.reload(aiplatform)
importlib.reload(vertexai)

vertexai.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
Expand Down Expand Up @@ -197,3 +203,18 @@ def test_genai_tuning_service_supervised_tuning_tune_model(self):
assert sft_tuning_job._experiment_name
assert sft_tuning_job.tuned_model_name
assert sft_tuning_job.tuned_model_endpoint_name

@mock.patch.object(
target=tuning.TuningJob,
attribute="client_class",
new=MockTuningJobClientWithOverride,
)
def test_genai_tuning_service_encryption_spec(self):
"""Test that the global encryption spec propagates to the tuning job."""
vertexai.init(encryption_spec_key_name="test-key")

sft_tuning_job = supervised_tuning.train(
source_model="gemini-1.0-pro-001",
train_dataset="gs://some-bucket/some_dataset.jsonl",
)
assert sft_tuning_job.encryption_spec.kms_key_name == "test-key"
5 changes: 5 additions & 0 deletions vertexai/tuning/_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ def _create(
else:
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")

if aiplatform_initializer.global_config.encryption_spec_key_name:
gca_tuning_job.encryption_spec.kms_key_name = (
aiplatform_initializer.global_config.encryption_spec_key_name
)

tuning_job: TuningJob = cls._construct_sdk_resource_from_gapic(
gapic_resource=gca_tuning_job,
project=project,
Expand Down

0 comments on commit eb651bc

Please sign in to comment.