Skip to content

Commit

Permalink
Some behavioral / perf fixes for Embedding<T> (#567)
Browse files Browse the repository at this point in the history
### Motivation and Context

Clean up some `Embedding<T>` behaviors / perf issues.

### Description

- IsSupported was more expensive than it needs to be: if it's changed to
just compare the types directly rather than searching a list, the JIT
can turn the entire operation into a JIT-time constant.
- The explicit cast to `ReadOnlySpan<T>` was cloning the array. It
doesn't need to do that.
- `default(Embedding<T>)` results in a struct that fails on a bunch of
operations with NullReferenceExceptions if you try to use it. We've
learned from experience with `ImmutableArray<T>` this is suboptimal.
I've made it functional and behave equivalent to Empty... in fact, Empty
is now just `default`.
  • Loading branch information
stephentoub committed Apr 26, 2023
1 parent a0aa0c5 commit 15dc234
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ internal OpenAIClientBase(ILogger? log = null, IDelegatingHandlerFactory? handle
"Embeddings not found");
}

return result.Embeddings.Select(e => new Embedding<float>(e.Values.ToArray())).ToList();
return result.Embeddings.Select(e => new Embedding<float>(e.Values)).ToList();
}
catch (Exception e) when (e is not AIException)
{
Expand Down
56 changes: 33 additions & 23 deletions dotnet/src/SemanticKernel.Abstractions/AI/Embeddings/Embedding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text.Json.Serialization;
Expand All @@ -21,16 +20,17 @@ namespace Microsoft.SemanticKernel.AI.Embeddings;
/// An empty <see cref="Embedding{TEmbedding}"/> instance.
/// </summary>
[SuppressMessage("Design", "CA1000:Do not declare static members on generic types", Justification = "Static empty struct instance.")]
public static Embedding<TEmbedding> Empty { get; } = new Embedding<TEmbedding>(Array.Empty<TEmbedding>());

/// <summary>
/// Initializes a new instance of the <see cref="Embedding{TEmbedding}"/> class that contains numeric elements copied from the specified collection.
/// </summary>
/// <exception cref="ArgumentException">Type <typeparamref name="TEmbedding"/> is unsupported.</exception>
/// <exception cref="ArgumentNullException">A <c>null</c> vector is passed in.</exception>
public Embedding()
: this(Array.Empty<TEmbedding>())
public static Embedding<TEmbedding> Empty
{
get
{
if (!IsSupported)
{
ThrowNotSupportedEmbedding();
}

return default;
}
}

/// <summary>
Expand All @@ -46,38 +46,47 @@ public Embedding(IEnumerable<TEmbedding> vector)

if (!IsSupported)
{
throw new NotSupportedException($"Embeddings do not support type '{typeof(TEmbedding).Name}'. "
+ $"Supported types include: [ {string.Join(", ", Embedding.SupportedTypes.Select(t => t.Name).ToList())} ]");
ThrowNotSupportedEmbedding();
}

// Create a local, protected copy
this._vector = vector.ToArray();
}

private static void ThrowNotSupportedEmbedding() =>
throw new NotSupportedException($"Embeddings do not support type '{typeof(TEmbedding).Name}'. Supported types include: [ Single, Double ]");

/// <summary>
/// Gets the vector as a <see cref="ReadOnlyCollection{T}"/>
/// Gets the vector as an <see cref="IEnumerable{TEmbedding}"/>
/// </summary>
[JsonPropertyName("vector")]
public IEnumerable<TEmbedding> Vector => this._vector.AsEnumerable();
public IEnumerable<TEmbedding> Vector => this._vector ?? Array.Empty<TEmbedding>();

/// <summary>
/// Gets a value that indicates whether <typeparamref name="TEmbedding"/> is supported.
/// </summary>
[JsonIgnore]
[SuppressMessage("Design", "CA1000:Do not declare static members on generic types", Justification = "Following 'IsSupported' pattern of System.Numerics.")]
public static bool IsSupported => Embedding.SupportedTypes.Contains(typeof(TEmbedding));
public static bool IsSupported => typeof(TEmbedding) == typeof(float) || typeof(TEmbedding) == typeof(double);

