Skip to content

Commit

Permalink
feat(vertexai): Add a few more parameters to the model request (#12824)
Browse files Browse the repository at this point in the history
* pass in a few more params into request

* requires generative_ai 0.4.1
  • Loading branch information
cynthiajoan authored May 24, 2024
1 parent 2e410a2 commit 35ad8d4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ final class GenerativeModel {
List<Tool>? tools,
Content? systemInstruction,
ToolConfig? toolConfig,
}) : _firebaseApp = app,
_googleAIModel = createModelWithBaseUri(
}) : _googleAIModel = createModelWithBaseUri(
model: _normalizeModelName(model),
apiKey: app.options.apiKey,
baseUri: _vertexUri(app, location),
Expand All @@ -75,7 +74,6 @@ final class GenerativeModel {
: [],
toolConfig: toolConfig?.toGoogleAI(),
);
final FirebaseApp _firebaseApp;
final google_ai.GenerativeModel _googleAIModel;

static const _modelsPrefix = 'models/';
Expand All @@ -92,15 +90,6 @@ final class GenerativeModel {
);
}

static google_ai.GenerationConfig _convertGenerationConfig(
GenerationConfig? config, FirebaseApp app) {
if (config == null) {
return google_ai.GenerationConfig();
} else {
return config.toGoogleAI();
}
}

static FutureOr<Map<String, String>> Function() _firebaseTokens(
FirebaseAppCheck? appCheck, FirebaseAuth? auth) {
return () async {
Expand Down Expand Up @@ -135,16 +124,21 @@ final class GenerativeModel {
/// ```
Future<GenerateContentResponse> generateContent(Iterable<Content> prompt,
{List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig}) async {
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig}) async {
Iterable<google_ai.Content> googlePrompt =
prompt.map((content) => content.toGoogleAI());
List<google_ai.SafetySetting> googleSafetySettings = safetySettings != null
? safetySettings.map((setting) => setting.toGoogleAI()).toList()
: [];
final response = await _googleAIModel.generateContent(googlePrompt,
safetySettings: googleSafetySettings,
generationConfig:
_convertGenerationConfig(generationConfig, _firebaseApp));
generationConfig: generationConfig?.toGoogleAI(),
tools: tools != null
? tools.map((tool) => tool.toGoogleAI()).toList()
: [],
toolConfig: toolConfig?.toGoogleAI());
return response.toVertex();
}

Expand All @@ -163,13 +157,19 @@ final class GenerativeModel {
Stream<GenerateContentResponse> generateContentStream(
Iterable<Content> prompt,
{List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig}) {
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig}) {
return _googleAIModel
.generateContentStream(prompt.map((content) => content.toGoogleAI()),
safetySettings: safetySettings != null
? safetySettings.map((setting) => setting.toGoogleAI()).toList()
: [],
generationConfig: generationConfig?.toGoogleAI())
generationConfig: generationConfig?.toGoogleAI(),
tools: tools != null
? tools.map((tool) => tool.toGoogleAI()).toList()
: [],
toolConfig: toolConfig?.toGoogleAI())
.map((r) => r.toVertex());
}

Expand All @@ -190,9 +190,23 @@ final class GenerativeModel {
/// print(response.text);
/// }
/// ```
Future<CountTokensResponse> countTokens(Iterable<Content> contents) async {
Future<CountTokensResponse> countTokens(
Iterable<Content> contents, {
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
}) async {
return _googleAIModel
.countTokens(contents.map((e) => e.toGoogleAI()))
.countTokens(contents.map((e) => e.toGoogleAI()),
safetySettings: safetySettings != null
? safetySettings.map((setting) => setting.toGoogleAI()).toList()
: [],
generationConfig: generationConfig?.toGoogleAI(),
tools: tools != null
? tools.map((tool) => tool.toGoogleAI()).toList()
: [],
toolConfig: toolConfig?.toGoogleAI())
.then((r) => r.toVertex());
}

Expand All @@ -207,10 +221,12 @@ final class GenerativeModel {
/// (await model.embedContent([Content.text(prompt)])).embedding.values;
/// ```
Future<EmbedContentResponse> embedContent(Content content,
{TaskType? taskType, String? title}) async {
{TaskType? taskType, String? title, int? outputDimensionality}) async {
return _googleAIModel
.embedContent(content.toGoogleAI(),
taskType: taskType?.toGoogleAI(), title: title)
taskType: taskType?.toGoogleAI(),
title: title,
outputDimensionality: outputDimensionality)
.then((r) => r.toVertex());
}

Expand Down
2 changes: 1 addition & 1 deletion packages/firebase_vertexai/firebase_vertexai/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
firebase_core_platform_interface: ^5.0.0
flutter:
sdk: flutter
google_generative_ai: ^0.4.0
google_generative_ai: ^0.4.1

dev_dependencies:
flutter_lints: ^3.0.0
Expand Down

0 comments on commit 35ad8d4

Please sign in to comment.