Skip to content

Commit

Permalink
.Net: Google connector API version selection (#5750)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Closes #5659

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

The update introduces support for different versions of the Google API
in various services and clients. A new enum, 'GoogleApiVersion', has
been added to represent stable and beta versions of the Google API.
Affected classes have been updated to accept this new parameter, and use
its value when constructing API endpoints.

GoogleAI endpoints currently support only BETA.

cc: @RogerBarreto 

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
  • Loading branch information
Krzysztof318 and RogerBarreto committed Apr 23, 2024
1 parent 47c5d92 commit 875477e
Show file tree
Hide file tree
Showing 24 changed files with 160 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ private GeminiChatCompletionClient CreateChatCompletionClient(
return new GeminiChatCompletionClient(
httpClient: httpClient ?? this._httpClient,
modelId: modelId,
apiVersion: GoogleAIVersion.V1,
apiKey: "fake-key");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ private GeminiChatCompletionClient CreateChatCompletionClient(
return new GeminiChatCompletionClient(
httpClient: httpClient ?? this._httpClient,
modelId: modelId,
apiVersion: VertexAIVersion.V1,
bearerTokenProvider: () => Task.FromResult(bearerKey),
location: "fake-location",
projectId: "fake-project-id");
Expand All @@ -443,6 +444,7 @@ private GeminiChatCompletionClient CreateChatCompletionClient(
return new GeminiChatCompletionClient(
httpClient: httpClient ?? this._httpClient,
modelId: modelId,
apiVersion: GoogleAIVersion.V1,
apiKey: "fake-key");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ private GeminiChatCompletionClient CreateChatCompletionClient(
return new GeminiChatCompletionClient(
httpClient: httpClient ?? this._httpClient,
modelId: modelId,
apiVersion: GoogleAIVersion.V1,
apiKey: "fake-key");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,15 @@ private GeminiChatCompletionClient CreateChatCompletionClient(
httpClient: httpClient ?? this._httpClient,
modelId: modelId,
bearerTokenProvider: () => Task.FromResult(bearerKey),
apiVersion: VertexAIVersion.V1,
location: "fake-location",
projectId: "fake-project-id");
}

return new GeminiChatCompletionClient(
httpClient: httpClient ?? this._httpClient,
modelId: modelId,
apiVersion: GoogleAIVersion.V1,
apiKey: "fake-key");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Connectors.Google;
using Microsoft.SemanticKernel.Connectors.Google.Core;
using Microsoft.SemanticKernel.Http;
using Xunit;
Expand Down Expand Up @@ -124,13 +125,15 @@ private GeminiTokenCounterClient CreateTokenCounterClient(
httpClient: this._httpClient,
modelId: modelId,
bearerTokenProvider: () => Task.FromResult(bearerKey),
apiVersion: VertexAIVersion.V1,
location: "fake-location",
projectId: "fake-project-id");
}

return new GeminiTokenCounterClient(
httpClient: this._httpClient,
modelId: modelId,
apiVersion: GoogleAIVersion.V1,
apiKey: "fake-key");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Net.Http;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Connectors.Google;
using Microsoft.SemanticKernel.Connectors.Google.Core;
using Microsoft.SemanticKernel.Http;
using Xunit;
Expand Down Expand Up @@ -147,6 +148,7 @@ private GoogleAIEmbeddingClient CreateEmbeddingsClient(
var client = new GoogleAIEmbeddingClient(
httpClient: this._httpClient,
modelId: modelId,
apiVersion: GoogleAIVersion.V1,
apiKey: "fake-key");
return client;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Net.Http;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Connectors.Google;
using Microsoft.SemanticKernel.Connectors.Google.Core;
using Microsoft.SemanticKernel.Http;
using Xunit;
Expand Down Expand Up @@ -143,6 +144,7 @@ private VertexAIEmbeddingClient CreateEmbeddingsClient(
httpClient: this._httpClient,
modelId: modelId,
bearerTokenProvider: () => Task.FromResult(bearerKey ?? "fake-key"),
apiVersion: VertexAIVersion.V1,
location: "us-central1",
projectId: "fake-project-id");
return client;
Expand Down
15 changes: 15 additions & 0 deletions dotnet/src/Connectors/Connectors.Google/Core/ClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,19 @@ protected void Log(LogLevel logLevel, string? message, params object[] args)
#pragma warning restore CA2254
}
}

protected static string GetApiVersionSubLink(GoogleAIVersion apiVersion)
=> apiVersion switch
{
GoogleAIVersion.V1 => "v1",
GoogleAIVersion.V1_Beta => "v1beta",
_ => throw new NotSupportedException($"Google API version {apiVersion} is not supported.")
};

protected static string GetApiVersionSubLink(VertexAIVersion apiVersion)
=> apiVersion switch
{
VertexAIVersion.V1 => "v1",
_ => throw new NotSupportedException($"Vertex API version {apiVersion} is not supported.")
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,13 @@ internal sealed class GeminiChatCompletionClient : ClientBase
/// <param name="httpClient">HttpClient instance used to send HTTP requests</param>
/// <param name="modelId">Id of the model supporting chat completion</param>
/// <param name="apiKey">Api key for GoogleAI endpoint</param>
/// <param name="apiVersion">Version of the Google API</param>
/// <param name="logger">Logger instance used for logging (optional)</param>
public GeminiChatCompletionClient(
HttpClient httpClient,
string modelId,
string apiKey,
GoogleAIVersion apiVersion,
ILogger? logger = null)
: base(
httpClient: httpClient,
Expand All @@ -100,9 +102,11 @@ public GeminiChatCompletionClient(
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(apiKey);

string versionSubLink = GetApiVersionSubLink(apiVersion);

this._modelId = modelId;
this._chatGenerationEndpoint = new Uri($"https://generativelanguage.googleapis.com/v1beta/models/{this._modelId}:generateContent?key={apiKey}");
this._chatStreamingEndpoint = new Uri($"https://generativelanguage.googleapis.com/v1beta/models/{this._modelId}:streamGenerateContent?key={apiKey}&alt=sse");
this._chatGenerationEndpoint = new Uri($"https://generativelanguage.googleapis.com/{versionSubLink}/models/{this._modelId}:generateContent?key={apiKey}");
this._chatStreamingEndpoint = new Uri($"https://generativelanguage.googleapis.com/{versionSubLink}/models/{this._modelId}:streamGenerateContent?key={apiKey}&alt=sse");
}

/// <summary>
Expand All @@ -113,13 +117,15 @@ public GeminiChatCompletionClient(
/// <param name="bearerTokenProvider">Bearer key provider used for authentication</param>
/// <param name="location">The region to process the request</param>
/// <param name="projectId">Project ID from google cloud</param>
/// <param name="apiVersion">Version of the Vertex API</param>
/// <param name="logger">Logger instance used for logging (optional)</param>
public GeminiChatCompletionClient(
HttpClient httpClient,
string modelId,
Func<Task<string>> bearerTokenProvider,
string location,
string projectId,
VertexAIVersion apiVersion,
ILogger? logger = null)
: base(
httpClient: httpClient,
Expand All @@ -130,9 +136,11 @@ public GeminiChatCompletionClient(
Verify.NotNullOrWhiteSpace(location);
Verify.NotNullOrWhiteSpace(projectId);

string versionSubLink = GetApiVersionSubLink(apiVersion);

this._modelId = modelId;
this._chatGenerationEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/v1/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:generateContent");
this._chatStreamingEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/v1/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:streamGenerateContent?alt=sse");
this._chatGenerationEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/{versionSubLink}/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:generateContent");
this._chatStreamingEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/{versionSubLink}/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:streamGenerateContent?alt=sse");
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ internal sealed class GeminiTokenCounterClient : ClientBase
/// <param name="httpClient">HttpClient instance used to send HTTP requests</param>
/// <param name="modelId">Id of the model to use to counting tokens</param>
/// <param name="apiKey">Api key for GoogleAI endpoint</param>
/// <param name="apiVersion">Version of the Google API</param>
/// <param name="logger">Logger instance used for logging (optional)</param>
public GeminiTokenCounterClient(
HttpClient httpClient,
string modelId,
string apiKey,
GoogleAIVersion apiVersion,
ILogger? logger = null)
: base(
httpClient: httpClient,
Expand All @@ -36,8 +38,10 @@ public GeminiTokenCounterClient(
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(apiKey);

string versionSubLink = GetApiVersionSubLink(apiVersion);

this._modelId = modelId;
this._tokenCountingEndpoint = new Uri($"https://generativelanguage.googleapis.com/v1beta/models/{this._modelId}:countTokens?key={apiKey}");
this._tokenCountingEndpoint = new Uri($"https://generativelanguage.googleapis.com/{versionSubLink}/models/{this._modelId}:countTokens?key={apiKey}");
}

/// <summary>
Expand All @@ -48,13 +52,15 @@ public GeminiTokenCounterClient(
/// <param name="bearerTokenProvider">Bearer key provider used for authentication</param>
/// <param name="location">The region to process the request</param>
/// <param name="projectId">Project ID from google cloud</param>
/// <param name="apiVersion">Version of the Vertex API</param>
/// <param name="logger">Logger instance used for logging (optional)</param>
public GeminiTokenCounterClient(
HttpClient httpClient,
string modelId,
Func<Task<string>> bearerTokenProvider,
string location,
string projectId,
VertexAIVersion apiVersion,
ILogger? logger = null)
: base(
httpClient: httpClient,
Expand All @@ -65,8 +71,10 @@ public GeminiTokenCounterClient(
Verify.NotNullOrWhiteSpace(location);
Verify.NotNullOrWhiteSpace(projectId);

string versionSubLink = GetApiVersionSubLink(apiVersion);

this._modelId = modelId;
this._tokenCountingEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/v1/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:countTokens");
this._tokenCountingEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/{versionSubLink}/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:countTokens");
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ internal sealed class GoogleAIEmbeddingClient : ClientBase
/// <param name="httpClient">HttpClient instance used to send HTTP requests</param>
/// <param name="modelId">Embeddings generation model id</param>
/// <param name="apiKey">Api key for GoogleAI endpoint</param>
/// <param name="apiVersion">Version of the Google API</param>
/// <param name="logger">Logger instance used for logging (optional)</param>
public GoogleAIEmbeddingClient(
HttpClient httpClient,
string modelId,
string apiKey,
GoogleAIVersion apiVersion,
ILogger? logger = null)
: base(
httpClient: httpClient,
Expand All @@ -37,8 +39,10 @@ public GoogleAIEmbeddingClient(
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(apiKey);

string versionSubLink = GetApiVersionSubLink(apiVersion);

this._embeddingModelId = modelId;
this._embeddingEndpoint = new Uri($"https://generativelanguage.googleapis.com/v1beta/models/{this._embeddingModelId}:batchEmbedContents?key={apiKey}");
this._embeddingEndpoint = new Uri($"https://generativelanguage.googleapis.com/{versionSubLink}/models/{this._embeddingModelId}:batchEmbedContents?key={apiKey}");
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ internal sealed class VertexAIEmbeddingClient : ClientBase
/// <param name="bearerTokenProvider">Bearer key provider used for authentication</param>
/// <param name="location">The region to process the request</param>
/// <param name="projectId">Project ID from google cloud</param>
/// <param name="apiVersion">Version of the Vertex API</param>
/// <param name="logger">Logger instance used for logging (optional)</param>
public VertexAIEmbeddingClient(
HttpClient httpClient,
string modelId,
Func<Task<string>> bearerTokenProvider,
string location,
string projectId,
VertexAIVersion apiVersion,
ILogger? logger = null)
: base(
httpClient: httpClient,
Expand All @@ -43,8 +45,10 @@ public VertexAIEmbeddingClient(
Verify.NotNullOrWhiteSpace(location);
Verify.NotNullOrWhiteSpace(projectId);

string versionSubLink = GetApiVersionSubLink(apiVersion);

this._embeddingModelId = modelId;
this._embeddingEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/v1/projects/{projectId}/locations/{location}/publishers/google/models/{this._embeddingModelId}:predict");
this._embeddingEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/{versionSubLink}/projects/{projectId}/locations/{location}/publishers/google/models/{this._embeddingModelId}:predict");
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ public static class GoogleAIKernelBuilderExtensions
/// <param name="builder">The kernel builder.</param>
/// <param name="modelId">The model for text generation.</param>
/// <param name="apiKey">The API key for authentication Gemini API.</param>
/// <param name="apiVersion">The version of the Google API.</param>
/// <param name="serviceId">The optional service ID.</param>
/// <param name="httpClient">The optional custom HttpClient.</param>
/// <returns>The updated kernel builder.</returns>
public static IKernelBuilder AddGoogleAIGeminiChatCompletion(
this IKernelBuilder builder,
string modelId,
string apiKey,
GoogleAIVersion apiVersion = GoogleAIVersion.V1_Beta, // todo: change beta to stable when stable version will be available
string? serviceId = null,
HttpClient? httpClient = null)
{
Expand All @@ -39,6 +41,7 @@ public static IKernelBuilder AddGoogleAIGeminiChatCompletion(
new GoogleAIGeminiChatCompletionService(
modelId: modelId,
apiKey: apiKey,
apiVersion: apiVersion,
httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider),
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
return builder;
Expand All @@ -50,13 +53,15 @@ public static IKernelBuilder AddGoogleAIGeminiChatCompletion(
/// <param name="builder">The kernel builder.</param>
/// <param name="modelId">The model for text generation.</param>
/// <param name="apiKey">The API key for authentication Gemini API.</param>
/// <param name="apiVersion">The version of the Google API.</param>
/// <param name="serviceId">The optional service ID.</param>
/// <param name="httpClient">The optional custom HttpClient.</param>
/// <returns>The updated kernel builder.</returns>
public static IKernelBuilder AddGoogleAIEmbeddingGeneration(
this IKernelBuilder builder,
string modelId,
string apiKey,
GoogleAIVersion apiVersion = GoogleAIVersion.V1_Beta, // todo: change beta to stable when stable version will be available
string? serviceId = null,
HttpClient? httpClient = null)
{
Expand All @@ -68,6 +73,7 @@ public static IKernelBuilder AddGoogleAIEmbeddingGeneration(
new GoogleAITextEmbeddingGenerationService(
modelId: modelId,
apiKey: apiKey,
apiVersion: apiVersion,
httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider),
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ public static class GoogleAIMemoryBuilderExtensions
/// <param name="builder">The <see cref="MemoryBuilder"/> instance</param>
/// <param name="modelId">The model for text generation.</param>
/// <param name="apiKey">The API key for authentication Gemini API.</param>
/// <param name="apiVersion">The version of the Google API.</param>
/// <param name="httpClient">The optional custom HttpClient.</param>
/// <returns>The updated memory builder.</returns>
public static MemoryBuilder WithGoogleAITextEmbeddingGeneration(
this MemoryBuilder builder,
string modelId,
string apiKey,
GoogleAIVersion apiVersion = GoogleAIVersion.V1_Beta,
HttpClient? httpClient = null)
{
Verify.NotNull(builder);
Expand All @@ -34,6 +36,7 @@ public static MemoryBuilder WithGoogleAITextEmbeddingGeneration(
new GoogleAITextEmbeddingGenerationService(
modelId: modelId,
apiKey: apiKey,
apiVersion: apiVersion,
httpClient: HttpClientProvider.GetHttpClient(httpClient ?? builderHttpClient),
loggerFactory: loggerFactory));
}
Expand Down
Loading

0 comments on commit 875477e

Please sign in to comment.