diff --git a/.vscode/launch.json b/.vscode/launch.json index 059bd32..ca411a2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -78,6 +78,13 @@ "type": "dart", "program": "lib/logging/logging.dart", }, + { + "name": "function calls", + "cwd": "example", + "request": "launch", + "type": "dart", + "program": "lib/function_calls/function_calls.dart", + }, { "name": "recipes", "cwd": "example", @@ -86,4 +93,4 @@ "program": "lib/recipes/recipes.dart", }, ] -} +} \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index d4c75b6..67e78b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,8 @@ ## 0.8.1 +* added support for tool calls to the Gemini and Vertex providers. Check out the + new `function_calls` example to see it in action. Thanks to @toshiossada for + [the inspiration](https://github.com/flutter/ai/pull/99). Fixes + [#98](https://github.com/flutter/ai/issues/98): How Can I get functionCalls? * fixed [#95](https://github.com/flutter/ai/issues/95): Image Attachment Disappears After Audio Recording diff --git a/example/android/app/build.gradle.kts b/example/android/app/build.gradle.kts index ea4d931..f1076c1 100644 --- a/example/android/app/build.gradle.kts +++ b/example/android/app/build.gradle.kts @@ -1,5 +1,8 @@ plugins { id("com.android.application") + // START: FlutterFire Configuration + id("com.google.gms.google-services") + // END: FlutterFire Configuration id("kotlin-android") // The Flutter Gradle Plugin must be applied after the Android and Kotlin Gradle plugins. id("dev.flutter.flutter-gradle-plugin") diff --git a/example/android/settings.gradle.kts b/example/android/settings.gradle.kts index a439442..9e2d35c 100644 --- a/example/android/settings.gradle.kts +++ b/example/android/settings.gradle.kts @@ -19,6 +19,9 @@ pluginManagement { plugins { id("dev.flutter.flutter-plugin-loader") version "1.0.0" id("com.android.application") version "8.7.0" apply false + // START: FlutterFire Configuration + id("com.google.gms.google-services") version("4.3.15") apply false + // END: FlutterFire Configuration id("org.jetbrains.kotlin.android") version "1.8.22" apply false } diff --git a/example/ios/Runner.xcodeproj/project.pbxproj b/example/ios/Runner.xcodeproj/project.pbxproj index 55f92ff..78b3fe8 100644 --- a/example/ios/Runner.xcodeproj/project.pbxproj +++ b/example/ios/Runner.xcodeproj/project.pbxproj @@ -12,6 +12,7 @@ 331C808B294A63AB00263BE5 /* RunnerTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 331C807B294A618700263BE5 /* RunnerTests.swift */; }; 3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */ = {isa = PBXBuildFile; fileRef = 3B3967151E833CAA004F5970 /* AppFrameworkInfo.plist */; }; 74858FAF1ED2DC5600515810 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 74858FAE1ED2DC5600515810 /* AppDelegate.swift */; }; + 96B60D02F354EA628523F83C /* GoogleService-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 041275B288B850F18666D4D3 /* GoogleService-Info.plist */; }; 97C146FC1CF9000F007C117D /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FA1CF9000F007C117D /* Main.storyboard */; }; 97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FD1CF9000F007C117D /* Assets.xcassets */; }; 97C147011CF9000F007C117D /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FF1CF9000F007C117D /* LaunchScreen.storyboard */; }; @@ -43,6 +44,7 @@ /* Begin PBXFileReference section */ 00907B8AC1679C03DF3ECF3B /* Pods_RunnerTests.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_RunnerTests.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + 041275B288B850F18666D4D3 /* GoogleService-Info.plist */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.plist.xml; name = "GoogleService-Info.plist"; path = "Runner/GoogleService-Info.plist"; sourceTree = ""; }; 1498D2321E8E86230040F4C2 /* GeneratedPluginRegistrant.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = GeneratedPluginRegistrant.h; sourceTree = ""; }; 1498D2331E8E89220040F4C2 /* GeneratedPluginRegistrant.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = GeneratedPluginRegistrant.m; sourceTree = ""; }; 1D6994B85815BA0993EDD5A4 /* Pods-RunnerTests.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-RunnerTests.debug.xcconfig"; path = "Target Support Files/Pods-RunnerTests/Pods-RunnerTests.debug.xcconfig"; sourceTree = ""; }; @@ -129,6 +131,7 @@ 331C8082294A63A400263BE5 /* RunnerTests */, 67F7FFA188394B87CE9E2A89 /* Pods */, F15FFF544ED27837A1711D1A /* Frameworks */, + 041275B288B850F18666D4D3 /* GoogleService-Info.plist */, ); sourceTree = ""; }; @@ -264,6 +267,7 @@ 3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */, 97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */, 97C146FC1CF9000F007C117D /* Main.storyboard in Resources */, + 96B60D02F354EA628523F83C /* GoogleService-Info.plist in Resources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/example/lib/function_calls/function_calls.dart b/example/lib/function_calls/function_calls.dart new file mode 100644 index 0000000..7e8b4d3 --- /dev/null +++ b/example/lib/function_calls/function_calls.dart @@ -0,0 +1,66 @@ +// Copyright 2024 The Flutter Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import 'package:flutter/material.dart'; +import 'package:flutter_ai_toolkit/flutter_ai_toolkit.dart'; +import 'package:google_generative_ai/google_generative_ai.dart'; + +import '../gemini_api_key.dart'; + +void main() => runApp(const App()); + +class App extends StatelessWidget { + static const title = 'Example: Function Calls'; + + const App({super.key}); + + @override + Widget build(BuildContext context) => + const MaterialApp(title: title, home: ChatPage()); +} + +class ChatPage extends StatelessWidget { + const ChatPage({super.key}); + + @override + Widget build(BuildContext context) => Scaffold( + appBar: AppBar(title: const Text(App.title)), + body: LlmChatView( + provider: GeminiProvider( + model: GenerativeModel( + model: 'gemini-2.0-flash', + apiKey: geminiApiKey, + tools: [ + Tool( + functionDeclarations: [ + FunctionDeclaration( + 'get_temperature', + 'Get the current local temperature', + Schema.object(properties: {}), + ), + FunctionDeclaration( + 'get_time', + 'Get the current local time', + Schema.object(properties: {}), + ), + ], + ), + ], + ), + onFunctionCall: _onFunctionCall, + ), + ), + ); + + Future?> _onFunctionCall( + FunctionCall functionCall, + ) async { + // note: just as an example, we're not actually calling any external APIs + return switch (functionCall.name) { + 'get_temperature' => {'temperature': 60, 'unit': 'F'}, + 'get_time' => {'time': DateTime(1970, 1, 1).toIso8601String()}, + _ => throw Exception('Unknown function call: ${functionCall.name}'), + }; + } +} diff --git a/example/macos/Runner.xcodeproj/project.pbxproj b/example/macos/Runner.xcodeproj/project.pbxproj index 22d8658..3198455 100644 --- a/example/macos/Runner.xcodeproj/project.pbxproj +++ b/example/macos/Runner.xcodeproj/project.pbxproj @@ -28,6 +28,7 @@ 33CC10F32044A3C60003C045 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 33CC10F22044A3C60003C045 /* Assets.xcassets */; }; 33CC10F62044A3C60003C045 /* MainMenu.xib in Resources */ = {isa = PBXBuildFile; fileRef = 33CC10F42044A3C60003C045 /* MainMenu.xib */; }; 33CC11132044BFA00003C045 /* MainFlutterWindow.swift in Sources */ = {isa = PBXBuildFile; fileRef = 33CC11122044BFA00003C045 /* MainFlutterWindow.swift */; }; + 83C8768E218628E214E0C7F2 /* GoogleService-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 29DB9CE91787729FA5EF4A32 /* GoogleService-Info.plist */; }; D9168A3AC46A7BD217B5C7C1 /* Pods_RunnerTests.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = B40A05188DFAAEDDB9FB89BA /* Pods_RunnerTests.framework */; }; /* End PBXBuildFile section */ @@ -63,6 +64,7 @@ /* Begin PBXFileReference section */ 214AF9134787478DD9FCBE9B /* Pods-Runner.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Runner.release.xcconfig"; path = "Target Support Files/Pods-Runner/Pods-Runner.release.xcconfig"; sourceTree = ""; }; + 29DB9CE91787729FA5EF4A32 /* GoogleService-Info.plist */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.plist.xml; name = "GoogleService-Info.plist"; path = "Runner/GoogleService-Info.plist"; sourceTree = ""; }; 2D628886CA8B87FB4043E6B3 /* Pods-RunnerTests.profile.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-RunnerTests.profile.xcconfig"; path = "Target Support Files/Pods-RunnerTests/Pods-RunnerTests.profile.xcconfig"; sourceTree = ""; }; 331C80D5294CF71000263BE5 /* RunnerTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = RunnerTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; 331C80D7294CF71000263BE5 /* RunnerTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = RunnerTests.swift; sourceTree = ""; }; @@ -152,6 +154,7 @@ 33CC10EE2044A3C60003C045 /* Products */, D73912EC22F37F3D000D13A0 /* Frameworks */, 1D8694EA61F695DDF92A68F4 /* Pods */, + 29DB9CE91787729FA5EF4A32 /* GoogleService-Info.plist */, ); sourceTree = ""; }; @@ -317,6 +320,7 @@ files = ( 33CC10F32044A3C60003C045 /* Assets.xcassets in Resources */, 33CC10F62044A3C60003C045 /* MainMenu.xib in Resources */, + 83C8768E218628E214E0C7F2 /* GoogleService-Info.plist in Resources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/lib/src/providers/implementations/gemini_provider.dart b/lib/src/providers/implementations/gemini_provider.dart index 12fdb3c..539ceab 100644 --- a/lib/src/providers/implementations/gemini_provider.dart +++ b/lib/src/providers/implementations/gemini_provider.dart @@ -27,24 +27,29 @@ class GeminiProvider extends LlmProvider with ChangeNotifier { /// /// [chatGenerationConfig] is an optional configuration for controlling the /// model's generation behavior. + /// + /// [onFunctionCall] is an optional function that will be called when the LLM + /// needs to call a function. GeminiProvider({ required GenerativeModel model, - this.onFunctionCalls, Iterable? history, List? chatSafetySettings, GenerationConfig? chatGenerationConfig, + Future?> Function(FunctionCall)? onFunctionCall, }) : _model = model, _history = history?.toList() ?? [], _chatSafetySettings = chatSafetySettings, - _chatGenerationConfig = chatGenerationConfig { + _chatGenerationConfig = chatGenerationConfig, + _onFunctionCall = onFunctionCall { _chat = _startChat(history); } - final void Function(Iterable)? onFunctionCalls; + final GenerativeModel _model; final List? _chatSafetySettings; final GenerationConfig? _chatGenerationConfig; final List _history; ChatSession? _chat; + final Future?> Function(FunctionCall)? _onFunctionCall; @override Stream generateStream( @@ -54,7 +59,6 @@ class GeminiProvider extends LlmProvider with ChangeNotifier { prompt: prompt, attachments: attachments, contentStreamGenerator: (c) => _model.generateContentStream([c]), - onFunctionCalls: onFunctionCalls, ); @override @@ -70,7 +74,6 @@ class GeminiProvider extends LlmProvider with ChangeNotifier { prompt: prompt, attachments: attachments, contentStreamGenerator: _chat!.sendMessageStream, - onFunctionCalls: onFunctionCalls, ); // don't write this code if you're targeting the web until this is fixed: @@ -93,30 +96,45 @@ class GeminiProvider extends LlmProvider with ChangeNotifier { required Iterable attachments, required Stream Function(Content) contentStreamGenerator, - required void Function(Iterable)? onFunctionCalls, }) async* { final content = Content('user', [ TextPart(prompt), ...attachments.map(_partFrom), ]); - final response = contentStreamGenerator(content); + final contentResponse = contentStreamGenerator(content); + // don't write this code if you're targeting the web until this is fixed: // https://github.com/dart-lang/sdk/issues/47764 // await for (final chunk in response) { // final text = chunk.text; // if (text != null) yield text; // } - yield* response - .map((chunk) { - if (chunk.candidates.any((e) => e.finishReason != null) && - chunk.functionCalls.isNotEmpty) { - onFunctionCalls?.call(chunk.functionCalls); - } - return chunk.text; - }) - .where((text) => text != null) - .cast(); + yield* contentResponse.asyncMap((chunk) async { + if (chunk.functionCalls.isEmpty) return chunk.text ?? ''; + + final functionResponses = []; + for (final functionCall in chunk.functionCalls) { + try { + functionResponses.add( + FunctionResponse( + functionCall.name, + await _onFunctionCall?.call(functionCall) ?? {}, + ), + ); + } catch (ex) { + functionResponses.add( + FunctionResponse(functionCall.name, {'error': ex.toString()}), + ); + } + } + + final functionContentResponse = await _chat!.sendMessage( + Content.functionResponses(functionResponses), + ); + + return '${chunk.text ?? ''}${functionContentResponse.text ?? ''}'; + }); } @override diff --git a/lib/src/providers/implementations/vertex_provider.dart b/lib/src/providers/implementations/vertex_provider.dart index 293260a..ce2b85a 100644 --- a/lib/src/providers/implementations/vertex_provider.dart +++ b/lib/src/providers/implementations/vertex_provider.dart @@ -33,10 +33,12 @@ class VertexProvider extends LlmProvider with ChangeNotifier { Iterable? history, List? chatSafetySettings, GenerationConfig? chatGenerationConfig, + Future?> Function(FunctionCall)? onFunctionCall, }) : _model = model, _history = history?.toList() ?? [], _chatSafetySettings = chatSafetySettings, - _chatGenerationConfig = chatGenerationConfig { + _chatGenerationConfig = chatGenerationConfig, + _onFunctionCall = onFunctionCall { _chat = _startChat(history); } final void Function(Iterable)? onFunctionCalls; @@ -44,6 +46,7 @@ class VertexProvider extends LlmProvider with ChangeNotifier { final List? _chatSafetySettings; final GenerationConfig? _chatGenerationConfig; final List _history; + final Future?> Function(FunctionCall)? _onFunctionCall; ChatSession? _chat; @override @@ -97,17 +100,39 @@ class VertexProvider extends LlmProvider with ChangeNotifier { ...attachments.map(_partFrom), ]); - final response = contentStreamGenerator(content); + final contentResponse = contentStreamGenerator(content); + // don't write this code if you're targeting the web until this is fixed: // https://github.com/dart-lang/sdk/issues/47764 // await for (final chunk in response) { // final text = chunk.text; // if (text != null) yield text; // } - yield* response - .map((chunk) => chunk.text) - .where((text) => text != null) - .cast(); + yield* contentResponse.asyncMap((chunk) async { + if (chunk.functionCalls.isEmpty) return chunk.text ?? ''; + + final functionResponses = []; + for (final functionCall in chunk.functionCalls) { + try { + functionResponses.add( + FunctionResponse( + functionCall.name, + await _onFunctionCall?.call(functionCall) ?? {}, + ), + ); + } catch (ex) { + functionResponses.add( + FunctionResponse(functionCall.name, {'error': ex.toString()}), + ); + } + } + + final functionContentResponse = await _chat!.sendMessage( + Content.functionResponses(functionResponses), + ); + + return '${chunk.text ?? ''}${functionContentResponse.text ?? ''}'; + }); } @override