/
OpenAITextEmbeddingGenerationService.cs
71 lines (63 loc) · 2.74 KB
/
OpenAITextEmbeddingGenerationService.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Services;
namespace Microsoft.SemanticKernel.Connectors.OpenAI;
/// <summary>
/// OpenAI text embedding service.
/// </summary>
[Experimental("SKEXP0011")]
public sealed class OpenAITextEmbeddingGenerationService : ITextEmbeddingGenerationService
{
private readonly OpenAIClientCore _core;
/// <summary>
/// Create an instance of the OpenAI text embedding connector
/// </summary>
/// <param name="modelId">Model name</param>
/// <param name="apiKey">OpenAI API Key</param>
/// <param name="organization">OpenAI Organization Id (usually optional)</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
public OpenAITextEmbeddingGenerationService(
string modelId,
string apiKey,
string? organization = null,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
this._core = new(modelId, apiKey, organization, httpClient, loggerFactory?.CreateLogger(typeof(OpenAITextEmbeddingGenerationService)));
this._core.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
}
/// <summary>
/// Create an instance of the OpenAI text embedding connector
/// </summary>
/// <param name="modelId">Model name</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/> for HTTP requests.</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
public OpenAITextEmbeddingGenerationService(
string modelId,
OpenAIClient openAIClient,
ILoggerFactory? loggerFactory = null)
{
this._core = new(modelId, openAIClient, loggerFactory?.CreateLogger(typeof(OpenAITextEmbeddingGenerationService)));
this._core.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
}
/// <inheritdoc/>
public IReadOnlyDictionary<string, object?> Attributes => this._core.Attributes;
/// <inheritdoc/>
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
IList<string> data,
Kernel? kernel = null,
CancellationToken cancellationToken = default)
{
this._core.LogActionDetails();
return this._core.GetEmbeddingsAsync(data, kernel, cancellationToken);
}
}