diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java index 151bf81e242e..1b8d444659a0 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java @@ -36,6 +36,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.Optional; /** This class holds a generative model that can complete what you provided. */ public final class GenerativeModel { @@ -45,6 +46,7 @@ public final class GenerativeModel { private final GenerationConfig generationConfig; private final ImmutableList safetySettings; private final ImmutableList tools; + private final Optional systemInstruction; /** * Constructs a GenerativeModel instance. @@ -53,7 +55,7 @@ public final class GenerativeModel { * "models/gemini-pro", "publishers/google/models/gemini-pro", where "gemini-pro" is the model * name. Valid model names can be found at * https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models - * @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs + * @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs * for the generative model */ public GenerativeModel(String modelName, VertexAI vertexAi) { @@ -62,6 +64,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) { GenerationConfig.getDefaultInstance(), ImmutableList.of(), ImmutableList.of(), + Optional.empty(), vertexAi); } @@ -76,7 +79,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) { * that will be used by default for generating response * @param tools a list of {@link com.google.cloud.vertexai.api.Tool} instances that can be used by * the model as auxiliary tools to generate content. - * @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs + * @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs * for the generative model */ private GenerativeModel( @@ -84,6 +87,7 @@ private GenerativeModel( GenerationConfig generationConfig, ImmutableList safetySettings, ImmutableList tools, + Optional systemInstruction, VertexAI vertexAi) { checkArgument( !Strings.isNullOrEmpty(modelName), @@ -105,6 +109,7 @@ private GenerativeModel( this.generationConfig = generationConfig; this.safetySettings = safetySettings; this.tools = tools; + this.systemInstruction = systemInstruction; } /** Builder class for {@link GenerativeModel}. */ @@ -114,20 +119,22 @@ public static class Builder { private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance(); private ImmutableList safetySettings = ImmutableList.of(); private ImmutableList tools = ImmutableList.of(); + private Optional systemInstruction = Optional.empty(); public GenerativeModel build() { checkArgument( !Strings.isNullOrEmpty(modelName), "modelName is required. Please call setModelName() before building."); checkNotNull(vertexAi, "vertexAi is required. Please call setVertexAi() before building."); - return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi); + return new GenerativeModel( + modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi); } /** * Sets the name of the generative model. This is required for building a GenerativeModel * instance. Supported format: "gemini-pro", "models/gemini-pro", * "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model - * names can be found at + * names can be found in the Gemini models documentation * https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models */ @CanIgnoreReturnValue @@ -187,6 +194,19 @@ public Builder setTools(List tools) { this.tools = ImmutableList.copyOf(tools); return this; } + + /** + * Sets a system instruction that will be used by default to interact with the generative model. + */ + @CanIgnoreReturnValue + public Builder setSystemInstruction(Content systemInstruction) { + checkNotNull( + systemInstruction, + "system instruction can't be null. " + + "Use Optional.empty() if no system instruction should be provided."); + this.systemInstruction = Optional.of(systemInstruction); + return this; + } } /** @@ -197,7 +217,13 @@ public Builder setTools(List tools) { * @return a new {@link GenerativeModel} instance with the specified GenerationConfig. */ public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) { - return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi); + return new GenerativeModel( + modelName, + generationConfig, + ImmutableList.copyOf(safetySettings), + ImmutableList.copyOf(tools), + systemInstruction, + vertexAi); } /** @@ -209,19 +235,46 @@ public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) { */ public GenerativeModel withSafetySettings(List safetySettings) { return new GenerativeModel( - modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi); + modelName, + generationConfig, + ImmutableList.copyOf(safetySettings), + ImmutableList.copyOf(tools), + systemInstruction, + vertexAi); } /** * Creates a copy of the current model with updated tools. * - * @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in - * the new model. + * @param tools a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in the new + * model. * @return a new {@link GenerativeModel} instance with the specified tools. */ public GenerativeModel withTools(List tools) { return new GenerativeModel( - modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi); + modelName, + generationConfig, + ImmutableList.copyOf(safetySettings), + ImmutableList.copyOf(tools), + systemInstruction, + vertexAi); + } + + /** + * Creates a copy of the current model with updated system instructions. + * + * @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} containing system + * instructions. + * @return a new {@link GenerativeModel} instance with the specified tools. + */ + public GenerativeModel withSystemInstruction(Content systemInstruction) { + return new GenerativeModel( + modelName, + generationConfig, + ImmutableList.copyOf(safetySettings), + ImmutableList.copyOf(tools), + Optional.of(systemInstruction), + vertexAi); } /** @@ -453,13 +506,20 @@ private ApiFuture generateContentAsync(GenerateContentR */ private GenerateContentRequest buildGenerateContentRequest(List contents) { checkArgument(contents != null && !contents.isEmpty(), "contents can't be null or empty."); - return GenerateContentRequest.newBuilder() - .setModel(resourceName) - .addAllContents(contents) - .setGenerationConfig(generationConfig) - .addAllSafetySettings(safetySettings) - .addAllTools(tools) - .build(); + + GenerateContentRequest.Builder requestBuilder = + GenerateContentRequest.newBuilder() + .setModel(resourceName) + .addAllContents(contents) + .setGenerationConfig(generationConfig) + .addAllSafetySettings(safetySettings) + .addAllTools(tools); + + if (systemInstruction.isPresent()) { + requestBuilder.setSystemInstruction(systemInstruction.get()); + } + + return requestBuilder.build(); } /** Returns the model name of this generative model. */ @@ -475,8 +535,7 @@ public GenerationConfig getGenerationConfig() { } /** - * Returns a list of {@link com.google.cloud.vertexai.api.SafetySettings} of this generative - * model. + * Returns a list of {@link com.google.cloud.vertexai.api.SafetySetting} of this generative model. */ public ImmutableList getSafetySettings() { return safetySettings; @@ -487,6 +546,11 @@ public ImmutableList getTools() { return tools; } + /** Returns the optional system instruction of this generative model. */ + public Optional getSystemInstruction() { + return systemInstruction; + } + public ChatSession startChat() { return new ChatSession(this); } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java index 06dc96ad9fa9..175ed2e7d3ab 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java @@ -324,6 +324,28 @@ public void testGenerateContentwithContents() throws Exception { assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT); } + @Test + public void testGenerateContentwithSystemInstructions() throws Exception { + String systemInstructionText = + "You're a helpful assistant that starts all its answers with: \"COOL\""; + Content systemInstructions = ContentMaker.fromString(systemInstructionText); + + model = new GenerativeModel(MODEL_NAME, vertexAi).withSystemInstruction(systemInstructions); + + when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); + when(mockUnaryCallable.call(any(GenerateContentRequest.class))) + .thenReturn(mockGenerateContentResponse); + + Content content = ContentMaker.fromString(TEXT); + GenerateContentResponse unused = model.generateContent(Arrays.asList(content)); + + ArgumentCaptor request = + ArgumentCaptor.forClass(GenerateContentRequest.class); + verify(mockUnaryCallable).call(request.capture()); + assertThat(request.getValue().getSystemInstruction().getParts(0).getText()) + .isEqualTo(systemInstructionText); + } + @Test public void testGenerateContentwithDefaultGenerationConfig() throws Exception { model = diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java index 805f228b53c8..960037c77a1b 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java @@ -58,6 +58,7 @@ public class ITGenerativeModelIntegrationTest { // Tested content private static final String TEXT = "What do you think about Google Pixel?"; + private static final String PIRATE_INSTRUCTION = "Speak like a pirate when answering questions."; private static final String IMAGE_INQUIRY = "Please describe this image: "; private static final String IMAGE_URL = "https://picsum.photos/id/1/200/300"; private static final String VIDEO_INQUIRY = "Please summarize this video: "; @@ -259,4 +260,15 @@ public void countTokens_withPlainText_returnsNonZeroTokens() throws IOException logger.info(String.format("Print number of tokens:\n%s", tokens)); assertThat(tokens.getTotalTokens()).isGreaterThan(0); } + + @Test + public void generateContent_withSystemInstruction() throws Exception { + GenerativeModel pirateModel = + textModel.withSystemInstruction(ContentMaker.fromString(PIRATE_INSTRUCTION)); + + // GenAI output is flaky so we always print out the response. + // For the same reason, we don't do assertions much. + GenerateContentResponse pirateResponse = pirateModel.generateContent(TEXT); + assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, pirateResponse); + } }