From a8c5cdcfb752bae490d0b1f0d0dc6df3b721fc94 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jochen=20Kirst=C3=A4tter?=
<7329802+jochenkirstaetter@users.noreply.github.com>
Date: Sat, 13 Apr 2024 15:42:44 +0400
Subject: [PATCH 1/2] improve response in SSE format
---
src/Mscc.GenerativeAI/CHANGELOG.md | 3 +
src/Mscc.GenerativeAI/GenerativeModel.cs | 133 ++++++++++--------
.../GoogleAi_GeminiPro_Should.cs | 9 +-
3 files changed, 80 insertions(+), 65 deletions(-)
diff --git a/src/Mscc.GenerativeAI/CHANGELOG.md b/src/Mscc.GenerativeAI/CHANGELOG.md
index 603dbd8..b0935c0 100644
--- a/src/Mscc.GenerativeAI/CHANGELOG.md
+++ b/src/Mscc.GenerativeAI/CHANGELOG.md
@@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- implement Server-Sent Events (SSE)
### Changed
+
+- improve response in SSE format
+
### Fixed
## 1.1.3
diff --git a/src/Mscc.GenerativeAI/GenerativeModel.cs b/src/Mscc.GenerativeAI/GenerativeModel.cs
index 72ea920..e7f93dc 100644
--- a/src/Mscc.GenerativeAI/GenerativeModel.cs
+++ b/src/Mscc.GenerativeAI/GenerativeModel.cs
@@ -192,7 +192,7 @@ private string Model
///
/// See Server-sent Events
///
- public bool UseServerSentEvents { get; set; } = false;
+ public bool UseServerSentEventsFormat { get; set; } = false;
///
/// Activate JSON Mode (default = no)
@@ -611,10 +611,6 @@ public async Task GenerateContent(GenerateContentReques
request.SystemInstruction ??= _systemInstruction;
var url = ParseUrl(Url, Method);
- if (UseServerSentEvents && _model == GenerativeAI.Model.Gemini10Pro.SanitizeModelName())
- {
- url = url.AddQueryString(new Dictionary() { ["key"] = "sse" });
- }
if (UseJsonMode)
{
request.GenerationConfig ??= new GenerationConfig();
@@ -713,58 +709,6 @@ public async Task GenerateContent(GenerateContentReques
return await GenerateContent(request);
}
- ///
- /// Generates a response from the model given an input GenerateContentRequest.
- ///
- /// Required. The request to send to the API.
- ///
- /// Response from the model for generated content.
- /// Thrown when the is .
- /// Thrown when the functionality is not supported by the model.
- /// Thrown when the request fails to execute.
- internal async IAsyncEnumerable> GenerateContentSSE(GenerateContentRequest? request,
- [EnumeratorCancellation] CancellationToken cancellationToken = default)
- {
- if (request == null) throw new ArgumentNullException(nameof(request));
- if (_model != GenerativeAI.Model.Gemini10Pro.SanitizeModelName()) throw new NotSupportedException();
-
- request.GenerationConfig ??= _generationConfig;
- request.SafetySettings ??= _safetySettings;
- request.Tools ??= _tools;
-
- var url = ParseUrl(Url, Method).AddQueryString(new Dictionary() { ["key"] = "sse" });
- string json = Serialize(request);
- var payload = new StringContent(json, Encoding.UTF8, MediaType);
- // Todo: How to POST the request?
- var message = new HttpRequestMessage
- {
- Method = HttpMethod.Post,
- Content = payload,
- RequestUri = new Uri(url),
- Version = _httpVersion
- };
- // message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
- message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue(MediaType));
-
- using (var response = await Client.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken))
- {
- response.EnsureSuccessStatusCode();
- if (response.Content is not null)
- {
- using (var sr = new StreamReader(await response.Content.ReadAsStreamAsync()))
- {
- while (!sr.EndOfStream)
- {
- var item = sr.ReadLineAsync();
- if (cancellationToken.IsCancellationRequested)
- yield break;
- yield return item;
- }
- }
- }
- }
- }
-
///
/// Generates a streamed response from the model given an input GenerateContentRequest.
/// This method uses a MemoryStream and StreamContent to send a streaming request to the API.
@@ -774,9 +718,22 @@ public async Task GenerateContent(GenerateContentReques
///
/// Stream of GenerateContentResponse with chunks asynchronously.
/// Thrown when the is .
+ /// Thrown when the request fails to execute.
+ /// Thrown when the response ended prematurely.
public async IAsyncEnumerable GenerateContentStream(GenerateContentRequest? request,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
+ if (UseServerSentEventsFormat)
+ {
+ await foreach (var item in GenerateContentStreamSSE(request, cancellationToken))
+ {
+ if (cancellationToken.IsCancellationRequested)
+ yield break;
+ yield return item;
+ }
+ yield break;
+ }
+
if (request == null) throw new ArgumentNullException(nameof(request));
request.GenerationConfig ??= _generationConfig;
@@ -786,10 +743,6 @@ public async Task GenerateContent(GenerateContentReques
var method = "streamGenerateContent";
var url = ParseUrl(Url, method);
- if (UseServerSentEvents && _model == GenerativeAI.Model.Gemini10Pro.SanitizeModelName())
- {
- url = url.AddQueryString(new Dictionary() { ["key"] = "sse" });
- }
if (UseJsonMode)
{
request.GenerationConfig ??= new GenerationConfig();
@@ -864,6 +817,64 @@ public async Task GenerateContent(GenerateContentReques
return GenerateContentStream(request);
}
+ ///
+ /// Generates a response from the model given an input GenerateContentRequest.
+ ///
+ /// Required. The request to send to the API.
+ ///
+ /// Response from the model for generated content.
+ /// Thrown when the is .
+ /// Thrown when the functionality is not supported by the model.
+ /// Thrown when the request fails to execute.
+ internal async IAsyncEnumerable GenerateContentStreamSSE(GenerateContentRequest? request,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ if (request == null) throw new ArgumentNullException(nameof(request));
+ if (_model != GenerativeAI.Model.Gemini10Pro.SanitizeModelName()) throw new NotSupportedException();
+
+ request.GenerationConfig ??= _generationConfig;
+ request.SafetySettings ??= _safetySettings;
+ request.Tools ??= _tools;
+
+ var method = "streamGenerateContent";
+ var url = ParseUrl(Url, method).AddQueryString(new Dictionary() { ["alt"] = "sse" });
+ string json = Serialize(request);
+ var payload = new StringContent(json, Encoding.UTF8, MediaType);
+ // Todo: How to POST the request?
+ var message = new HttpRequestMessage
+ {
+ Method = HttpMethod.Post,
+ Content = payload,
+ RequestUri = new Uri(url),
+ Version = _httpVersion
+ };
+ // message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
+ message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue(MediaType));
+
+ using (var response = await Client.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken))
+ {
+ response.EnsureSuccessStatusCode();
+ if (response.Content is not null)
+ {
+ using (var sr = new StreamReader(await response.Content.ReadAsStreamAsync()))
+ {
+ while (!sr.EndOfStream)
+ {
+ var data = await sr.ReadLineAsync();
+ if (string.IsNullOrWhiteSpace(data))
+ continue;
+
+ var item = JsonSerializer.Deserialize(
+ data.Substring("data:".Length).Trim(), _options);
+ if (cancellationToken.IsCancellationRequested)
+ yield break;
+ yield return item;
+ }
+ }
+ }
+ }
+ }
+
///
/// Generates a grounded answer from the model given an input GenerateAnswerRequest.
///
diff --git a/tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs b/tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs
index 4c6a1b4..c3c7121 100644
--- a/tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs
+++ b/tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs
@@ -432,7 +432,7 @@ public async void GenerateContent_WithRequest_UseServerSentEvents()
var prompt = "Write a story about a magic backpack.";
var googleAi = new GoogleAI(apiKey: _fixture.ApiKey);
var model = googleAi.GenerativeModel(model: _model);
- model.UseServerSentEvents = true;
+ model.UseServerSentEventsFormat = true;
var request = new GenerateContentRequest(prompt);
// Act
@@ -446,22 +446,23 @@ public async void GenerateContent_WithRequest_UseServerSentEvents()
}
[Fact]
- public async void GenerateContent_WithRequest_ServerSentEvents()
+ public async void GenerateContent_Stream_WithRequest_ServerSentEvents()
{
// Arrange
var prompt = "Write a story about a magic backpack.";
var googleAi = new GoogleAI(apiKey: _fixture.ApiKey);
var model = googleAi.GenerativeModel(model: _model);
+ model.UseServerSentEventsFormat = true;
var request = new GenerateContentRequest(prompt);
// Act
- var responseEvents = model.GenerateContentSSE(request);
+ var responseEvents = model.GenerateContentStream(request);
// Assert
responseEvents.Should().NotBeNull();
await foreach (var response in responseEvents)
{
- _output.WriteLine($"{response}");
+ _output.WriteLine($"{response.Text}");
}
}
From 8717f6014aff612a646d4953aa66ec789cac9f93 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jochen=20Kirst=C3=A4tter?=
<7329802+jochenkirstaetter@users.noreply.github.com>
Date: Sat, 13 Apr 2024 16:26:37 +0400
Subject: [PATCH 2/2] remove obsolete ToDo
---
src/Mscc.GenerativeAI/GenerativeModel.cs | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/Mscc.GenerativeAI/GenerativeModel.cs b/src/Mscc.GenerativeAI/GenerativeModel.cs
index e7f93dc..7431c48 100644
--- a/src/Mscc.GenerativeAI/GenerativeModel.cs
+++ b/src/Mscc.GenerativeAI/GenerativeModel.cs
@@ -840,7 +840,6 @@ await foreach (var item in GenerateContentStreamSSE(request, cancellationToken))
var url = ParseUrl(Url, method).AddQueryString(new Dictionary() { ["alt"] = "sse" });
string json = Serialize(request);
var payload = new StringContent(json, Encoding.UTF8, MediaType);
- // Todo: How to POST the request?
var message = new HttpRequestMessage
{
Method = HttpMethod.Post,