Skip to content

Commit

Permalink
fix: Fixed the vertexai.init partial initialization issues
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625560910
  • Loading branch information
Ark-kun authored and Copybara-Service committed Apr 17, 2024
1 parent 0654c35 commit 636a654
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,23 +192,22 @@ def init(
ValueError:
If experiment_description is provided but experiment is not.
"""

if api_endpoint is not None:
self._api_endpoint = api_endpoint

# This method mutates state, so we need to be careful with the validation
# First, we need to validate all passed values
if api_transport:
VALID_TRANSPORT_TYPES = ["grpc", "rest"]
if api_transport not in VALID_TRANSPORT_TYPES:
raise ValueError(
f"{api_transport} is not a valid transport type. "
+ f"Valid transport types: {VALID_TRANSPORT_TYPES}"
)
if location:
utils.validate_region(location)
if experiment_description and experiment is None:
raise ValueError(
"Experiment needs to be set in `init` in order to add experiment descriptions."
)

if experiment_tensorboard and not isinstance(experiment_tensorboard, bool):
metadata._experiment_tracker.set_tensorboard(
tensorboard=experiment_tensorboard,
project=project,
location=location,
credentials=credentials,
)

# reset metadata_service config if project or location is updated.
if (project and project != self._project) or (
location and location != self._location
Expand All @@ -217,10 +216,14 @@ def init(
logging.info("project/location updated, reset Experiment config.")
metadata._experiment_tracker.reset()

# Then we change the main state
if api_endpoint is not None:
self._api_endpoint = api_endpoint
if api_transport:
self._api_transport = api_transport
if project:
self._project = project
if location:
utils.validate_region(location)
self._location = location
if staging_bucket:
self._staging_bucket = staging_bucket
Expand All @@ -233,22 +236,22 @@ def init(
if service_account is not None:
self._service_account = service_account

# Finally, perform secondary state updates
if experiment_tensorboard and not isinstance(experiment_tensorboard, bool):
metadata._experiment_tracker.set_tensorboard(
tensorboard=experiment_tensorboard,
project=project,
location=location,
credentials=credentials,
)

if experiment:
metadata._experiment_tracker.set_experiment(
experiment=experiment,
description=experiment_description,
backing_tensorboard=experiment_tensorboard,
)

if api_transport:
VALID_TRANSPORT_TYPES = ["grpc", "rest"]
if api_transport not in VALID_TRANSPORT_TYPES:
raise ValueError(
f"{api_transport} is not a valid transport type. "
+ f"Valid transport types: {VALID_TRANSPORT_TYPES}"
)
self._api_transport = api_transport

def get_encryption_spec(
self,
encryption_spec_key_name: Optional[str],
Expand Down

0 comments on commit 636a654

Please sign in to comment.