Skip to content

Commit

Permalink
Merge pull request #24 from mscraftsman/sse
Browse files Browse the repository at this point in the history
improve response in SSE format
  • Loading branch information
jochenkirstaetter committed Apr 15, 2024
2 parents 57828e7 + 9e22759 commit 3281a7e
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 65 deletions.
1 change: 1 addition & 0 deletions src/Mscc.GenerativeAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- improve response in SSE format
- update samples to latest NuGet package

### Fixed
Expand Down
132 changes: 71 additions & 61 deletions src/Mscc.GenerativeAI/GenerativeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ private string Model
/// <remarks>
/// See <a href="https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events">Server-sent Events</a>
/// </remarks>
public bool UseServerSentEvents { get; set; } = false;
public bool UseServerSentEventsFormat { get; set; } = false;

/// <summary>
/// Activate JSON Mode (default = no)
Expand Down Expand Up @@ -611,10 +611,6 @@ public async Task<GenerateContentResponse> GenerateContent(GenerateContentReques
request.SystemInstruction ??= _systemInstruction;

var url = ParseUrl(Url, Method);
if (UseServerSentEvents && _model == GenerativeAI.Model.Gemini10Pro.SanitizeModelName())
{
url = url.AddQueryString(new Dictionary<string, string?>() { ["key"] = "sse" });
}
if (UseJsonMode)
{
request.GenerationConfig ??= new GenerationConfig();
Expand Down Expand Up @@ -713,58 +709,6 @@ public async Task<GenerateContentResponse> GenerateContent(GenerateContentReques
return await GenerateContent(request);
}

/// <summary>
/// Generates a response from the model given an input GenerateContentRequest.
/// </summary>
/// <param name="request">Required. The request to send to the API.</param>
/// <param name="cancellationToken"></param>
/// <returns>Response from the model for generated content.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="request"/> is <see langword="null"/>.</exception>
/// <exception cref="NotSupportedException">Thrown when the functionality is not supported by the model.</exception>
/// <exception cref="HttpRequestException">Thrown when the request fails to execute.</exception>
internal async IAsyncEnumerable<Task<string>> GenerateContentSSE(GenerateContentRequest? request,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (request == null) throw new ArgumentNullException(nameof(request));
if (_model != GenerativeAI.Model.Gemini10Pro.SanitizeModelName()) throw new NotSupportedException();

request.GenerationConfig ??= _generationConfig;
request.SafetySettings ??= _safetySettings;
request.Tools ??= _tools;

var url = ParseUrl(Url, Method).AddQueryString(new Dictionary<string, string?>() { ["key"] = "sse" });
string json = Serialize(request);
var payload = new StringContent(json, Encoding.UTF8, MediaType);
// Todo: How to POST the request?
var message = new HttpRequestMessage
{
Method = HttpMethod.Post,
Content = payload,
RequestUri = new Uri(url),
Version = _httpVersion
};
// message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue(MediaType));

using (var response = await Client.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken))
{
response.EnsureSuccessStatusCode();
if (response.Content is not null)
{
using (var sr = new StreamReader(await response.Content.ReadAsStreamAsync()))
{
while (!sr.EndOfStream)
{
var item = sr.ReadLineAsync();
if (cancellationToken.IsCancellationRequested)
yield break;
yield return item;
}
}
}
}
}

/// <summary>
/// Generates a streamed response from the model given an input GenerateContentRequest.
/// This method uses a MemoryStream and StreamContent to send a streaming request to the API.
Expand All @@ -774,9 +718,22 @@ public async Task<GenerateContentResponse> GenerateContent(GenerateContentReques
/// <param name="cancellationToken"></param>
/// <returns>Stream of GenerateContentResponse with chunks asynchronously.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="request"/> is <see langword="null"/>.</exception>
/// <exception cref="HttpRequestException">Thrown when the request fails to execute.</exception>
/// <exception cref="HttpIOException">Thrown when the response ended prematurely.</exception>
public async IAsyncEnumerable<GenerateContentResponse> GenerateContentStream(GenerateContentRequest? request,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (UseServerSentEventsFormat)
{
await foreach (var item in GenerateContentStreamSSE(request, cancellationToken))
{
if (cancellationToken.IsCancellationRequested)
yield break;
yield return item;
}
yield break;
}

if (request == null) throw new ArgumentNullException(nameof(request));

request.GenerationConfig ??= _generationConfig;
Expand All @@ -786,10 +743,6 @@ public async Task<GenerateContentResponse> GenerateContent(GenerateContentReques

var method = "streamGenerateContent";
var url = ParseUrl(Url, method);
if (UseServerSentEvents && _model == GenerativeAI.Model.Gemini10Pro.SanitizeModelName())
{
url = url.AddQueryString(new Dictionary<string, string?>() { ["key"] = "sse" });
}
if (UseJsonMode)
{
request.GenerationConfig ??= new GenerationConfig();
Expand Down Expand Up @@ -864,6 +817,63 @@ public async Task<GenerateContentResponse> GenerateContent(GenerateContentReques
return GenerateContentStream(request);
}

/// <summary>
/// Generates a response from the model given an input GenerateContentRequest.
/// </summary>
/// <param name="request">Required. The request to send to the API.</param>
/// <param name="cancellationToken"></param>
/// <returns>Response from the model for generated content.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="request"/> is <see langword="null"/>.</exception>
/// <exception cref="NotSupportedException">Thrown when the functionality is not supported by the model.</exception>
/// <exception cref="HttpRequestException">Thrown when the request fails to execute.</exception>
internal async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamSSE(GenerateContentRequest? request,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (request == null) throw new ArgumentNullException(nameof(request));
if (_model != GenerativeAI.Model.Gemini10Pro.SanitizeModelName()) throw new NotSupportedException();

request.GenerationConfig ??= _generationConfig;
request.SafetySettings ??= _safetySettings;
request.Tools ??= _tools;

var method = "streamGenerateContent";
var url = ParseUrl(Url, method).AddQueryString(new Dictionary<string, string?>() { ["alt"] = "sse" });
string json = Serialize(request);
var payload = new StringContent(json, Encoding.UTF8, MediaType);
var message = new HttpRequestMessage
{
Method = HttpMethod.Post,
Content = payload,
RequestUri = new Uri(url),
Version = _httpVersion
};
// message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
message.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue(MediaType));

using (var response = await Client.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken))
{
response.EnsureSuccessStatusCode();
if (response.Content is not null)
{
using (var sr = new StreamReader(await response.Content.ReadAsStreamAsync()))
{
while (!sr.EndOfStream)
{
var data = await sr.ReadLineAsync();
if (string.IsNullOrWhiteSpace(data))
continue;

var item = JsonSerializer.Deserialize<GenerateContentResponse>(
data.Substring("data:".Length).Trim(), _options);
if (cancellationToken.IsCancellationRequested)
yield break;
yield return item;
}
}
}
}
}

/// <summary>
/// Generates a grounded answer from the model given an input GenerateAnswerRequest.
/// </summary>
Expand Down
9 changes: 5 additions & 4 deletions tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ public async void GenerateContent_WithRequest_UseServerSentEvents()
var prompt = "Write a story about a magic backpack.";
var googleAi = new GoogleAI(apiKey: _fixture.ApiKey);
var model = googleAi.GenerativeModel(model: _model);
model.UseServerSentEvents = true;
model.UseServerSentEventsFormat = true;
var request = new GenerateContentRequest(prompt);

// Act
Expand All @@ -446,22 +446,23 @@ public async void GenerateContent_WithRequest_UseServerSentEvents()
}

[Fact]
public async void GenerateContent_WithRequest_ServerSentEvents()
public async void GenerateContent_Stream_WithRequest_ServerSentEvents()
{
// Arrange
var prompt = "Write a story about a magic backpack.";
var googleAi = new GoogleAI(apiKey: _fixture.ApiKey);
var model = googleAi.GenerativeModel(model: _model);
model.UseServerSentEventsFormat = true;
var request = new GenerateContentRequest(prompt);

// Act
var responseEvents = model.GenerateContentSSE(request);
var responseEvents = model.GenerateContentStream(request);

// Assert
responseEvents.Should().NotBeNull();
await foreach (var response in responseEvents)
{
_output.WriteLine($"{response}");
_output.WriteLine($"{response.Text}");
}
}

Expand Down

0 comments on commit 3281a7e

Please sign in to comment.