Skip to content

Commit

Permalink
feat: [vertexai] Allow tuned model to be used. (#10825)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 632509605

Co-authored-by: Zhenyi Qi <zhenyiqi@google.com>
  • Loading branch information
copybara-service[bot] and ZhenyiQ committed May 14, 2024
1 parent 88ee863 commit 9081269
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@

/** A class that holds all constants for vertexai/generativeai. */
public final class Constants {
public static final String MODEL_NAME_PREFIX_PROJECTS = "projects/";
public static final String MODEL_NAME_PREFIX_PUBLISHERS = "publishers/";
public static final String MODEL_NAME_PREFIX_MODELS = "models/";
public static final ImmutableSet<String> MODEL_NAME_PREFIXES =
ImmutableSet.of("publishers/google/models/", "models/");
ImmutableSet.of(
MODEL_NAME_PREFIX_PROJECTS, MODEL_NAME_PREFIX_PUBLISHERS, MODEL_NAME_PREFIX_MODELS);

private Constants() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,9 @@ private GenerativeModel(
checkNotNull(safetySettings, "ImmutableList<SafetySettings> can't be null.");
checkNotNull(tools, "ImmutableList<Tool> can't be null.");

modelName = reconcileModelName(modelName);
this.modelName = modelName;
this.resourceName =
String.format(
"projects/%s/locations/%s/publishers/google/models/%s",
vertexAi.getProjectId(), vertexAi.getLocation(), modelName);
this.resourceName = getResourceName(modelName, vertexAi);
// reconcileModelName should be called after getResourceName.
this.modelName = reconcileModelName(modelName);
this.vertexAi = vertexAi;
this.generationConfig = generationConfig;
this.safetySettings = safetySettings;
Expand Down Expand Up @@ -157,7 +154,7 @@ public Builder setModelName(String modelName) {
+ " https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models"
+ " to find the right model name.");

this.modelName = reconcileModelName(modelName);
this.modelName = modelName;
return this;
}

Expand Down Expand Up @@ -584,10 +581,28 @@ public ChatSession startChat() {
private static String reconcileModelName(String modelName) {
for (String prefix : Constants.MODEL_NAME_PREFIXES) {
if (modelName.startsWith(prefix)) {
modelName = modelName.substring(prefix.length());
modelName = modelName.substring(modelName.lastIndexOf('/') + 1);
break;
}
}
return modelName;
}

/**
* Computes resourceName based on original modelName. Note: this should happen before the
* modelName is reconciled.
*/
private static String getResourceName(String modelName, VertexAI vertexAi) {
if (modelName.startsWith(Constants.MODEL_NAME_PREFIX_PROJECTS)) {
return modelName;
} else if (modelName.startsWith(Constants.MODEL_NAME_PREFIX_PUBLISHERS)) {
return String.format(
"projects/%s/locations/%s/%s",
vertexAi.getProjectId(), vertexAi.getLocation(), modelName);
} else {
return String.format(
"projects/%s/locations/%s/publishers/google/models/%s",
vertexAi.getProjectId(), vertexAi.getLocation(), reconcileModelName(modelName));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ public final class GenerativeModelTest {
.setVertexAiSearch(
VertexAISearch.newBuilder()
.setDatastore(
String.format(
"projects/%s/locations/%s/collections/%s/dataStores/%s",
PROJECT, "global", "default_collection", "test_123")))
"projects/test_project/locations/global/collections/default_collection/dataStores/test_123"))
.setDisableAttribution(false))
.build();

Expand Down Expand Up @@ -164,6 +162,19 @@ public void testInstantiateGenerativeModel() {
assertThat(model.getTools()).isEmpty();
}

@Test
public void
testInstantiateGenerativeModel_withModelNameStartingFromProjects_modelNameIsCorrect() {
model =
new GenerativeModel(
"projects/test_project/locations/test_location/publishers/google/models/gemini-pro",
vertexAi);
assertThat(model.getModelName()).isEqualTo(MODEL_NAME);
assertThat(model.getGenerationConfig()).isEqualTo(GenerationConfig.getDefaultInstance());
assertThat(model.getSafetySettings()).isEmpty();
assertThat(model.getTools()).isEmpty();
}

@Test
public void testInstantiateGenerativeModelwithBuilder() {
model = new GenerativeModel.Builder().setModelName(MODEL_NAME).setVertexAi(vertexAi).build();
Expand Down Expand Up @@ -286,6 +297,32 @@ public void testGenerateContentwithText() throws Exception {
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
assertThat(request.getValue().getModel())
.isEqualTo(
"projects/test_project/locations/test_location/publishers/google/models/gemini-pro");
}

@Test
public void testGenerateContentwithText_withFullModelName_requestHasCorrectResourceName()
throws Exception {
model =
new GenerativeModel(
"projects/another_project/locations/europe-west4/publishers/google/models/another_model",
vertexAi);

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockGenerateContentResponse);

GenerateContentResponse unused = model.generateContent(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());
assertThat(request.getValue().getModel())
.isEqualTo(
"projects/another_project/locations/europe-west4/publishers/google/models/another_model");
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
}

@Test
Expand Down

0 comments on commit 9081269

Please sign in to comment.