Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@
"type": "dart",
"program": "lib/logging/logging.dart",
},
{
"name": "function calls",
"cwd": "example",
"request": "launch",
"type": "dart",
"program": "lib/function_calls/function_calls.dart",
},
{
"name": "recipes",
"cwd": "example",
Expand All @@ -86,4 +93,4 @@
"program": "lib/recipes/recipes.dart",
},
]
}
}
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
## 0.8.1
* added support for tool calls to the Gemini and Vertex providers. Check out the
new `function_calls` example to see it in action. Thanks to @toshiossada for
[the inspiration](https://github.com/flutter/ai/pull/99). Fixes
[#98](https://github.com/flutter/ai/issues/98): How Can I get functionCalls?

* fixed [#95](https://github.com/flutter/ai/issues/95): Image Attachment
Disappears After Audio Recording
Expand Down
3 changes: 3 additions & 0 deletions example/android/app/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
plugins {
id("com.android.application")
// START: FlutterFire Configuration
id("com.google.gms.google-services")
// END: FlutterFire Configuration
id("kotlin-android")
// The Flutter Gradle Plugin must be applied after the Android and Kotlin Gradle plugins.
id("dev.flutter.flutter-gradle-plugin")
Expand Down
3 changes: 3 additions & 0 deletions example/android/settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ pluginManagement {
plugins {
id("dev.flutter.flutter-plugin-loader") version "1.0.0"
id("com.android.application") version "8.7.0" apply false
// START: FlutterFire Configuration
id("com.google.gms.google-services") version("4.3.15") apply false
// END: FlutterFire Configuration
id("org.jetbrains.kotlin.android") version "1.8.22" apply false
}

Expand Down
4 changes: 4 additions & 0 deletions example/ios/Runner.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
331C808B294A63AB00263BE5 /* RunnerTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 331C807B294A618700263BE5 /* RunnerTests.swift */; };
3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */ = {isa = PBXBuildFile; fileRef = 3B3967151E833CAA004F5970 /* AppFrameworkInfo.plist */; };
74858FAF1ED2DC5600515810 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 74858FAE1ED2DC5600515810 /* AppDelegate.swift */; };
96B60D02F354EA628523F83C /* GoogleService-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 041275B288B850F18666D4D3 /* GoogleService-Info.plist */; };
97C146FC1CF9000F007C117D /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FA1CF9000F007C117D /* Main.storyboard */; };
97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FD1CF9000F007C117D /* Assets.xcassets */; };
97C147011CF9000F007C117D /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FF1CF9000F007C117D /* LaunchScreen.storyboard */; };
Expand Down Expand Up @@ -43,6 +44,7 @@

/* Begin PBXFileReference section */
00907B8AC1679C03DF3ECF3B /* Pods_RunnerTests.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_RunnerTests.framework; sourceTree = BUILT_PRODUCTS_DIR; };
041275B288B850F18666D4D3 /* GoogleService-Info.plist */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.plist.xml; name = "GoogleService-Info.plist"; path = "Runner/GoogleService-Info.plist"; sourceTree = "<group>"; };
1498D2321E8E86230040F4C2 /* GeneratedPluginRegistrant.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = GeneratedPluginRegistrant.h; sourceTree = "<group>"; };
1498D2331E8E89220040F4C2 /* GeneratedPluginRegistrant.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = GeneratedPluginRegistrant.m; sourceTree = "<group>"; };
1D6994B85815BA0993EDD5A4 /* Pods-RunnerTests.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-RunnerTests.debug.xcconfig"; path = "Target Support Files/Pods-RunnerTests/Pods-RunnerTests.debug.xcconfig"; sourceTree = "<group>"; };
Expand Down Expand Up @@ -129,6 +131,7 @@
331C8082294A63A400263BE5 /* RunnerTests */,
67F7FFA188394B87CE9E2A89 /* Pods */,
F15FFF544ED27837A1711D1A /* Frameworks */,
041275B288B850F18666D4D3 /* GoogleService-Info.plist */,
);
sourceTree = "<group>";
};
Expand Down Expand Up @@ -264,6 +267,7 @@
3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */,
97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */,
97C146FC1CF9000F007C117D /* Main.storyboard in Resources */,
96B60D02F354EA628523F83C /* GoogleService-Info.plist in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down
66 changes: 66 additions & 0 deletions example/lib/function_calls/function_calls.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2024 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

import 'package:flutter/material.dart';
import 'package:flutter_ai_toolkit/flutter_ai_toolkit.dart';
import 'package:google_generative_ai/google_generative_ai.dart';

import '../gemini_api_key.dart';

void main() => runApp(const App());

class App extends StatelessWidget {
static const title = 'Example: Function Calls';

const App({super.key});

@override
Widget build(BuildContext context) =>
const MaterialApp(title: title, home: ChatPage());
}

class ChatPage extends StatelessWidget {
const ChatPage({super.key});

@override
Widget build(BuildContext context) => Scaffold(
appBar: AppBar(title: const Text(App.title)),
body: LlmChatView(
provider: GeminiProvider(
model: GenerativeModel(
model: 'gemini-2.0-flash',
apiKey: geminiApiKey,
tools: [
Tool(
functionDeclarations: [
FunctionDeclaration(
'get_temperature',
'Get the current local temperature',
Schema.object(properties: {}),
),
FunctionDeclaration(
'get_time',
'Get the current local time',
Schema.object(properties: {}),
),
],
),
],
),
onFunctionCall: _onFunctionCall,
),
),
);

Future<Map<String, Object?>?> _onFunctionCall(
FunctionCall functionCall,
) async {
// note: just as an example, we're not actually calling any external APIs
return switch (functionCall.name) {
'get_temperature' => {'temperature': 60, 'unit': 'F'},
'get_time' => {'time': DateTime(1970, 1, 1).toIso8601String()},
_ => throw Exception('Unknown function call: ${functionCall.name}'),
};
}
}
4 changes: 4 additions & 0 deletions example/macos/Runner.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
33CC10F32044A3C60003C045 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 33CC10F22044A3C60003C045 /* Assets.xcassets */; };
33CC10F62044A3C60003C045 /* MainMenu.xib in Resources */ = {isa = PBXBuildFile; fileRef = 33CC10F42044A3C60003C045 /* MainMenu.xib */; };
33CC11132044BFA00003C045 /* MainFlutterWindow.swift in Sources */ = {isa = PBXBuildFile; fileRef = 33CC11122044BFA00003C045 /* MainFlutterWindow.swift */; };
83C8768E218628E214E0C7F2 /* GoogleService-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 29DB9CE91787729FA5EF4A32 /* GoogleService-Info.plist */; };
D9168A3AC46A7BD217B5C7C1 /* Pods_RunnerTests.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = B40A05188DFAAEDDB9FB89BA /* Pods_RunnerTests.framework */; };
/* End PBXBuildFile section */

