Skip to content

Commit

Permalink
fix: ensure model starts with publishers/ when users provide resource…
Browse files Browse the repository at this point in the history
… path from models/

PiperOrigin-RevId: 640914707
  • Loading branch information
Zhenyi Qi authored and Copybara-Service committed Jun 6, 2024
1 parent bd4c09c commit d689331
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 4 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,24 +422,28 @@ def test_generative_model_constructor_model_name(
model1._prediction_resource_name
== project_location_prefix + "publishers/google/models/" + model_name1
)
assert model1._model_name == "publishers/google/models/gemini-pro"

model_name2 = "models/gemini-pro"
model2 = generative_models.GenerativeModel(model_name2)
assert (
model2._prediction_resource_name
== project_location_prefix + "publishers/google/" + model_name2
)
assert model2._model_name == "publishers/google/models/gemini-pro"

model_name3 = "publishers/some_publisher/models/some_model"
model3 = generative_models.GenerativeModel(model_name3)
assert model3._prediction_resource_name == project_location_prefix + model_name3
assert model3._model_name == "publishers/some_publisher/models/some_model"

model_name4 = (
f"projects/{_TEST_PROJECT2}/locations/{_TEST_LOCATION2}/endpoints/endpoint1"
)
model4 = generative_models.GenerativeModel(model_name4)
assert model4._prediction_resource_name == model_name4
assert _TEST_LOCATION2 in model4._prediction_client._api_endpoint
assert model4._model_name == model_name4

with pytest.raises(ValueError):
generative_models.GenerativeModel("foo/bar/models/gemini-pro")
Expand Down
2 changes: 1 addition & 1 deletion vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _reconcile_model_name(model_name: str, project: str, location: str) -> str:
if "/" not in model_name:
return f"publishers/google/models/{model_name}"
elif model_name.startswith("models/"):
return f"projects/{project}/locations/{location}/publishers/google/{model_name}"
return f"publishers/google/{model_name}"
elif model_name.startswith("publishers/") or model_name.startswith("projects/"):
return model_name
else:
Expand Down

0 comments on commit d689331

Please sign in to comment.