Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [vertexai] adding system instruction support #10775

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

/** This class holds a generative model that can complete what you provided. */
public final class GenerativeModel {
Expand All @@ -45,6 +46,7 @@ public final class GenerativeModel {
private final GenerationConfig generationConfig;
private final ImmutableList<SafetySetting> safetySettings;
private final ImmutableList<Tool> tools;
private final Optional<Content> systemInstruction;

/**
* Constructs a GenerativeModel instance.
Expand All @@ -53,7 +55,7 @@ public final class GenerativeModel {
* "models/gemini-pro", "publishers/google/models/gemini-pro", where "gemini-pro" is the model
* name. Valid model names can be found at
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
*/
public GenerativeModel(String modelName, VertexAI vertexAi) {
Expand All @@ -62,6 +64,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
GenerationConfig.getDefaultInstance(),
ImmutableList.of(),
ImmutableList.of(),
Optional.empty(),
vertexAi);
}

Expand All @@ -76,14 +79,15 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
* that will be used by default for generating response
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} instances that can be used by
* the model as auxiliary tools to generate content.
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
*/
private GenerativeModel(
String modelName,
GenerationConfig generationConfig,
ImmutableList<SafetySetting> safetySettings,
ImmutableList<Tool> tools,
Optional<Content> systemInstruction,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a setSystemInstruction in the GenerativeModel.Builder class? Also a getter method in the GenerativeModel

VertexAI vertexAi) {
checkArgument(
!Strings.isNullOrEmpty(modelName),
Expand All @@ -105,6 +109,7 @@ private GenerativeModel(
this.generationConfig = generationConfig;
this.safetySettings = safetySettings;
this.tools = tools;
this.systemInstruction = systemInstruction;
}

/** Builder class for {@link GenerativeModel}. */
Expand All @@ -114,20 +119,22 @@ public static class Builder {
private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance();
private ImmutableList<SafetySetting> safetySettings = ImmutableList.of();
private ImmutableList<Tool> tools = ImmutableList.of();
private Optional<Content> systemInstruction = Optional.empty();

public GenerativeModel build() {
checkArgument(
!Strings.isNullOrEmpty(modelName),
"modelName is required. Please call setModelName() before building.");
checkNotNull(vertexAi, "vertexAi is required. Please call setVertexAi() before building.");
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
return new GenerativeModel(
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
}

/**
* Sets the name of the generative model. This is required for building a GenerativeModel
* instance. Supported format: "gemini-pro", "models/gemini-pro",
* "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
* names can be found at
* names can be found in the Gemini models documentation
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
*/
@CanIgnoreReturnValue
Expand Down Expand Up @@ -187,6 +194,19 @@ public Builder setTools(List<Tool> tools) {
this.tools = ImmutableList.copyOf(tools);
return this;
}

/**
* Sets a system instruction that will be used by default to interact with the generative model.
*/
@CanIgnoreReturnValue
public Builder setSystemInstruction(Content systemInstruction) {
checkNotNull(
systemInstruction,
"system instruction can't be null. "
+ "Use Optional.empty() if no system instruction should be provided.");
this.systemInstruction = Optional.of(systemInstruction);
return this;
}
}

/**
Expand All @@ -197,7 +217,13 @@ public Builder setTools(List<Tool> tools) {
* @return a new {@link GenerativeModel} instance with the specified GenerationConfig.
*/
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
return new GenerativeModel(
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
}

/**
Expand All @@ -209,19 +235,46 @@ public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
*/
public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
return new GenerativeModel(
modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi);
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
}

/**
* Creates a copy of the current model with updated tools.
*
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in
* the new model.
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in the new
* model.
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withTools(List<Tool> tools) {
return new GenerativeModel(
modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi);
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
}

/**
* Creates a copy of the current model with updated system instructions.
*
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} containing system
* instructions.
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withSystemInstruction(Content systemInstruction) {
return new GenerativeModel(
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
Optional.of(systemInstruction),
vertexAi);
}

/**
Expand Down Expand Up @@ -453,13 +506,20 @@ private ApiFuture<GenerateContentResponse> generateContentAsync(GenerateContentR
*/
private GenerateContentRequest buildGenerateContentRequest(List<Content> contents) {
checkArgument(contents != null && !contents.isEmpty(), "contents can't be null or empty.");
return GenerateContentRequest.newBuilder()
.setModel(resourceName)
.addAllContents(contents)
.setGenerationConfig(generationConfig)
.addAllSafetySettings(safetySettings)
.addAllTools(tools)
.build();

GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder()
.setModel(resourceName)
.addAllContents(contents)
.setGenerationConfig(generationConfig)
.addAllSafetySettings(safetySettings)
.addAllTools(tools);

if (systemInstruction.isPresent()) {
requestBuilder.setSystemInstruction(systemInstruction.get());
}

return requestBuilder.build();
}

/** Returns the model name of this generative model. */
Expand All @@ -475,8 +535,7 @@ public GenerationConfig getGenerationConfig() {
}

/**
* Returns a list of {@link com.google.cloud.vertexai.api.SafetySettings} of this generative
* model.
* Returns a list of {@link com.google.cloud.vertexai.api.SafetySetting} of this generative model.
*/
public ImmutableList<SafetySetting> getSafetySettings() {
return safetySettings;
Expand All @@ -487,6 +546,11 @@ public ImmutableList<Tool> getTools() {
return tools;
}

/** Returns the optional system instruction of this generative model. */
public Optional<Content> getSystemInstruction() {
return systemInstruction;
}

public ChatSession startChat() {
return new ChatSession(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,28 @@ public void testGenerateContentwithContents() throws Exception {
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
}

@Test
public void testGenerateContentwithSystemInstructions() throws Exception {
String systemInstructionText =
"You're a helpful assistant that starts all its answers with: \"COOL\"";
Content systemInstructions = ContentMaker.fromString(systemInstructionText);

model = new GenerativeModel(MODEL_NAME, vertexAi).withSystemInstruction(systemInstructions);

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockGenerateContentResponse);

Content content = ContentMaker.fromString(TEXT);
GenerateContentResponse unused = model.generateContent(Arrays.asList(content));

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());
assertThat(request.getValue().getSystemInstruction().getParts(0).getText())
.isEqualTo(systemInstructionText);
}

@Test
public void testGenerateContentwithDefaultGenerationConfig() throws Exception {
model =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class ITGenerativeModelIntegrationTest {

// Tested content
private static final String TEXT = "What do you think about Google Pixel?";
private static final String PIRATE_INSTRUCTION = "Speak like a pirate when answering questions.";
private static final String IMAGE_INQUIRY = "Please describe this image: ";
private static final String IMAGE_URL = "https://picsum.photos/id/1/200/300";
private static final String VIDEO_INQUIRY = "Please summarize this video: ";
Expand Down Expand Up @@ -259,4 +260,15 @@ public void countTokens_withPlainText_returnsNonZeroTokens() throws IOException
logger.info(String.format("Print number of tokens:\n%s", tokens));
assertThat(tokens.getTotalTokens()).isGreaterThan(0);
}

@Test
public void generateContent_withSystemInstruction() throws Exception {
GenerativeModel pirateModel =
textModel.withSystemInstruction(ContentMaker.fromString(PIRATE_INSTRUCTION));

// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
GenerateContentResponse pirateResponse = pirateModel.generateContent(TEXT);
assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, pirateResponse);
}
}