Expand Down Expand Up @@ -63,6 +64,7 @@

/* Begin PBXFileReference section */
214AF9134787478DD9FCBE9B /* Pods-Runner.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Runner.release.xcconfig"; path = "Target Support Files/Pods-Runner/Pods-Runner.release.xcconfig"; sourceTree = "<group>"; };
29DB9CE91787729FA5EF4A32 /* GoogleService-Info.plist */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.plist.xml; name = "GoogleService-Info.plist"; path = "Runner/GoogleService-Info.plist"; sourceTree = "<group>"; };
2D628886CA8B87FB4043E6B3 /* Pods-RunnerTests.profile.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-RunnerTests.profile.xcconfig"; path = "Target Support Files/Pods-RunnerTests/Pods-RunnerTests.profile.xcconfig"; sourceTree = "<group>"; };
331C80D5294CF71000263BE5 /* RunnerTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = RunnerTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
331C80D7294CF71000263BE5 /* RunnerTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = RunnerTests.swift; sourceTree = "<group>"; };
Expand Down Expand Up @@ -152,6 +154,7 @@
33CC10EE2044A3C60003C045 /* Products */,
D73912EC22F37F3D000D13A0 /* Frameworks */,
1D8694EA61F695DDF92A68F4 /* Pods */,
29DB9CE91787729FA5EF4A32 /* GoogleService-Info.plist */,
);
sourceTree = "<group>";
};
Expand Down Expand Up @@ -317,6 +320,7 @@
files = (
33CC10F32044A3C60003C045 /* Assets.xcassets in Resources */,
33CC10F62044A3C60003C045 /* MainMenu.xib in Resources */,
83C8768E218628E214E0C7F2 /* GoogleService-Info.plist in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down
52 changes: 35 additions & 17 deletions lib/src/providers/implementations/gemini_provider.dart
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,29 @@ class GeminiProvider extends LlmProvider with ChangeNotifier {
///
/// [chatGenerationConfig] is an optional configuration for controlling the
/// model's generation behavior.
///
/// [onFunctionCall] is an optional function that will be called when the LLM
/// needs to call a function.
GeminiProvider({
required GenerativeModel model,
this.onFunctionCalls,
Iterable<ChatMessage>? history,
List<SafetySetting>? chatSafetySettings,
GenerationConfig? chatGenerationConfig,
Future<Map<String, Object?>?> Function(FunctionCall)? onFunctionCall,
}) : _model = model,
_history = history?.toList() ?? [],
_chatSafetySettings = chatSafetySettings,
_chatGenerationConfig = chatGenerationConfig {
_chatGenerationConfig = chatGenerationConfig,
_onFunctionCall = onFunctionCall {
_chat = _startChat(history);
}
final void Function(Iterable<FunctionCall>)? onFunctionCalls;

final GenerativeModel _model;
final List<SafetySetting>? _chatSafetySettings;
final GenerationConfig? _chatGenerationConfig;
final List<ChatMessage> _history;
ChatSession? _chat;
final Future<Map<String, Object?>?> Function(FunctionCall)? _onFunctionCall;

@override
Stream<String> generateStream(
Expand All @@ -54,7 +59,6 @@ class GeminiProvider extends LlmProvider with ChangeNotifier {
prompt: prompt,
attachments: attachments,
contentStreamGenerator: (c) => _model.generateContentStream([c]),
onFunctionCalls: onFunctionCalls,
);

@override
Expand All @@ -70,7 +74,6 @@ class GeminiProvider extends LlmProvider with ChangeNotifier {
prompt: prompt,
attachments: attachments,
contentStreamGenerator: _chat!.sendMessageStream,
onFunctionCalls: onFunctionCalls,
);

// don't write this code if you're targeting the web until this is fixed:
Expand All @@ -93,30 +96,45 @@ class GeminiProvider extends LlmProvider with ChangeNotifier {
required Iterable<Attachment> attachments,
required Stream<GenerateContentResponse> Function(Content)
contentStreamGenerator,
required void Function(Iterable<FunctionCall>)? onFunctionCalls,
}) async* {
final content = Content('user', [
TextPart(prompt),
...attachments.map(_partFrom),
]);

final response = contentStreamGenerator(content);
final contentResponse = contentStreamGenerator(content);

// don't write this code if you're targeting the web until this is fixed:
// https://github.com/dart-lang/sdk/issues/47764
// await for (final chunk in response) {
// final text = chunk.text;
// if (text != null) yield text;
// }
yield* response
.map((chunk) {
if (chunk.candidates.any((e) => e.finishReason != null) &&
chunk.functionCalls.isNotEmpty) {
onFunctionCalls?.call(chunk.functionCalls);
}
return chunk.text;
})
.where((text) => text != null)
.cast<String>();
yield* contentResponse.asyncMap((chunk) async {
if (chunk.functionCalls.isEmpty) return chunk.text ?? '';

final functionResponses = <FunctionResponse>[];
for (final functionCall in chunk.functionCalls) {
try {
functionResponses.add(
FunctionResponse(
functionCall.name,
await _onFunctionCall?.call(functionCall) ?? {},
),
);
} catch (ex) {
functionResponses.add(
FunctionResponse(functionCall.name, {'error': ex.toString()}),
);
}
}

final functionContentResponse = await _chat!.sendMessage(
Content.functionResponses(functionResponses),
);

return '${chunk.text ?? ''}${functionContentResponse.text ?? ''}';
});
}

@override
Expand Down
37 changes: 31 additions & 6 deletions lib/src/providers/implementations/vertex_provider.dart
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,20 @@ class VertexProvider extends LlmProvider with ChangeNotifier {
Iterable<ChatMessage>? history,
List<SafetySetting>? chatSafetySettings,
GenerationConfig? chatGenerationConfig,
Future<Map<String, Object?>?> Function(FunctionCall)? onFunctionCall,
}) : _model = model,
_history = history?.toList() ?? [],
_chatSafetySettings = chatSafetySettings,
_chatGenerationConfig = chatGenerationConfig {
_chatGenerationConfig = chatGenerationConfig,
_onFunctionCall = onFunctionCall {
_chat = _startChat(history);
}
final void Function(Iterable<FunctionCall>)? onFunctionCalls;
final GenerativeModel _model;
final List<SafetySetting>? _chatSafetySettings;
final GenerationConfig? _chatGenerationConfig;
final List<ChatMessage> _history;
final Future<Map<String, Object?>?> Function(FunctionCall)? _onFunctionCall;
ChatSession? _chat;

@override
Expand Down Expand Up @@ -97,17 +100,39 @@ class VertexProvider extends LlmProvider with ChangeNotifier {
...attachments.map(_partFrom),
]);

final response = contentStreamGenerator(content);
final contentResponse = contentStreamGenerator(content);

// don't write this code if you're targeting the web until this is fixed:
// https://github.com/dart-lang/sdk/issues/47764
// await for (final chunk in response) {
// final text = chunk.text;
// if (text != null) yield text;
// }
yield* response
.map((chunk) => chunk.text)
.where((text) => text != null)
.cast<String>();
yield* contentResponse.asyncMap((chunk) async {
if (chunk.functionCalls.isEmpty) return chunk.text ?? '';

final functionResponses = <FunctionResponse>[];
for (final functionCall in chunk.functionCalls) {
try {
functionResponses.add(
FunctionResponse(
functionCall.name,
await _onFunctionCall?.call(functionCall) ?? {},
),
);
} catch (ex) {
functionResponses.add(
FunctionResponse(functionCall.name, {'error': ex.toString()}),
);
}
}

final functionContentResponse = await _chat!.sendMessage(
Content.functionResponses(functionResponses),
);

return '${chunk.text ?? ''}${functionContentResponse.text ?? ''}';
});
}

@override
Expand Down