/// <summary>
/// <c>true</c> if the vector is empty.
/// </summary>
[JsonIgnore]
public bool IsEmpty => this._vector.Length == 0;
public bool IsEmpty
{
get
{
TEmbedding[]? vector = this._vector;
return vector is null || vector.Length == 0;
}
}

/// <summary>
/// The number of elements in the vector.
/// </summary>
[JsonIgnore]
public int Count => this._vector.Length;
public int Count => this._vector?.Length ?? 0;

/// <summary>
/// Gets the vector as a read-only span.
Expand All @@ -93,7 +102,7 @@ public ReadOnlySpan<TEmbedding> AsReadOnlySpan()
/// <returns>A hash code for the current object.</returns>
public override int GetHashCode()
{
return this._vector.GetHashCode();
return this._vector?.GetHashCode() ?? 0;
}

/// <summary>
Expand All @@ -103,7 +112,7 @@ public override int GetHashCode()
/// <returns><c>true</c> if the specified object is equal to the current object; otherwise, <c>false</c>.</returns>
public override bool Equals(object obj)
{
return (obj is Embedding<TEmbedding> other) && this.Equals(other);
return obj is Embedding<TEmbedding> other && this.Equals(other);
}

/// <summary>
Expand All @@ -113,7 +122,8 @@ public override bool Equals(object obj)
/// <returns>><c>true</c> if the specified object is equal to the current object; otherwise, <c>false</c>.</returns>
public bool Equals(Embedding<TEmbedding> other)
{
return this._vector.Equals(other._vector);
TEmbedding[]? vector = this._vector;
return vector is null ? other._vector is null : vector.Equals(other._vector);
}

/// <summary>
Expand Down Expand Up @@ -154,7 +164,7 @@ public bool Equals(Embedding<TEmbedding> other)
/// <remarks>A clone of the underlying data.</remarks>
public static explicit operator TEmbedding[](Embedding<TEmbedding> embedding)
{
return (TEmbedding[])embedding._vector.Clone();
return embedding._vector is null ? Array.Empty<TEmbedding>() : (TEmbedding[])embedding._vector.Clone();
}

/// <summary>
Expand All @@ -164,12 +174,12 @@ public bool Equals(Embedding<TEmbedding> other)
/// <remarks>A clone of the underlying data.</remarks>
public static explicit operator ReadOnlySpan<TEmbedding>(Embedding<TEmbedding> embedding)
{
return (TEmbedding[])embedding._vector.Clone();
return embedding.AsReadOnlySpan();
}

#region private ================================================================================

private readonly TEmbedding[] _vector;
private readonly TEmbedding[]? _vector;

#endregion
}
Expand Down
9 changes: 5 additions & 4 deletions dotnet/src/SemanticKernel.Abstractions/Memory/IMemoryStore.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -100,10 +101,10 @@ public interface IMemoryStore
Task RemoveBatchAsync(string collectionName, IEnumerable<string> keys, CancellationToken cancel = default);

/// <summary>
/// Gets the nearest matches to the <see cref="Embedding"/> of type <see cref="float"/>. Does not guarantee that the collection exists.
/// Gets the nearest matches to the <see cref="Embedding{Single}"/> of type <see cref="float"/>. Does not guarantee that the collection exists.
/// </summary>
/// <param name="collectionName">The name associated with a collection of embeddings.</param>
/// <param name="embedding">The <see cref="Embedding"/> to compare the collection's embeddings with.</param>
/// <param name="embedding">The <see cref="Embedding{Single}"/> to compare the collection's embeddings with.</param>
/// <param name="limit">The maximum number of similarity results to return.</param>
/// <param name="minRelevanceScore">The minimum relevance threshold for returned results.</param>
/// <param name="withEmbeddings">If true, the embeddings will be returned in the memory records.</param>
Expand All @@ -118,10 +119,10 @@ public interface IMemoryStore
CancellationToken cancel = default);

/// <summary>
/// Gets the nearest match to the <see cref="Embedding"/> of type <see cref="float"/>. Does not guarantee that the collection exists.
/// Gets the nearest match to the <see cref="Embedding{Single}"/> of type <see cref="float"/>. Does not guarantee that the collection exists.
/// </summary>
/// <param name="collectionName">The name associated with a collection of embeddings.</param>
/// <param name="embedding">The <see cref="Embedding"/> to compare the collection's embeddings with.</param>
/// <param name="embedding">The <see cref="Embedding{Single}"/> to compare the collection's embeddings with.</param>
/// <param name="minRelevanceScore">The minimum relevance threshold for returned results.</param>
/// <param name="withEmbedding">If true, the embedding will be returned in the memory record.</param>
/// <param name="cancel">Cancellation token</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

