Skip to content

Commit

Permalink
feat: Use colab enterprise enviroment variables to infer project_id a…
Browse files Browse the repository at this point in the history
…nd region

PiperOrigin-RevId: 615076478
  • Loading branch information
matthew29tang authored and Copybara-Service committed Mar 12, 2024
1 parent e004e87 commit 5baf5f8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
6 changes: 4 additions & 2 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def _set_project_as_env_var_or_google_auth_default(self):
# See https://github.com/googleapis/google-auth-library-python/issues/924
# TODO: Remove when google.auth.default() learns the
# CLOUD_ML_PROJECT_ID env variable or Vertex AI starts setting GOOGLE_CLOUD_PROJECT env variable.
project_number = os.environ.get("CLOUD_ML_PROJECT_ID")
project_number = os.environ.get("GOOGLE_CLOUD_PROJECT") or os.environ.get(
"CLOUD_ML_PROJECT_ID"
)
if project_number:
if not self._credentials:
credentials, _ = google.auth.default()
Expand Down Expand Up @@ -312,7 +314,7 @@ def location(self) -> str:
if self._location:
return self._location

location = os.getenv("CLOUD_ML_REGION")
location = os.getenv("GOOGLE_CLOUD_REGION") or os.getenv("CLOUD_ML_REGION")
if location:
utils.validate_region(location)
return location
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/aiplatform/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,28 @@ def mock_get_project_id(project_number: str, **_):
):
assert initializer.global_config.project == _TEST_PROJECT

def test_infer_project_id_with_precedence(self):
lower_precedence_cloud_project_number = "456"
higher_precedence_cloud_project_number = "123"

def mock_get_project_id(project_number: str, **_):
assert project_number == higher_precedence_cloud_project_number
return _TEST_PROJECT

with mock.patch.object(
target=resource_manager_utils,
attribute="get_project_id",
new=mock_get_project_id,
), mock.patch.dict(
os.environ,
{
"GOOGLE_CLOUD_PROJECT": higher_precedence_cloud_project_number,
"CLOUD_ML_PROJECT_ID": lower_precedence_cloud_project_number,
},
clear=True,
):
assert initializer.global_config.project == _TEST_PROJECT

def test_init_location_sets_location(self):
initializer.global_config.init(location=_TEST_LOCATION)
assert initializer.global_config.location == _TEST_LOCATION
Expand Down

0 comments on commit 5baf5f8

Please sign in to comment.