Skip to content

Commit

Permalink
Add support for creating embeddings with OpenAI API. (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcominerva committed Oct 9, 2023
2 parents 26a1f4c + 4c3ebf6 commit fc3ea4f
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 0 deletions.
48 changes: 48 additions & 0 deletions src/ChatGptNet/ChatGptClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ public async Task<Guid> SetupAsync(Guid conversationId, string message, Cancella
return conversationId;
}

public async Task<ChatGptEmbeddingResponse> CreateEmbeddingAsync(string message, string[]? messages = null, string? model = null, CancellationToken cancellationToken = default)
{
if (message == null && (messages == null || messages.Length < 1))
throw new ArgumentException($"Either ${nameof(message)} or {nameof(messages)} must be supplied");
var request = CreateEmbeddingsRequest(message, messages, model);
var requestUri = options.ServiceConfiguration.GetEmbeddingsEndpoint(model ?? options.DefaultModel);
using var httpResponse = await httpClient.PostAsJsonAsync(requestUri, request, jsonSerializerOptions, cancellationToken);

var response = await httpResponse.Content.ReadFromJsonAsync<ChatGptEmbeddingResponse>(jsonSerializerOptions, cancellationToken: cancellationToken);
NormalizeEmbeddingResponse(httpResponse, response!, model ?? options.DefaultModel);

if (!response!.IsSuccessful && options.ThrowExceptionOnError)
{
throw new ChatGptException(response.Error, httpResponse.StatusCode);
}

return response;
}

public async Task<ChatGptResponse> AskAsync(Guid conversationId, string message, ChatGptFunctionParameters? functionParameters = null, ChatGptParameters? parameters = null, string? model = null, bool addToConversationHistory = true, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(message);
Expand Down Expand Up @@ -301,6 +320,13 @@ private async Task<IList<ChatGptMessage>> CreateMessageListAsync(Guid conversati
return messages;
}

private ChatGptEmbeddingsRequest CreateEmbeddingsRequest(string? message, string[]? messages, string? model = null)
=> new()
{
Model = model ?? options.DefaultModel,
Input = message != null ? new string[] { message } : messages
};

private ChatGptRequest CreateRequest(IEnumerable<ChatGptMessage> messages, ChatGptFunctionParameters? functionParameters, bool stream, ChatGptParameters? parameters, string? model)
=> new()
{
Expand Down Expand Up @@ -379,4 +405,26 @@ private static void NormalizeResponse(HttpResponseMessage httpResponse, ChatGptR
response.Error.StatusCode = (int)httpResponse.StatusCode;
}
}

private static void NormalizeEmbeddingResponse(HttpResponseMessage httpResponse, ChatGptEmbeddingResponse response, string? model)
{
if (string.IsNullOrWhiteSpace(response.Model) && model is not null)
{
response.Model = model;
}

if (!httpResponse.IsSuccessStatusCode && response.Error is null)
{
response.Error = new ChatGptError
{
Message = httpResponse.ReasonPhrase ?? httpResponse.StatusCode.ToString(),
Code = ((int)httpResponse.StatusCode).ToString()
};
}

if (response.Error is not null)
{
response.Error.StatusCode = (int)httpResponse.StatusCode;
}
}
}
22 changes: 22 additions & 0 deletions src/ChatGptNet/Models/ChatGptEmbedding.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
namespace ChatGptNet.Models;

/// <summary>
/// Represents an embedding.
/// </summary>
public class ChatGptEmbedding
{
/// <summary>
/// Gets or sets the index of the embedding.
/// </summary>
public int Index { get; set; } = 0;

/// <summary>
/// Gets or sets the source object for this response.
/// </summary>
public string Object { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the embedding data
/// </summary>
public float[] Embedding { get; set; } = Array.Empty<float>();
}
48 changes: 48 additions & 0 deletions src/ChatGptNet/Models/ChatGptEmbeddingResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using System.Diagnostics.CodeAnalysis;

namespace ChatGptNet.Models;

/// <summary>
/// Represents an embedding response.
/// </summary>
public class ChatGptEmbeddingResponse
{
/// <summary>
/// Gets or sets the Id of the response.
/// </summary>
public string Id { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the source object for this response.
/// </summary>
public string Object { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the model name that has been used to generate the response.
/// </summary>
public string Model { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the error occurred during the chat completion execution, if any.
/// </summary>
public ChatGptError? Error { get; set; }

/// <summary>
/// Gets or sets information about token usage.
/// </summary>
/// <remarks>
/// The <see cref="Usage"/> property is always <see langword="null"/> when requesting response streaming with <see cref="ChatGptClient.AskStreamAsync(Guid, string, ChatGptParameters?, string?, bool, CancellationToken)"/>.
/// </remarks>
public ChatGptUsage? Usage { get; set; }

/// <summary>
/// Gets a value that determines if the response was successful.
/// </summary>
[MemberNotNullWhen(false, nameof(Error))]
public bool IsSuccessful => Error is null;

/// <summary>
/// Array of Embedding objects created by ChatGpt
/// </summary>
public ChatGptEmbedding[] Data { get; set; } = Array.Empty<ChatGptEmbedding>();
}
23 changes: 23 additions & 0 deletions src/ChatGptNet/Models/ChatGptEmbeddingsRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
namespace ChatGptNet.Models;

/// <summary>
/// Represents a request for create embeddings request.
/// </summary>
/// <remarks>
/// See <see href="https://platform.openai.com/docs/api-reference/embeddings/create">Create embeddings (OpenAI)</see> or <see href="https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console">Embeddings basics (Azure)</see> for more information.
/// </remarks>
internal class ChatGptEmbeddingsRequest
{
/// <summary>
/// Gets or sets the ID of the model to use.
/// </summary>
public string? Model { get; set; }

/// <summary>
/// Gets or sets the messages array to generate embeddings for.
/// </summary>
/// <seealso cref="Input"/>
public string[]? Input { get; set; }

public string? User { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ public override Uri GetServiceEndpoint(string? modelName)
return endpoint;
}

/// <inheritdoc />
public override Uri GetEmbeddingsEndpoint(string? modelName = null) => throw new NotImplementedException();

/// <inheritdoc />
public override IDictionary<string, string?> GetRequestHeaders()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ internal abstract class ChatGptServiceConfiguration
/// <returns>The <see cref="Uri"/> of the service.</returns>
public abstract Uri GetServiceEndpoint(string? modelName = null);

/// <summary>
/// Returns the <see cref="Uri"/> that creates embeddings.
/// </summary>
/// <param name="modelName">The name of the model for embeddings.</param>
/// <returns>The <see cref="Uri"/> of the service.</returns>
public abstract Uri GetEmbeddingsEndpoint(string? modelName = null);

/// <summary>
/// Returns the headers that are required by the service to complete the request.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ public OpenAIChatGptServiceConfiguration(IConfiguration configuration)
/// <inheritdoc />
public override Uri GetServiceEndpoint(string? _) => new("https://api.openai.com/v1/chat/completions");

/// <inheritdoc />
public override Uri GetEmbeddingsEndpoint(string? _) => new("https://api.openai.com/v1/embeddings");

/// <inheritdoc />
public override IDictionary<string, string?> GetRequestHeaders()
{
Expand Down

0 comments on commit fc3ea4f

Please sign in to comment.