Skip to content

Commit

Permalink
.Net Enable Usage of Custom Compatible Chat Message API Endpoints wit…
Browse files Browse the repository at this point in the history
…h OpenAI Connector + Examples (#4753)

### Motivation and Context

- Allow usage of custom Message API (OpenAI ChatCompletion Standard)
compliant endpoints with the OpenAI Connector.
- Refactoring of OpenAI Models and Classes Structure
- Adding Examples on using the current changes against `LMStudio`,
`Ollama` and `LocalAI` Message APIs.

---------

Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com>
  • Loading branch information
RogerBarreto and markwallace-microsoft committed Apr 16, 2024
1 parent e416946 commit beef63c
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Xunit;
using Xunit.Abstractions;

namespace Examples;

/// <summary>
/// This example shows a way of using OpenAI connector with other APIs that supports the same ChatCompletion Message API standard from OpenAI.
///
/// To proceed with this example will be necessary to follow those steps:
/// 1. Install LMStudio Platform in your environment
/// 2. Open LM Studio
/// 3. Search and Download both Phi2 and Llama2 models (preferably the ones that uses 8GB RAM or more)
/// 4. Start the Message API Server on http://localhost:1234
/// 5. Run the examples.
///
/// OR
///
/// 1. Start the Ollama Message API Server on http://localhost:11434 using docker
/// 2. docker run -d --gpus=all -v "d:\temp\ollama:/root/.ollama" -p 11434:11434 --name ollama ollama/ollama <see href="https://ollama.com/blog/ollama-is-now-available-as-an-official-docker-image" />
/// 3. Set Llama2 as the current ollama model: docker exec -it ollama ollama run llama2
/// 4. Run the Ollama examples.
///
/// OR
///
/// 1. Start the LocalAI Message API Server on http://localhost:8080
/// 2. docker run -ti -p 8080:8080 localai/localai:v2.12.3-ffmpeg-core phi-2 <see href="https://localai.io/docs/getting-started/run-other-models/" />
/// 3. Run the LocalAI examples.
/// </summary>
public class Example88_CustomMessageAPIEndpoint : BaseTest
{
[Theory(Skip = "Manual configuration needed")]
[InlineData("LMStudio", "http://localhost:1234", "llama2")] // Setup Llama2 as the model in LM Studio UI and start the Message API Server on http://localhost:1234
[InlineData("Ollama", "http://localhost:11434", "llama2")] // Start the Ollama Message API Server on http://localhost:11434 using docker
[InlineData("LocalAI", "http://localhost:8080", "phi-2")]
public async Task LocalModel_ExampleAsync(string messageAPIPlatform, string url, string modelId)
{
WriteLine($"Example using local {messageAPIPlatform}");
// Setup Llama2 as the model in LM Studio UI.

var kernel = Kernel.CreateBuilder()
.AddOpenAIChatCompletion(
modelId: modelId,
apiKey: null,
endpoint: new Uri(url))
.Build();

var prompt = @"Rewrite the text between triple backticks into a business mail. Use a professional tone, be clear and concise.
Sign the mail as AI Assistant.
Text: ```{{$input}}```";

var mailFunction = kernel.CreateFunctionFromPrompt(prompt, new OpenAIPromptExecutionSettings
{
TopP = 0.5,
MaxTokens = 1000,
});

var response = await kernel.InvokeAsync(mailFunction, new() { ["input"] = "Tell David that I'm going to finish the business plan by the end of the week." });
this.WriteLine(response);
}

[Theory(Skip = "Manual configuration needed")]
[InlineData("LMStudio", "http://localhost:1234", "llama2")] // Setup Llama2 as the model in LM Studio UI and start the Message API Server on http://localhost:1234
[InlineData("Ollama", "http://localhost:11434", "llama2")] // Start the Ollama Message API Server on http://localhost:11434 using docker
[InlineData("LocalAI", "http://localhost:8080", "phi-2")]
public async Task LocalModel_StreamingExampleAsync(string messageAPIPlatform, string url, string modelId)
{
WriteLine($"Example using local {messageAPIPlatform}");

var kernel = Kernel.CreateBuilder()
.AddOpenAIChatCompletion(
modelId: modelId,
apiKey: null,
endpoint: new Uri(url))
.Build();

var prompt = @"Rewrite the text between triple backticks into a business mail. Use a professional tone, be clear and concise.
Sign the mail as AI Assistant.
Text: ```{{$input}}```";

var mailFunction = kernel.CreateFunctionFromPrompt(prompt, new OpenAIPromptExecutionSettings
{
TopP = 0.5,
MaxTokens = 1000,
});

await foreach (var word in kernel.InvokeStreamingAsync(mailFunction, new() { ["input"] = "Tell David that I'm going to finish the business plan by the end of the week." }))
{
this.WriteLine(word);
};
}

public Example88_CustomMessageAPIEndpoint(ITestOutputHelper output) : base(output)
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<IsTestProject>true</IsTestProject>
<IsPackable>false</IsPackable>
<!-- Suppress: "Declare types in namespaces", "Require ConfigureAwait", "Experimental" -->
<NoWarn>CS8618,IDE0009,CA1051,CA1050,CA1707,CA2007,VSTHRD111,CS1591,RCS1110,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101</NoWarn>
<NoWarn>CS8618,IDE0009,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101</NoWarn>
<OutputType>Library</OutputType>
</PropertyGroup>
<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ public OpenAIAudioToTextService(
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
this._core = new(modelId, apiKey, organization, httpClient, loggerFactory?.CreateLogger(typeof(OpenAIAudioToTextService)));
this._core = new(
modelId: modelId,
apiKey: apiKey,
organization: organization,
httpClient: httpClient,
logger: loggerFactory?.CreateLogger(typeof(OpenAIAudioToTextService)));

this._core.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
this._core.AddAttribute(OpenAIClientCore.OrganizationKey, organization);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using Azure.Core;
using Azure.Core.Pipeline;

namespace Microsoft.SemanticKernel.Connectors.OpenAI.Core.AzureSdk;

internal class CustomHostPipelinePolicy : HttpPipelineSynchronousPolicy
{
private readonly Uri _endpoint;

internal CustomHostPipelinePolicy(Uri endpoint)
{
this._endpoint = endpoint;
}
public override void OnSendingRequest(HttpMessage message)
{
if (message?.Request == null)
{
return;
}

// Update current host to provided endpoint
message.Request.Uri.Reset(this._endpoint);
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Net.Http;
using System.Runtime.CompilerServices;
using Azure.AI.OpenAI;
using Azure.Core;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Connectors.OpenAI.Core.AzureSdk;
using Microsoft.SemanticKernel.Services;

namespace Microsoft.SemanticKernel.Connectors.OpenAI;
Expand All @@ -29,18 +31,19 @@ internal sealed class OpenAIClientCore : ClientCore
/// </summary>
/// <param name="modelId">Model name.</param>
/// <param name="apiKey">OpenAI API Key.</param>
/// <param name="endpoint">OpenAI compatible API endpoint.</param>
/// <param name="organization">OpenAI Organization Id (usually optional).</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="logger">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
internal OpenAIClientCore(
string modelId,
string apiKey,
string? apiKey = null,
Uri? endpoint = null,
string? organization = null,
HttpClient? httpClient = null,
ILogger? logger = null) : base(logger)
{
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(apiKey);

this.DeploymentOrModelName = modelId;

Expand All @@ -51,7 +54,17 @@ internal OpenAIClientCore(
options.AddPolicy(new AddHeaderRequestPolicy("OpenAI-Organization", organization!), HttpPipelinePosition.PerCall);
}

this.Client = new OpenAIClient(apiKey, options);
// Accepts the endpoint if provided, otherwise uses the default OpenAI endpoint.
var providedEndpoint = endpoint ?? httpClient?.BaseAddress;
if (providedEndpoint is null)
{
Verify.NotNullOrWhiteSpace(apiKey); // For Public OpenAI Endpoint a key must be provided.
}
else
{
options.AddPolicy(new CustomHostPipelinePolicy(providedEndpoint), Azure.Core.HttpPipelinePosition.PerRetry);
}
this.Client = new OpenAIClient(apiKey ?? string.Empty, options);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -32,10 +34,61 @@ public OpenAIChatCompletionService(
string apiKey,
string? organization = null,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
ILoggerFactory? loggerFactory = null
)
{
this._core = new(
modelId,
apiKey,
endpoint: null,
organization,
httpClient,
loggerFactory?.CreateLogger(typeof(OpenAIChatCompletionService)));

this._core.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
this._core.AddAttribute(OpenAIClientCore.OrganizationKey, organization);
}

/// <summary>
/// Create an instance of the Custom Message API OpenAI chat completion connector
/// </summary>
/// <param name="modelId">Model name</param>
/// <param name="endpoint">Custom Message API compatible endpoint</param>
/// <param name="apiKey">OpenAI API Key</param>
/// <param name="organization">OpenAI Organization Id (usually optional)</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
[Experimental("SKEXP0010")]
public OpenAIChatCompletionService(
string modelId,
Uri endpoint,
string? apiKey = null,
string? organization = null,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
this._core = new(modelId, apiKey, organization, httpClient, loggerFactory?.CreateLogger(typeof(OpenAIChatCompletionService)));
Uri? internalClientEndpoint = null;
var providedEndpoint = endpoint ?? httpClient?.BaseAddress;
if (providedEndpoint is not null)
{
// If the provided endpoint does not have a path specified, updates it to the default Message API Chat Completions endpoint
internalClientEndpoint = providedEndpoint.PathAndQuery == "/" ?
new Uri(providedEndpoint, "v1/chat/completions")
: providedEndpoint;
}

this._core = new(
modelId,
apiKey,
internalClientEndpoint,
organization,
httpClient,
loggerFactory?.CreateLogger(typeof(OpenAIChatCompletionService)));

if (providedEndpoint is not null)
{
this._core.AddAttribute(AIServiceExtensions.EndpointKey, providedEndpoint.ToString());
}
this._core.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
this._core.AddAttribute(OpenAIClientCore.OrganizationKey, organization);
}
Expand All @@ -51,7 +104,10 @@ public OpenAIChatCompletionService(
OpenAIClient openAIClient,
ILoggerFactory? loggerFactory = null)
{
this._core = new(modelId, openAIClient, loggerFactory?.CreateLogger(typeof(OpenAIChatCompletionService)));
this._core = new(
modelId,
openAIClient,
loggerFactory?.CreateLogger(typeof(OpenAIChatCompletionService)));

this._core.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,80 @@ public static IServiceCollection AddOpenAIChatCompletion(this IServiceCollection
return services;
}

/// <summary>
/// Adds the Custom OpenAI chat completion service to the list.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="modelId">OpenAI model name, see https://platform.openai.com/docs/models</param>
/// <param name="endpoint">A Custom Message API compatible endpoint.</param>
/// <param name="apiKey">OpenAI API key, see https://platform.openai.com/account/api-keys</param>
/// <param name="orgId">OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
[Experimental("SKEXP0010")]
public static IServiceCollection AddOpenAIChatCompletion(
this IServiceCollection services,
string modelId,
Uri endpoint,
string? apiKey = null,
string? orgId = null,
string? serviceId = null)
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(modelId);

Func<IServiceProvider, object?, OpenAIChatCompletionService> factory = (serviceProvider, _) =>
new(modelId,
endpoint,
apiKey,
orgId,
HttpClientProvider.GetHttpClient(serviceProvider),
serviceProvider.GetService<ILoggerFactory>());

services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);

return services;
}

/// <summary>
/// Adds the Custom Endpoint OpenAI chat completion service to the list.
/// </summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="modelId">OpenAI model name, see https://platform.openai.com/docs/models</param>
/// <param name="endpoint">Custom OpenAI Compatible Message API endpoint</param>
/// <param name="apiKey">OpenAI API key, see https://platform.openai.com/account/api-keys</param>
/// <param name="orgId">OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="httpClient">The HttpClient to use with this service.</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
[Experimental("SKEXP0010")]
public static IKernelBuilder AddOpenAIChatCompletion(
this IKernelBuilder builder,
string modelId,
Uri endpoint,
string? apiKey,
string? orgId = null,
string? serviceId = null,
HttpClient? httpClient = null)
{
Verify.NotNull(builder);
Verify.NotNullOrWhiteSpace(modelId);

Func<IServiceProvider, object?, OpenAIChatCompletionService> factory = (serviceProvider, _) =>
new(modelId: modelId,
apiKey: apiKey,
endpoint: endpoint,
organization: orgId,
httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider),
loggerFactory: serviceProvider.GetService<ILoggerFactory>());

builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
builder.Services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);

return builder;
}

#endregion

#region Images
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ public OpenAITextEmbeddingGenerationService(
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
this._core = new(modelId, apiKey, organization, httpClient, loggerFactory?.CreateLogger(typeof(OpenAITextEmbeddingGenerationService)));
this._core = new(
modelId: modelId,
apiKey: apiKey,
organization: organization,
httpClient: httpClient,
logger: loggerFactory?.CreateLogger(typeof(OpenAITextEmbeddingGenerationService)));

this._core.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
}
Expand Down
Loading

0 comments on commit beef63c

Please sign in to comment.