From a8c5cdcfb752bae490d0b1f0d0dc6df3b721fc94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jochen=20Kirst=C3=A4tter?= <7329802+jochenkirstaetter@users.noreply.github.com> Date: Sat, 13 Apr 2024 15:42:44 +0400 Subject: [PATCH 1/2] improve response in SSE format --- src/Mscc.GenerativeAI/CHANGELOG.md | 3 + src/Mscc.GenerativeAI/GenerativeModel.cs | 133 ++++++++++-------- .../GoogleAi_GeminiPro_Should.cs | 9 +- 3 files changed, 80 insertions(+), 65 deletions(-) diff --git a/src/Mscc.GenerativeAI/CHANGELOG.md b/src/Mscc.GenerativeAI/CHANGELOG.md index 603dbd8..b0935c0 100644 --- a/src/Mscc.GenerativeAI/CHANGELOG.md +++ b/src/Mscc.GenerativeAI/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - implement Server-Sent Events (SSE) ### Changed + +- improve response in SSE format + ### Fixed ## 1.1.3 diff --git a/src/Mscc.GenerativeAI/GenerativeModel.cs b/src/Mscc.GenerativeAI/GenerativeModel.cs index 72ea920..e7f93dc 100644 --- a/src/Mscc.GenerativeAI/GenerativeModel.cs +++ b/src/Mscc.GenerativeAI/GenerativeModel.cs @@ -192,7 +192,7 @@ private string Model /// /// See Server-sent Events /// - public bool UseServerSentEvents { get; set; } = false; + public bool UseServerSentEventsFormat { get; set; } = false; /// /// Activate JSON Mode (default = no) @@ -611,10 +611,6 @@ public async Task GenerateContent(GenerateContentReques request.SystemInstruction ??= _systemInstruction; var url = ParseUrl(Url, Method); - if (UseServerSentEvents && _model == GenerativeAI.Model.Gemini10Pro.SanitizeModelName()) - { - url = url.AddQueryString(new Dictionary() { ["key"] = "sse" }); - } if (UseJsonMode) { request.GenerationConfig ??= new GenerationConfig(); @@ -713,58 +709,6 @@ public async Task GenerateContent(GenerateContentReques return await GenerateContent(request); } - /// - /// Generates a response from the model given an input GenerateContentRequest. - /// - /// Required. The request to send to the API. - /// - /// Response from the model for generated content. - /// Thrown when the is . - /// Thrown when the functionality is not supported by the model. - /// Thrown when the request fails to execute. - internal async IAsyncEnumerable> GenerateContentSSE(GenerateContentRequest? request, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - if (request == null) throw new ArgumentNullException(nameof(request)); - if (_model != GenerativeAI.Model.Gemini10Pro.SanitizeModelName()) throw new NotSupportedException(); - - request.GenerationConfig ??= _generationConfig; - request.SafetySettings ??= _safetySettings; - request.Tools ??= _tools; - - var url = ParseUrl(Url, Method).AddQueryString(new Dictionary() { ["key"] = "sse" }); - string json = Serialize(request); - var payload = new StringContent(json, Encoding.UTF8, MediaType); - // Todo: How to POST the request? - var message = new HttpRequestMessage - { - Method = HttpMethod.Post, - Content = payload, - RequestUri = new Uri(url), - Version = _httpVersion - }; - // message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue(MediaType)); - - using (var response = await Client.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken)) - { - response.EnsureSuccessStatusCode(); - if (response.Content is not null) - { - using (var sr = new StreamReader(await response.Content.ReadAsStreamAsync())) - { - while (!sr.EndOfStream) - { - var item = sr.ReadLineAsync(); - if (cancellationToken.IsCancellationRequested) - yield break; - yield return item; - } - } - } - } - } - /// /// Generates a streamed response from the model given an input GenerateContentRequest. /// This method uses a MemoryStream and StreamContent to send a streaming request to the API. @@ -774,9 +718,22 @@ public async Task GenerateContent(GenerateContentReques /// /// Stream of GenerateContentResponse with chunks asynchronously. /// Thrown when the is . + /// Thrown when the request fails to execute. + /// Thrown when the response ended prematurely. public async IAsyncEnumerable GenerateContentStream(GenerateContentRequest? request, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + if (UseServerSentEventsFormat) + { + await foreach (var item in GenerateContentStreamSSE(request, cancellationToken)) + { + if (cancellationToken.IsCancellationRequested) + yield break; + yield return item; + } + yield break; + } + if (request == null) throw new ArgumentNullException(nameof(request)); request.GenerationConfig ??= _generationConfig; @@ -786,10 +743,6 @@ public async Task GenerateContent(GenerateContentReques var method = "streamGenerateContent"; var url = ParseUrl(Url, method); - if (UseServerSentEvents && _model == GenerativeAI.Model.Gemini10Pro.SanitizeModelName()) - { - url = url.AddQueryString(new Dictionary() { ["key"] = "sse" }); - } if (UseJsonMode) { request.GenerationConfig ??= new GenerationConfig(); @@ -864,6 +817,64 @@ public async Task GenerateContent(GenerateContentReques return GenerateContentStream(request); } + /// + /// Generates a response from the model given an input GenerateContentRequest. + /// + /// Required. The request to send to the API. + /// + /// Response from the model for generated content. + /// Thrown when the is . + /// Thrown when the functionality is not supported by the model. + /// Thrown when the request fails to execute. + internal async IAsyncEnumerable GenerateContentStreamSSE(GenerateContentRequest? request, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (request == null) throw new ArgumentNullException(nameof(request)); + if (_model != GenerativeAI.Model.Gemini10Pro.SanitizeModelName()) throw new NotSupportedException(); + + request.GenerationConfig ??= _generationConfig; + request.SafetySettings ??= _safetySettings; + request.Tools ??= _tools; + + var method = "streamGenerateContent"; + var url = ParseUrl(Url, method).AddQueryString(new Dictionary() { ["alt"] = "sse" }); + string json = Serialize(request); + var payload = new StringContent(json, Encoding.UTF8, MediaType); + // Todo: How to POST the request? + var message = new HttpRequestMessage + { + Method = HttpMethod.Post, + Content = payload, + RequestUri = new Uri(url), + Version = _httpVersion + }; + // message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue(MediaType)); + + using (var response = await Client.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken)) + { + response.EnsureSuccessStatusCode(); + if (response.Content is not null) + { + using (var sr = new StreamReader(await response.Content.ReadAsStreamAsync())) + { + while (!sr.EndOfStream) + { + var data = await sr.ReadLineAsync(); + if (string.IsNullOrWhiteSpace(data)) + continue; + + var item = JsonSerializer.Deserialize( + data.Substring("data:".Length).Trim(), _options); + if (cancellationToken.IsCancellationRequested) + yield break; + yield return item; + } + } + } + } + } + /// /// Generates a grounded answer from the model given an input GenerateAnswerRequest. /// diff --git a/tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs b/tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs index 4c6a1b4..c3c7121 100644 --- a/tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs +++ b/tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs @@ -432,7 +432,7 @@ public async void GenerateContent_WithRequest_UseServerSentEvents() var prompt = "Write a story about a magic backpack."; var googleAi = new GoogleAI(apiKey: _fixture.ApiKey); var model = googleAi.GenerativeModel(model: _model); - model.UseServerSentEvents = true; + model.UseServerSentEventsFormat = true; var request = new GenerateContentRequest(prompt); // Act @@ -446,22 +446,23 @@ public async void GenerateContent_WithRequest_UseServerSentEvents() } [Fact] - public async void GenerateContent_WithRequest_ServerSentEvents() + public async void GenerateContent_Stream_WithRequest_ServerSentEvents() { // Arrange var prompt = "Write a story about a magic backpack."; var googleAi = new GoogleAI(apiKey: _fixture.ApiKey); var model = googleAi.GenerativeModel(model: _model); + model.UseServerSentEventsFormat = true; var request = new GenerateContentRequest(prompt); // Act - var responseEvents = model.GenerateContentSSE(request); + var responseEvents = model.GenerateContentStream(request); // Assert responseEvents.Should().NotBeNull(); await foreach (var response in responseEvents) { - _output.WriteLine($"{response}"); + _output.WriteLine($"{response.Text}"); } } From 8717f6014aff612a646d4953aa66ec789cac9f93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jochen=20Kirst=C3=A4tter?= <7329802+jochenkirstaetter@users.noreply.github.com> Date: Sat, 13 Apr 2024 16:26:37 +0400 Subject: [PATCH 2/2] remove obsolete ToDo --- src/Mscc.GenerativeAI/GenerativeModel.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Mscc.GenerativeAI/GenerativeModel.cs b/src/Mscc.GenerativeAI/GenerativeModel.cs index e7f93dc..7431c48 100644 --- a/src/Mscc.GenerativeAI/GenerativeModel.cs +++ b/src/Mscc.GenerativeAI/GenerativeModel.cs @@ -840,7 +840,6 @@ await foreach (var item in GenerateContentStreamSSE(request, cancellationToken)) var url = ParseUrl(Url, method).AddQueryString(new Dictionary() { ["alt"] = "sse" }); string json = Serialize(request); var payload = new StringContent(json, Encoding.UTF8, MediaType); - // Todo: How to POST the request? var message = new HttpRequestMessage { Method = HttpMethod.Post,