Skip to content

Commit

Permalink
feat: Add streaming support to googleai_dart client (#299)
Browse files Browse the repository at this point in the history
Co-authored-by: David Miguel <me@davidmiguel.com>
  • Loading branch information
luisredondo and davidmigloz committed Feb 12, 2024
1 parent 5df0be8 commit 2cbd538
Show file tree
Hide file tree
Showing 10 changed files with 336 additions and 62 deletions.
33 changes: 29 additions & 4 deletions packages/googleai_dart/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@ Unofficial Dart client for [Google AI](https://ai.google.dev) for Developers (Ge
- Custom base URL, headers and query params support (e.g. HTTP proxies)
- Custom HTTP client support (e.g. SOCKS5 proxies or advanced use cases)

> [!NOTE]
> Streaming support coming soon.
**Supported endpoints:**

- Generate content
- Generate content (with streaming support)
- Count tokens
- Embed content
- Models info
Expand All @@ -35,6 +32,7 @@ Unofficial Dart client for [Google AI](https://ai.google.dev) for Developers (Ge
+ [Text-only input](#text-only-input)
+ [Text-and-image input](#text-and-image-input)
+ [Multi-turn conversations (chat)](#multi-turn-conversations-chat)
+ [Streaming generated content](#streaming-generated-content)
* [Count tokens](#count-tokens)
* [Embedding](#embedding)
* [Model info](#model-info)
Expand Down Expand Up @@ -167,6 +165,33 @@ print(res.candidates?.first.content?.parts?.first.text);
// In the heart of a tranquil village nestled amidst the rolling hills of 17th century France...
```

#### Streaming generated content

By default, `generateContent` returns a response after completing the entire generation process. You can achieve faster interactions by not waiting for the entire result, and instead use `streamGenerateContent` to handle partial results as they become available.

```dart
final stream = await client.streamGenerateContent(
modelId: 'gemini-pro',
request: const GenerateContentRequest(
contents: [
Content(
parts: [
Part(text: 'Write a story about a magic backpack.'),
],
),
],
generationConfig: GenerationConfig(
temperature: 0.8,
),
),
);
stream.listen((final res) P
print(res.candidates?.first.content?.parts?.first.text);
// In a quaint little town nestled amidst rolling hills, there lived a...
)
```

### Count tokens

When using long prompts, it might be useful to count tokens before sending any content to the model.
Expand Down
29 changes: 29 additions & 0 deletions packages/googleai_dart/example/googleai_dart_example.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// ignore_for_file: avoid_print, avoid_redundant_argument_values
import 'dart:async';
import 'dart:convert';
import 'dart:io';

Expand All @@ -14,6 +15,9 @@ Future<void> main() async {
await _generateContentTextAndImageInput(client);
await _generateContentMultiTurnConversations(client);

// Stream generate content
await _streamGenerateContentTextInput(client);

// Count tokens
await _countTokens(client);

Expand Down Expand Up @@ -114,6 +118,31 @@ Future<void> _generateContentMultiTurnConversations(
// In the heart of a tranquil village nestled amidst the rolling hills of 17th century France...
}

Future<void> _streamGenerateContentTextInput(
final GoogleAIClient client,
) async {
final stream = client.streamGenerateContent(
modelId: 'gemini-pro',
request: const GenerateContentRequest(
contents: [
Content(
parts: [
Part(text: 'Write a story about a magic backpack.'),
],
),
],
generationConfig: GenerationConfig(
temperature: 0.8,
),
),
);

await for (final res in stream) {
print(res.candidates?.first.content?.parts?.first.text);
// In a quaint little town nestled amidst rolling hills, there lived a...
}
}

Future<void> _countTokens(final GoogleAIClient client) async {
final res = await client.countTokens(
modelId: 'gemini-pro',
Expand Down
50 changes: 50 additions & 0 deletions packages/googleai_dart/lib/src/client.dart
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// ignore_for_file: use_super_parameters
import 'dart:async';
import 'dart:convert';

import 'package:http/http.dart' as http;

import 'generated/client.dart' as g;
import 'generated/schema/schema.dart';
import 'http_client/http_client.dart';

/// Client for Google AI API (Gemini API).
Expand Down Expand Up @@ -39,8 +43,54 @@ class GoogleAIClient extends g.GoogleAIClient {
client: client ?? createDefaultHttpClient(),
);

// ------------------------------------------
// METHOD: streamGenerateContent
// ------------------------------------------

/// Generates a streamed response from the model given an input `GenerateContentRequest`.
///
/// `modelId`: The id of the model to use.
///
/// `request`: Request to generate a completion from the model.
///
/// `POST` `https://generativelanguage.googleapis.com/v1/models/{modelId}:streamGenerateContent`
Stream<GenerateContentResponse> streamGenerateContent({
final String modelId = 'gemini-pro',
final GenerateContentRequest? request,
}) async* {
final streamedResponse = await makeRequestStream(
baseUrl: 'https://generativelanguage.googleapis.com/v1',
queryParams: {
'alt': 'sse',
},
path: '/models/$modelId:streamGenerateContent',
method: g.HttpMethod.post,
requestType: 'application/json',
responseType: 'application/json',
body: request,
);

yield* streamedResponse.stream
.transform(const _GoogleAIStreamTransformer())
.map((final d) => GenerateContentResponse.fromJson(json.decode(d)));
}

@override
Future<http.BaseRequest> onRequest(final http.BaseRequest request) {
return onRequestHandler(request);
}
}

class _GoogleAIStreamTransformer
extends StreamTransformerBase<List<int>, String> {
const _GoogleAIStreamTransformer();

@override
Stream<String> bind(final Stream<List<int>> stream) {
return stream
.transform(utf8.decoder)
.transform(const LineSplitter())
.where((final s) => s.isNotEmpty)
.map((final s) => s.substring(6));
}
}
27 changes: 0 additions & 27 deletions packages/googleai_dart/lib/src/generated/client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -583,33 +583,6 @@ class GoogleAIClient {
return BatchEmbedContentsResponse.fromJson(_jsonDecode(r));
}

// ------------------------------------------
// METHOD: streamGenerateContent
// ------------------------------------------

/// Generates a streamed response from the model given an input `GenerateContentRequest`.
///
/// `modelId`: The id of the model to use.
///
/// `request`: Request to generate a completion from the model.
///
/// `POST` `https://generativelanguage.googleapis.com/v1/models/{modelId}:streamGenerateContent`
Future<GenerateContentResponse> streamGenerateContent({
String modelId = 'gemini-pro',
GenerateContentRequest? request,
}) async {
final r = await makeRequest(
baseUrl: 'https://generativelanguage.googleapis.com/v1',
path: '/models/$modelId:streamGenerateContent',
method: HttpMethod.post,
isMultipart: false,
requestType: 'application/json',
responseType: 'application/json',
body: request,
);
return GenerateContentResponse.fromJson(_jsonDecode(r));
}

// ------------------------------------------
// METHOD: embedContent
// ------------------------------------------
Expand Down
27 changes: 0 additions & 27 deletions packages/googleai_dart/oas/googleai_openapi_curated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1020,33 +1020,6 @@ paths:
schema:
type: string
default: embedding-001
/models/{modelId}:streamGenerateContent:
post:
description: >-
Generates a streamed response from the model given an input
`GenerateContentRequest`.
operationId: streamGenerateContent
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/GenerateContentRequest'
security: [ ]
responses:
'200':
description: Successful response
content:
application/json:
schema:
$ref: '#/components/schemas/GenerateContentResponse'
parameters:
- in: path
name: modelId
description: The id of the model to use.
required: false
schema:
type: string
default: gemini-pro
/models/{modelId}:embedContent:
post:
description: Generates an embedding from the model given an input `Content`.
Expand Down
1 change: 0 additions & 1 deletion packages/googleai_dart/test/count_tokens_test.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// ignore_for_file: avoid_redundant_argument_values
@TestOn('vm')
library; // Uses dart:io

Expand Down
1 change: 0 additions & 1 deletion packages/googleai_dart/test/embed_content_test.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// ignore_for_file: avoid_redundant_argument_values
@TestOn('vm')
library; // Uses dart:io

Expand Down
1 change: 0 additions & 1 deletion packages/googleai_dart/test/generate_content_test.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// ignore_for_file: avoid_redundant_argument_values
@TestOn('vm')
library; // Uses dart:io

Expand Down
1 change: 0 additions & 1 deletion packages/googleai_dart/test/model_info_test.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// ignore_for_file: avoid_redundant_argument_values
@TestOn('vm')
library; // Uses dart:io

Expand Down
Loading

0 comments on commit 2cbd538

Please sign in to comment.