Skip to content

Commit

Permalink
Retry functionality for OpenAPI skills (#779)
Browse files Browse the repository at this point in the history
### Motivation and Context

This PR adds two improvements:
- It extends all ImportOpenAPISkill* and ImportChatGptPluginSkill*
methods to accept HTTP retry configuration.
- It improves HttpClient usage in OpenAPI skills by reusing the same
instance for the same set of imported skills.

The first improvement will allow OpenAPI skill to be more resilient
against unreliable REST API it tries to access.

The second improvement will decrease number of HttpClient instances
OpenAPI skills create. Instead of creating one instance of HttpClient
per skill, one instance will be created per set of skills that are
imported together by the same Import* method. This is the first step to
decrease number of HttpClient created for OpenAPI skills and it's not
the final one. Later, when agreed on a proper solution, the HttpClient
could be externalized so that client code/hosting app could provide its
own instance, if required.

### Description
1. All the ImportOpenAPISkill* and ImportChatGptPluginSkill* methods are
extended to accept and respect HttpRetryConfig.
2. Functionality that creates HttpClient is moved one level up so that
it's called once per import rather than for each skill in the import.
  • Loading branch information
SergeyMenshykh committed May 3, 2023
1 parent 06c57b6 commit 3013b5d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Connectors.WebApi.Rest;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Reliability;
using Microsoft.SemanticKernel.SkillDefinition;
using Microsoft.SemanticKernel.Skills.OpenAPI.Skills;

Expand All @@ -32,6 +33,7 @@ public static class KernelChatGptPluginExtensions
/// <param name="url">Url to in which to retrieve the ChatGPT plugin.</param>
/// <param name="httpClient">Optional HttpClient to use for the request.</param>
/// <param name="authCallback">Optional callback for adding auth data to the API requests.</param>
/// <param name="retryConfiguration">Optional retry configuration.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A list of all the semantic functions representing the skill.</returns>
public static async Task<IDictionary<string, ISKFunction>> ImportChatGptPluginSkillFromUrlAsync(
Expand All @@ -40,6 +42,7 @@ public static class KernelChatGptPluginExtensions
Uri url,
HttpClient? httpClient = null,
AuthenticateRequestAsyncCallback? authCallback = null,
HttpRetryConfig? retryConfiguration = null,
CancellationToken cancellationToken = default)
{
Verify.ValidSkillName(skillName);
Expand Down Expand Up @@ -69,7 +72,7 @@ public static class KernelChatGptPluginExtensions
string gptPluginJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
string? openApiUrl = ParseOpenApiUrl(gptPluginJson);

return await kernel.ImportOpenApiSkillFromUrlAsync(skillName, new Uri(openApiUrl), httpClient, authCallback, cancellationToken).ConfigureAwait(false);
return await kernel.ImportOpenApiSkillFromUrlAsync(skillName, new Uri(openApiUrl), httpClient, authCallback, retryConfiguration, cancellationToken).ConfigureAwait(false);
}
finally
{
Expand All @@ -87,13 +90,15 @@ public static class KernelChatGptPluginExtensions
/// <param name="skillName">Skill name.</param>
/// <param name="httpClient">Optional HttpClient to use for the request.</param>
/// <param name="authCallback">Optional callback for adding auth data to the API requests.</param>
/// <param name="retryConfiguration">Optional retry configuration.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A list of all the semantic functions representing the skill.</returns>
public static async Task<IDictionary<string, ISKFunction>> ImportChatGptPluginSkillFromResourceAsync(
this IKernel kernel,
string skillName,
HttpClient? httpClient = null,
AuthenticateRequestAsyncCallback? authCallback = null,
HttpRetryConfig? retryConfiguration = null,
CancellationToken cancellationToken = default)
{
Verify.ValidSkillName(skillName);
Expand All @@ -113,7 +118,7 @@ public static class KernelChatGptPluginExtensions

string? openApiUrl = ParseOpenApiUrl(gptPluginJson);

return await kernel.ImportOpenApiSkillFromUrlAsync(skillName, new Uri(openApiUrl), httpClient, authCallback, cancellationToken).ConfigureAwait(false);
return await kernel.ImportOpenApiSkillFromUrlAsync(skillName, new Uri(openApiUrl), httpClient, authCallback, retryConfiguration, cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand All @@ -124,6 +129,7 @@ public static class KernelChatGptPluginExtensions
/// <param name="skillDirectoryName">Name of the directory containing the selected skill.</param>
/// <param name="httpClient">Optional HttpClient to use for the request.</param>
/// <param name="authCallback">Optional callback for adding auth data to the API requests.</param>
/// <param name="retryConfiguration">Optional retry configuration.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A list of all the semantic functions representing the skill.</returns>
public async static Task<IDictionary<string, ISKFunction>> ImportChatGptPluginSkillSkillFromDirectoryAsync(
Expand All @@ -132,6 +138,7 @@ public static class KernelChatGptPluginExtensions
string skillDirectoryName,
HttpClient? httpClient = null,
AuthenticateRequestAsyncCallback? authCallback = null,
HttpRetryConfig? retryConfiguration = null,
CancellationToken cancellationToken = default)
{
const string ChatGptPluginFile = "ai-plugin.json";
Expand All @@ -151,7 +158,7 @@ public static class KernelChatGptPluginExtensions

using var stream = File.OpenRead(chatGptPluginPath);

return await kernel.RegisterOpenApiSkillAsync(stream, skillDirectoryName, authCallback, cancellationToken).ConfigureAwait(false);
return await kernel.RegisterOpenApiSkillAsync(stream, skillDirectoryName, authCallback, retryConfiguration, cancellationToken: cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand All @@ -161,13 +168,15 @@ public static class KernelChatGptPluginExtensions
/// <param name="skillName">Name of the skill to register.</param>
/// <param name="filePath">File path to the ChatGPT plugin definition.</param>
/// <param name="authCallback">Optional callback for adding auth data to the API requests.</param>
/// <param name="retryConfiguration">Optional retry configuration.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A list of all the semantic functions representing the skill.</returns>
public async static Task<IDictionary<string, ISKFunction>> ImportChatGptPluginSkillSkillFromFileAsync(
this IKernel kernel,
string skillName,
string filePath,
AuthenticateRequestAsyncCallback? authCallback = null,
HttpRetryConfig? retryConfiguration = null,
CancellationToken cancellationToken = default)
{
if (!File.Exists(filePath))
Expand All @@ -179,7 +188,7 @@ public static class KernelChatGptPluginExtensions

using var stream = File.OpenRead(filePath);

return await kernel.RegisterOpenApiSkillAsync(stream, skillName, authCallback, cancellationToken).ConfigureAwait(false);
return await kernel.RegisterOpenApiSkillAsync(stream, skillName, authCallback, retryConfiguration, cancellationToken: cancellationToken).ConfigureAwait(false);
}

private static string ParseOpenApiUrl(string gptPluginJson)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.SemanticKernel.Connectors.WebApi.Rest.Model;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Reliability;
using Microsoft.SemanticKernel.SkillDefinition;
using Microsoft.SemanticKernel.Skills.OpenAPI.Model;
using Microsoft.SemanticKernel.Skills.OpenAPI.OpenApi;
Expand All @@ -36,6 +37,7 @@ public static class KernelOpenApiExtensions
/// <param name="url">Url to in which to retrieve the OpenAPI definition.</param>
/// <param name="httpClient">Optional HttpClient to use for the request.</param>
/// <param name="authCallback">Optional callback for adding auth data to the API requests.</param>
/// <param name="retryConfiguration">Optional retry configuration.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A list of all the semantic functions representing the skill.</returns>
public static async Task<IDictionary<string, ISKFunction>> ImportOpenApiSkillFromUrlAsync(
Expand All @@ -44,6 +46,7 @@ public static class KernelOpenApiExtensions
Uri url,
HttpClient? httpClient = null,
AuthenticateRequestAsyncCallback? authCallback = null,
HttpRetryConfig? retryConfiguration = null,
CancellationToken cancellationToken = default)
{
Verify.ValidSkillName(skillName);
Expand Down Expand Up @@ -76,7 +79,7 @@ public static class KernelOpenApiExtensions
throw new MissingManifestResourceException($"Unable to load OpenApi skill from url '{url}'.");
}

return await kernel.RegisterOpenApiSkillAsync(stream, skillName, authCallback, cancellationToken).ConfigureAwait(false);
return await kernel.RegisterOpenApiSkillAsync(stream, skillName, authCallback, retryConfiguration, cancellationToken: cancellationToken).ConfigureAwait(false);
}
finally
{
Expand All @@ -90,12 +93,14 @@ public static class KernelOpenApiExtensions
/// <param name="kernel">Semantic Kernel instance.</param>
/// <param name="skillName">Skill name.</param>
/// <param name="authCallback">Optional callback for adding auth data to the API requests.</param>
/// <param name="retryConfiguration">Optional retry configuration.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A list of all the semantic functions representing the skill.</returns>
public static Task<IDictionary<string, ISKFunction>> ImportOpenApiSkillFromResourceAsync(
this IKernel kernel,
string skillName,
AuthenticateRequestAsyncCallback? authCallback = null,
HttpRetryConfig? retryConfiguration = null,
CancellationToken cancellationToken = default)
{
Verify.ValidSkillName(skillName);
Expand All @@ -110,7 +115,7 @@ public static class KernelOpenApiExtensions
throw new MissingManifestResourceException($"Unable to load OpenApi skill from assembly resource '{resourceName}'.");
}

return kernel.RegisterOpenApiSkillAsync(stream, skillName, authCallback, cancellationToken);
return kernel.RegisterOpenApiSkillAsync(stream, skillName, authCallback, retryConfiguration, cancellationToken: cancellationToken);
}

/// <summary>
Expand All @@ -120,13 +125,15 @@ public static class KernelOpenApiExtensions
/// <param name="parentDirectory">Directory containing the skill directory.</param>
/// <param name="skillDirectoryName">Name of the directory containing the selected skill.</param>
/// <param name="authCallback">Optional callback for adding auth data to the API requests.</param>
/// <param name="retryConfiguration">Optional retry configuration.</param>
/// <param name="cancellationToken"></param>
/// <returns>A list of all the semantic functions representing the skill.</returns>
public static async Task<IDictionary<string, ISKFunction>> ImportOpenApiSkillFromDirectoryAsync(
this IKernel kernel,
string parentDirectory,
string skillDirectoryName,
AuthenticateRequestAsyncCallback? authCallback = null,
HttpRetryConfig? retryConfiguration = null,
CancellationToken cancellationToken = default)
{
const string OpenApiFile = "openapi.json";
Expand All @@ -144,12 +151,11 @@ public static class KernelOpenApiExtensions

kernel.Log.LogTrace("Registering Rest functions from {0} OpenApi document", openApiDocumentPath);

// TODO: never used, why?
var skill = new Dictionary<string, ISKFunction>();

using var stream = File.OpenRead(openApiDocumentPath);

return await kernel.RegisterOpenApiSkillAsync(stream, skillDirectoryName, authCallback, cancellationToken).ConfigureAwait(false);
return await kernel.RegisterOpenApiSkillAsync(stream, skillDirectoryName, authCallback, retryConfiguration, cancellationToken: cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand All @@ -159,13 +165,15 @@ public static class KernelOpenApiExtensions
/// <param name="skillName">Name of the skill to register.</param>
/// <param name="filePath">File path to the OpenAPI document.</param>
/// <param name="authCallback">Optional callback for adding auth data to the API requests.</param>
/// <param name="retryConfiguration">Optional retry configuration.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A list of all the semantic functions representing the skill.</returns>
public static async Task<IDictionary<string, ISKFunction>> ImportOpenApiSkillFromFileAsync(
this IKernel kernel,
string skillName,
string filePath,
AuthenticateRequestAsyncCallback? authCallback = null,
HttpRetryConfig? retryConfiguration = null,
CancellationToken cancellationToken = default)
{
if (!File.Exists(filePath))
Expand All @@ -175,12 +183,9 @@ public static class KernelOpenApiExtensions

kernel.Log.LogTrace("Registering Rest functions from {0} OpenApi document", filePath);

// TODO: never used, why?
var skill = new Dictionary<string, ISKFunction>();

using var stream = File.OpenRead(filePath);

return await kernel.RegisterOpenApiSkillAsync(stream, skillName, authCallback, cancellationToken).ConfigureAwait(false);
return await kernel.RegisterOpenApiSkillAsync(stream, skillName, authCallback, retryConfiguration, cancellationToken: cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand All @@ -190,13 +195,17 @@ public static class KernelOpenApiExtensions
/// <param name="documentStream">OpenApi document stream.</param>
/// <param name="skillName">Skill name.</param>
/// <param name="authCallback">Optional callback for adding auth data to the API requests.</param>
/// <param name="retryConfiguration">Optional retry configuration.</param>
/// <param name="userAgent">Optional override for request-header field containing information about the user agent originating the request</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A list of all the semantic functions representing the skill.</returns>
public static async Task<IDictionary<string, ISKFunction>> RegisterOpenApiSkillAsync(
this IKernel kernel,
Stream documentStream,
string skillName,
AuthenticateRequestAsyncCallback? authCallback = null,
HttpRetryConfig? retryConfiguration = null,
string? userAgent = "Microsoft-Semantic-Kernel",
CancellationToken cancellationToken = default)
{
Verify.NotNull(kernel);
Expand All @@ -207,14 +216,25 @@ public static class KernelOpenApiExtensions

var operations = await parser.ParseAsync(documentStream, cancellationToken).ConfigureAwait(false);

#pragma warning disable CA2000 // Dispose objects before losing scope
//Creating HttpClient here until a proper solution allowing client code to provide its own instance is put in place.
var retryHandler = new DefaultHttpRetryHandler(retryConfiguration ?? new HttpRetryConfig(), kernel.Log) { InnerHandler = new HttpClientHandler() { CheckCertificateRevocationList = true } };
var httpClient = new HttpClient(retryHandler, true);
#pragma warning restore CA2000 // Dispose objects before losing scope

// User Agent may be a required request header fields for some Rest APIs,
// but this detail isn't specified in OpenAPI specs, so defaulting for all Rest APIs imported.
// Other applications can override this value by passing it as a parameter on execution.
var runner = new RestApiOperationRunner(httpClient, authCallback, userAgent);

var skill = new Dictionary<string, ISKFunction>();

foreach (var operation in operations)
{
try
{
kernel.Log.LogTrace("Registering Rest function {0}.{1}", skillName, operation.Id);
var function = kernel.RegisterRestApiFunction(skillName, operation, authCallback, cancellationToken: cancellationToken);
var function = kernel.RegisterRestApiFunction(skillName, runner, operation, cancellationToken);
skill[function.Name] = function;
}
catch (Exception ex) when (!ex.IsCriticalException())
Expand All @@ -235,30 +255,23 @@ public static class KernelOpenApiExtensions
/// </summary>
/// <param name="kernel">Semantic Kernel instance.</param>
/// <param name="skillName">Skill name.</param>
/// <param name="runner">The REST API operation runner.</param>
/// <param name="operation">The REST API operation.</param>
/// <param name="authCallback">Optional callback for adding auth data to the API requests.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <param name="userAgent">Optional override for request-header field containing information about the user agent originating the request</param>
/// <returns>An instance of <see cref="SKFunction"/> class.</returns>
private static ISKFunction RegisterRestApiFunction(
this IKernel kernel,
string skillName,
IRestApiOperationRunner runner,
RestApiOperation operation,
AuthenticateRequestAsyncCallback? authCallback = null,
string? userAgent = "Microsoft-Semantic-Kernel",
CancellationToken cancellationToken = default)
{
var restOperationParameters = operation.GetParameters();

// User Agent may be a required request header fields for some Rest APIs,
// but this detail isn't specified in OpenAPI specs, so defaulting for all Rest APIs imported.
// Other applications can override this value by passing it as a parameter on execution.
async Task<SKContext> ExecuteAsync(SKContext context)
{
try
{
var runner = new RestApiOperationRunner(new HttpClient(), authCallback, userAgent);

// Extract function arguments from context
var arguments = new Dictionary<string, string>();
foreach (var parameter in restOperationParameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Reliability;
using Microsoft.SemanticKernel.Skills.OpenAPI.Authentication;
using Microsoft.SemanticKernel.Skills.OpenAPI.Skills;
using RepoUtils;
Expand All @@ -22,23 +23,26 @@ public static async Task RunAsync()
new[] { "https://vault.azure.net/.default" },
new Uri("http://localhost"));

await GetSecretFromAzureKeyVaultAsync(authenticationProvider);
await GetSecretFromAzureKeyVaultWithRetryAsync(authenticationProvider);

await AddSecretToAzureKeyVaultAsync(authenticationProvider);
}

public static async Task GetSecretFromAzureKeyVaultAsync(InteractiveMsalAuthenticationProvider authenticationProvider)
public static async Task GetSecretFromAzureKeyVaultWithRetryAsync(InteractiveMsalAuthenticationProvider authenticationProvider)
{
var kernel = new KernelBuilder().WithLogger(ConsoleLogger.Log).Build();

var retryConfig = new HttpRetryConfig() { MaxRetryCount = 3, UseExponentialBackoff = true };

// Import a OpenApi skill using one of the following Kernel extension methods
// kernel.ImportOpenApiSkillFromResource
// kernel.ImportOpenApiSkillFromDirectory
// kernel.ImportOpenApiSkillFromFile
// kernel.ImportOpenApiSkillFromUrlAsync
// kernel.RegisterOpenApiSkill
var skill = await kernel.ImportOpenApiSkillFromResourceAsync(SkillResourceNames.AzureKeyVault,
authenticationProvider.AuthenticateRequestAsync);
authenticationProvider.AuthenticateRequestAsync,
retryConfig);

// Add arguments for required parameters, arguments for optional ones can be skipped.
var contextVariables = new ContextVariables();
Expand Down

0 comments on commit 3013b5d

Please sign in to comment.