Skip to content

Commit

Permalink
Avoid boxing when comparing for equality
Browse files Browse the repository at this point in the history
  • Loading branch information
jnyrup committed May 1, 2021
1 parent 821e5e1 commit 2d865aa
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 89 deletions.
48 changes: 20 additions & 28 deletions Src/FluentAssertions/Collections/GenericCollectionAssertions.cs
Expand Up @@ -869,10 +869,11 @@ public AndConstraint<TAssertions> ContainInOrder(params T[] expected)
IList<T> expectedItems = expected.ConvertOrCastToList();
IList<T> actualItems = Subject.ConvertOrCastToList();

Func<T, T, bool> areSameOrEqual = ObjectExtensions.GetComparer<T>();
for (int index = 0; index < expectedItems.Count; index++)
{
T expectedItem = expectedItems[index];
actualItems = actualItems.SkipWhile(actualItem => !actualItem.IsSameOrEqualTo(expectedItem)).ToArray();
actualItems = actualItems.SkipWhile(actualItem => !areSameOrEqual(actualItem, expectedItem)).ToArray();
if (actualItems.Any())
{
actualItems = actualItems.Skip(1).ToArray();
Expand Down Expand Up @@ -1037,7 +1038,7 @@ public AndConstraint<TAssertions> EndWith(IEnumerable<T> expectation, string bec
return EndWith(null, because, becauseArgs);
}

AssertCollectionEndsWith(Subject, expectation.ConvertOrCastToCollection(), EqualityComparer<T>.Default.Equals, because, becauseArgs);
AssertCollectionEndsWith(Subject, expectation.ConvertOrCastToCollection(), (a, b) => EqualityComparer<T>.Default.Equals(a, b), because, becauseArgs);
return new AndConstraint<TAssertions>((TAssertions)this);
}

Expand Down Expand Up @@ -1083,7 +1084,7 @@ public AndConstraint<TAssertions> EndWith(IEnumerable<T> expectation, string bec
/// </param>
public AndConstraint<TAssertions> EndWith(T element, string because = "", params object[] becauseArgs)
{
AssertCollectionEndsWith(Subject, new[] { element }, (a, b) => a.IsSameOrEqualTo(b), because, becauseArgs);
AssertCollectionEndsWith(Subject, new[] { element }, ObjectExtensions.GetComparer<T>(), because, becauseArgs);
return new AndConstraint<TAssertions>((TAssertions)this);
}

Expand All @@ -1094,9 +1095,7 @@ public AndConstraint<TAssertions> EndWith(T element, string because = "", params
/// <param name="elements">A params array with the expected elements.</param>
public AndConstraint<TAssertions> Equal(params T[] elements)
{
Func<T, T, bool> comparer = GetComparer();

AssertSubjectEquality(elements, comparer, string.Empty);
AssertSubjectEquality(elements, ObjectExtensions.GetComparer<T>(), string.Empty);

return new AndConstraint<TAssertions>((TAssertions)this);
}
Expand Down Expand Up @@ -1140,7 +1139,7 @@ public AndConstraint<TAssertions> Equal(params T[] elements)
/// </param>
public AndConstraint<TAssertions> Equal(IEnumerable<T> expected, string because = "", params object[] becauseArgs)
{
AssertSubjectEquality(expected, (s, e) => s.IsSameOrEqualTo(e), because, becauseArgs);
AssertSubjectEquality(expected, ObjectExtensions.GetComparer<T>(), because, becauseArgs);

return new AndConstraint<TAssertions>((TAssertions)this);
}
Expand Down Expand Up @@ -1362,7 +1361,7 @@ public AndConstraint<TAssertions> HaveCountLessThan(int expected, string because
actual = Subject.ElementAt(index);

Execute.Assertion
.ForCondition(actual.IsSameOrEqualTo(element))
.ForCondition(ObjectExtensions.GetComparer<T>()(actual, element))
.BecauseOf(because, becauseArgs)
.FailWith("Expected {0} at index {1}{reason}, but found {2}.", element, index, actual);
}
Expand Down Expand Up @@ -1400,7 +1399,7 @@ public AndConstraint<TAssertions> HaveElementPreceding(T successor, T expectatio
.FailWith("but found nothing.")
.Then
.Given(() => PredecessorOf(successor, Subject))
.ForCondition(predecessor => predecessor.IsSameOrEqualTo(expectation))
.ForCondition(predecessor => ObjectExtensions.GetComparer<T>()(predecessor, expectation))
.FailWith("but found {0}.", predecessor => predecessor)
.Then
.ClearExpectation();
Expand Down Expand Up @@ -1432,7 +1431,7 @@ public AndConstraint<TAssertions> HaveElementSucceeding(T predecessor, T expecta
.FailWith("but found nothing.")
.Then
.Given(() => SuccessorOf(predecessor, Subject))
.ForCondition(successor => successor.IsSameOrEqualTo(expectation))
.ForCondition(successor => ObjectExtensions.GetComparer<T>()(successor, expectation))
.FailWith("but found {0}.", successor => successor)
.Then
.ClearExpectation();
Expand Down Expand Up @@ -2091,7 +2090,7 @@ public AndConstraint<TAssertions> NotContain(IEnumerable<T> unexpected, string b
int index = 0;
foreach (T actualItem in Subject)
{
var context = new EquivalencyValidationContext(Node.From<TExpectation>(CallerIdentifier.DetermineCallerIdentity))
var context = new EquivalencyValidationContext(Node.From<TExpectation>(() => CallerIdentifier.DetermineCallerIdentity()))
{
Subject = actualItem,
Expectation = unexpected,
Expand Down Expand Up @@ -2188,14 +2187,15 @@ public AndConstraint<TAssertions> NotContainInOrder(params T[] unexpected)
}

var actualItemsSkipped = 0;
Func<T, T, bool> areSameOrEqual = ObjectExtensions.GetComparer<T>();
for (int index = 0; index < unexpectedItems.Count; index++)
{
T unexpectedItem = unexpectedItems[index];

actualItems = actualItems.SkipWhile(actualItem =>
{
actualItemsSkipped++;
return !actualItem.IsSameOrEqualTo(unexpectedItem);
return !areSameOrEqual(actualItem, unexpectedItem);
}).ToArray();

if (actualItems.Any())
Expand Down Expand Up @@ -2282,7 +2282,7 @@ public AndConstraint<TAssertions> NotContainNulls(string because = "", params ob
}

int[] indices = Subject
.Select((item, index) => new { Item = item, Index = index })
.Select((item, index) => (Item: item, Index: index))
.Where(e => e.Item is null)
.Select(e => e.Index)
.ToArray();
Expand Down Expand Up @@ -2789,7 +2789,7 @@ public AndConstraint<TAssertions> StartWith(IEnumerable<T> expectation, string b
return StartWith(null, because, becauseArgs);
}

AssertCollectionStartsWith(Subject, expectation.ConvertOrCastToCollection(), EqualityComparer<T>.Default.Equals, because, becauseArgs);
AssertCollectionStartsWith(Subject, expectation.ConvertOrCastToCollection(), (a, b) => EqualityComparer<T>.Default.Equals(a, b), because, becauseArgs);
return new AndConstraint<TAssertions>((TAssertions)this);
}

Expand Down Expand Up @@ -2835,7 +2835,7 @@ public AndConstraint<TAssertions> StartWith(IEnumerable<T> expectation, string b
/// </param>
public AndConstraint<TAssertions> StartWith(T element, string because = "", params object[] becauseArgs)
{
AssertCollectionStartsWith(Subject, new[] { element }, (a, b) => a.IsSameOrEqualTo(b), because, becauseArgs);
AssertCollectionStartsWith(Subject, new[] { element }, ObjectExtensions.GetComparer<T>(), because, becauseArgs);
return new AndConstraint<TAssertions>((TAssertions)this);
}

Expand Down Expand Up @@ -3009,16 +3009,6 @@ protected static IEnumerable<TExpectation> RepeatAsManyAs<TExpectation>(TExpecta
.ClearExpectation();
}

private static Func<T, T, bool> GetComparer()
{
if (typeof(T).IsValueType)
{
return (T s, T e) => s.Equals(e);
}

return (T s, T e) => Equals(s, e);
}

private static string GetExpressionOrderString<TSelector>(Expression<Func<T, TSelector>> propertyExpression)
{
string orderString = propertyExpression.GetMemberPath().ToString();
Expand All @@ -3028,7 +3018,7 @@ private static string GetExpressionOrderString<TSelector>(Expression<Func<T, TSe
return orderString;
}

private static Type GetType(object o)
private static Type GetType<TType>(TType o)
{
return o is Type t ? t : o.GetType();
}
Expand Down Expand Up @@ -3165,10 +3155,11 @@ private bool IsValidProperty<TSelector>(Expression<Func<T, TSelector>> propertyE
: actualItems.OrderByDescending(item => item, comparer);

T[] orderedItems = ordering.ToArray();
Func<T, T, bool> areSameOrEqual = ObjectExtensions.GetComparer<T>();

for (int index = 0; index < orderedItems.Length; index++)
{
if (!actualItems[index].IsSameOrEqualTo(orderedItems[index]))
if (!areSameOrEqual(actualItems[index], orderedItems[index]))
{
Execute.Assertion
.BecauseOf(because, becauseArgs)
Expand Down Expand Up @@ -3209,8 +3200,9 @@ private AndConstraint<TAssertions> NotBeInOrder(IComparer<T> comparer, SortOrder
? actualItems.OrderBy(item => item, comparer).ToArray()
: actualItems.OrderByDescending(item => item, comparer).ToArray();

Func<T, T, bool> areSameOrEqual = ObjectExtensions.GetComparer<T>();
bool itemsAreUnordered = actualItems
.Where((actualItem, index) => !actualItem.IsSameOrEqualTo(orderedItems[index]))
.Where((actualItem, index) => !areSameOrEqual(actualItem, orderedItems[index]))
.Any();

if (!itemsAreUnordered)
Expand Down
21 changes: 13 additions & 8 deletions Src/FluentAssertions/Collections/GenericDictionaryAssertions.cs
Expand Up @@ -85,10 +85,11 @@ public GenericDictionaryAssertions(TCollection keyValuePairs)
additionalKeys);
}

Func<TValue, TValue, bool> areSameOrEqual = ObjectExtensions.GetComparer<TValue>();
foreach (var key in expectedKeys)
{
Execute.Assertion
.ForCondition(GetValue(Subject, key).IsSameOrEqualTo(GetValue(expected, key)))
.ForCondition(areSameOrEqual(GetValue(Subject, key), GetValue(expected, key)))
.BecauseOf(because, becauseArgs)
.FailWith("Expected {context:dictionary} to be equal to {0}{reason}, but {1} differs at key {2}.",
expected, Subject, key);
Expand Down Expand Up @@ -135,9 +136,10 @@ public GenericDictionaryAssertions(TCollection keyValuePairs)
IEnumerable<TKey> missingKeys = unexpectedKeys.Except(subjectKeys);
IEnumerable<TKey> additionalKeys = subjectKeys.Except(unexpectedKeys);

Func<TValue, TValue, bool> areSameOrEqual = ObjectExtensions.GetComparer<TValue>();
bool foundDifference = missingKeys.Any()
|| additionalKeys.Any()
|| subjectKeys.Any(key => !GetValue(Subject, key).IsSameOrEqualTo(GetValue(unexpected, key)));
|| subjectKeys.Any(key => !areSameOrEqual(GetValue(Subject, key), GetValue(unexpected, key)));

if (!foundDifference)
{
Expand Down Expand Up @@ -209,7 +211,7 @@ public GenericDictionaryAssertions(TCollection keyValuePairs)

EquivalencyAssertionOptions<TExpectation> options = config(AssertionOptions.CloneDefaults<TExpectation>());

var context = new EquivalencyValidationContext(Node.From<TExpectation>(CallerIdentifier.DetermineCallerIdentity))
var context = new EquivalencyValidationContext(Node.From<TExpectation>(() => CallerIdentifier.DetermineCallerIdentity()))
{
Subject = Subject,
Expectation = expectation,
Expand Down Expand Up @@ -521,7 +523,7 @@ public AndConstraint<TAssertions> ContainValues(params TValue[] expected)
IEnumerable<TValue> first, IEnumerable<TValue> second)
{
var secondSet = new HashSet<TValue>(second);
return first.Where(secondSet.Contains);
return first.Where(e => secondSet.Contains(e));
}

#endregion
Expand Down Expand Up @@ -693,7 +695,8 @@ public AndConstraint<TAssertions> Contain(params KeyValuePair<TKey, TValue>[] ex
}
}

KeyValuePair<TKey, TValue>[] keyValuePairsNotSameOrEqualInSubject = expectedKeyValuePairs.Where(keyValuePair => !GetValue(Subject, keyValuePair.Key).IsSameOrEqualTo(keyValuePair.Value)).ToArray();
Func<TValue, TValue, bool> areSameOrEqual = ObjectExtensions.GetComparer<TValue>();
KeyValuePair<TKey, TValue>[] keyValuePairsNotSameOrEqualInSubject = expectedKeyValuePairs.Where(keyValuePair => !areSameOrEqual(GetValue(Subject, keyValuePair.Key), keyValuePair.Value)).ToArray();

if (keyValuePairsNotSameOrEqualInSubject.Any())
{
Expand Down Expand Up @@ -765,8 +768,9 @@ public AndConstraint<TAssertions> Contain(params KeyValuePair<TKey, TValue>[] ex

if (TryGetValue(Subject, key, out TValue actual))
{
Func<TValue, TValue, bool> areSameOrEqual = ObjectExtensions.GetComparer<TValue>();
Execute.Assertion
.ForCondition(actual.IsSameOrEqualTo(value))
.ForCondition(areSameOrEqual(actual, value))
.BecauseOf(because, becauseArgs)
.FailWith("Expected {context:dictionary} to contain value {0} at key {1}{reason}, but found {2}.", value, key, actual);
}
Expand Down Expand Up @@ -833,8 +837,9 @@ public AndConstraint<TAssertions> NotContain(params KeyValuePair<TKey, TValue>[]

if (keyValuePairsFound.Any())
{
Func<TValue, TValue, bool> areSameOrEqual = ObjectExtensions.GetComparer<TValue>();
KeyValuePair<TKey, TValue>[] keyValuePairsSameOrEqualInSubject = keyValuePairsFound
.Where(keyValuePair => GetValue(Subject, keyValuePair.Key).IsSameOrEqualTo(keyValuePair.Value)).ToArray();
.Where(keyValuePair => areSameOrEqual(GetValue(Subject, keyValuePair.Key), keyValuePair.Value)).ToArray();

if (keyValuePairsSameOrEqualInSubject.Any())
{
Expand Down Expand Up @@ -906,7 +911,7 @@ public AndConstraint<TAssertions> NotContain(params KeyValuePair<TKey, TValue>[]
if (TryGetValue(Subject, key, out TValue actual))
{
Execute.Assertion
.ForCondition(!actual.IsSameOrEqualTo(value))
.ForCondition(!ObjectExtensions.GetComparer<TValue>()(actual, value))
.BecauseOf(because, becauseArgs)
.FailWith("Expected {context:dictionary} not to contain value {0} at key {1}{reason}, but found it anyhow.", value, key);
}
Expand Down
35 changes: 24 additions & 11 deletions Src/FluentAssertions/Common/DictionaryHelpers.cs
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Linq;

namespace FluentAssertions.Common
Expand Down Expand Up @@ -34,8 +35,14 @@ internal static class DictionaryHelpers
{
IDictionary<TKey, TValue> dictionary => dictionary.ContainsKey(key),
IReadOnlyDictionary<TKey, TValue> readOnlyDictionary => readOnlyDictionary.ContainsKey(key),
_ => collection.Any(kvp => kvp.Key.IsSameOrEqualTo(key)),
_ => ContainsKey(collection, key),
};

static bool ContainsKey(TCollection collection, TKey key)
{
Func<TKey, TKey, bool> areSameOrEqual = ObjectExtensions.GetComparer<TKey>();
return collection.Any(kvp => areSameOrEqual(kvp.Key, key));
}
}

public static bool TryGetValue<TCollection, TKey, TValue>(this TCollection collection, TKey key, out TValue value)
Expand All @@ -45,16 +52,16 @@ internal static class DictionaryHelpers
{
IDictionary<TKey, TValue> dictionary => dictionary.TryGetValue(key, out value),
IReadOnlyDictionary<TKey, TValue> readOnlyDictionary => readOnlyDictionary.TryGetValue(key, out value),
_ => TryGetValueInternal(collection, key, out value),
_ => TryGetValue(collection, key, out value),
};
}

private static bool TryGetValueInternal<TCollection, TKey, TValue>(this TCollection collection, TKey key, out TValue value)
where TCollection : IEnumerable<KeyValuePair<TKey, TValue>>
{
KeyValuePair<TKey, TValue> matchingPair = collection.FirstOrDefault(kvp => kvp.Key.IsSameOrEqualTo(key));
value = matchingPair.Value;
return matchingPair.Equals(default(KeyValuePair<TKey, TValue>));
static bool TryGetValue(TCollection collection, TKey key, out TValue value)
{
Func<TKey, TKey, bool> areSameOrEqual = ObjectExtensions.GetComparer<TKey>();
KeyValuePair<TKey, TValue> matchingPair = collection.FirstOrDefault(kvp => areSameOrEqual(kvp.Key, key));
value = matchingPair.Value;
return matchingPair.Equals(default(KeyValuePair<TKey, TValue>));
}
}

public static TValue GetValue<TCollection, TKey, TValue>(this TCollection collection, TKey key)
Expand All @@ -64,8 +71,14 @@ internal static class DictionaryHelpers
{
IDictionary<TKey, TValue> dictionary => dictionary[key],
IReadOnlyDictionary<TKey, TValue> readOnlyDictionary => readOnlyDictionary[key],
_ => collection.First(kvp => kvp.Key.IsSameOrEqualTo(key)).Value,
_ => GetValue(collection, key),
};

static TValue GetValue(TCollection collection, TKey key)
{
Func<TKey, TKey, bool> areSameOrEqual = ObjectExtensions.GetComparer<TKey>();
return collection.First(kvp => areSameOrEqual(kvp.Key, key)).Value;
}
}
}
}

0 comments on commit 2d865aa

Please sign in to comment.