using System;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text.Json;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Diagnostics;
Expand All @@ -15,6 +17,54 @@ public class EmbeddingTests
private readonly float[] _vector = new float[] { 0, 3, -4 };
private readonly float[] _empty = Array.Empty<float>();

[Fact]
public void ItTreatsDefaultEmbeddingAsEmpty()
{
// Arrange
Embedding<float> target = default;

// Assert
Assert.True(target.IsEmpty);
Assert.Equal(0, target.Count);
Assert.Empty(target.Vector);
Assert.Same(Array.Empty<float>(), target.Vector);
Assert.Same(Array.Empty<float>(), (float[])target);
Assert.True(target.AsReadOnlySpan().IsEmpty);
Assert.True(((ReadOnlySpan<float>)target).IsEmpty);
Assert.True(target.Equals(Embedding<float>.Empty));
Assert.True(target.Equals(new Embedding<float>()));
Assert.True(target == Embedding<float>.Empty);
Assert.True(target == new Embedding<float>());
Assert.False(target != Embedding<float>.Empty);
Assert.Equal(0, target.GetHashCode());
}

[Fact]
public void ItThrowsFromCtorWithUnsupportedType()
{
// Assert
Assert.Throws<NotSupportedException>(() => new Embedding<int>(new int[] { 1, 2, 3 }));
Assert.Throws<NotSupportedException>(() => new Embedding<int>(Array.Empty<int>()));
}

[Fact]
public void ItThrowsFromEmptyWithUnsupportedType()
{
// Assert
Assert.Throws<NotSupportedException>(() => Embedding<int>.Empty);
}

[Fact]
public void ItAllowsUnsupportedTypesOnEachOperation()
{
// Arrange
Embedding<int> target = default;

// Act
Assert.True(target.IsEmpty);
Assert.Equal(0, target.Count);
}

[Fact]
public void ItThrowsWithNullVector()
{
Expand All @@ -31,6 +81,7 @@ public void ItCreatesEmptyEmbedding()
// Assert
Assert.Empty(target.Vector);
Assert.Equal(0, target.Count);
Assert.False(Embedding<int>.IsSupported);
}

[Fact]
Expand All @@ -39,10 +90,6 @@ public void ItCreatesExpectedEmbedding()
// Arrange
var target = new Embedding<float>(this._vector);

// Act
// TODO: copy is never used - bug?
var copy = target;

// Assert
Assert.True(target.Vector.SequenceEqual(this._vector));
}
Expand All @@ -60,4 +107,19 @@ public void ItSerializesEmbedding()
// Assert
Assert.True(copy.Vector.SequenceEqual(this._vector));
}

[Fact]
public void ItDoesntCopyVectorWhenCastingToSpan()
{
// Arrange
var target = new Embedding<float>(this._vector);

// Act
ReadOnlySpan<float> span1 = target.AsReadOnlySpan();
ReadOnlySpan<float> span2 = (ReadOnlySpan<float>)target;

// Assert
Assert.False(Unsafe.AreSame(ref MemoryMarshal.GetReference(span1), ref MemoryMarshal.GetArrayDataReference(this._vector)));
Assert.True(Unsafe.AreSame(ref MemoryMarshal.GetReference(span1), ref MemoryMarshal.GetReference(span2)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ private async Task<IList<Embedding<float>>> ExecuteEmbeddingRequestAsync(IList<s

var embeddingResponse = JsonSerializer.Deserialize<TextEmbeddingResponse>(body);

return embeddingResponse?.Embeddings?.Select(l => new Embedding<float>(l.Embedding.ToArray())).ToList()!;
return embeddingResponse?.Embeddings?.Select(l => new Embedding<float>(l.Embedding!)).ToList()!;
}
catch (Exception e) when (e is not AIException && !e.IsCriticalException())
{
Expand Down

0 comments on commit 15dc234

Please sign in to comment.