Skip to content

Commit

Permalink
Make assert.xunit aot-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
agocke committed Mar 26, 2023
1 parent 1733ac1 commit dc6e247
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 178 deletions.
1 change: 1 addition & 0 deletions src/tests/Common/xunit/assert.xunit/Comparers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using Xunit.Sdk;

namespace Xunit
Expand Down
6 changes: 5 additions & 1 deletion src/tests/Common/xunit/assert.xunit/DXUnit.Assert.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
<!-- dotnet/runtime fork of xunit.assert, with modifications for AOT-compatibilty -->
<PropertyGroup>
<OutputType>Library</OutputType>
<TargetFrameworks>$(NetCoreAppMinimum)</TargetFrameworks>
<TargetFrameworks>$(NetCoreAppCurrent);$(NetCoreAppMinimum)</TargetFrameworks>
<DefineConstants>$(DefineConstants);XUNIT_NULLABLE</DefineConstants>

<IsTrimmable>true</IsTrimmable>
<EnableAotAnalyzer>true</EnableAotAnalyzer>
<EnableSingleFileAnalyzer>true</EnableSingleFileAnalyzer>

<!-- Baselining warnings -->
<NoWarn>$(NoWarn);SA1400;CA1852;CA1859;CA2007;SA1121;CA2249;CA1845;CA1822;CA1823;IDE0020;IDE0054;IDE0031;IDE0059;CA1510;CA1805;CA1825;IDE0036;IDE0074</NoWarn>
</PropertyGroup>
Expand Down
6 changes: 6 additions & 0 deletions src/tests/Common/xunit/assert.xunit/Sdk/ArgumentFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Reflection;
Expand Down Expand Up @@ -225,6 +226,11 @@ static string Format(
static string FormatComplexValue(
object value,
int depth,
[DynamicallyAccessedMembers(
DynamicallyAccessedMemberTypes.PublicProperties
| DynamicallyAccessedMemberTypes.NonPublicProperties
| DynamicallyAccessedMemberTypes.PublicFields
| DynamicallyAccessedMemberTypes.NonPublicFields)]
Type type)
{
if (depth == MAX_DEPTH)
Expand Down
173 changes: 2 additions & 171 deletions src/tests/Common/xunit/assert.xunit/Sdk/AssertEqualityComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#if XUNIT_NULLABLE
using System.Diagnostics.CodeAnalysis;
using System.Threading.Tasks;
#endif

namespace Xunit.Sdk
Expand Down Expand Up @@ -116,11 +117,6 @@ public bool Equals(
if (dictionariesEqual.HasValue)
return dictionariesEqual.GetValueOrDefault();

// Sets?
var setsEqual = CheckIfSetsAreEqual(x, y, typeInfo);
if (setsEqual.HasValue)
return setsEqual.GetValueOrDefault();

// Enumerable?
var enumerablesEqual = CheckIfEnumerablesAreEqual(x, y, out mismatchIndex);
if (enumerablesEqual.HasValue)
Expand Down Expand Up @@ -150,48 +146,9 @@ public bool Equals(

// Implements IStructuralEquatable?
var structuralEquatable = x as IStructuralEquatable;
if (structuralEquatable != null && structuralEquatable.Equals(y, new TypeErasedEqualityComparer(innerComparerFactory())))
if (structuralEquatable != null && structuralEquatable.Equals(y, EqualityComparer<T>.Default))
return true;

// Implements IEquatable<typeof(y)>?
var iequatableY = typeof(IEquatable<>).MakeGenericType(y.GetType()).GetTypeInfo();
if (iequatableY.IsAssignableFrom(x.GetType().GetTypeInfo()))
{
var equalsMethod = iequatableY.GetDeclaredMethod(nameof(IEquatable<T>.Equals));
if (equalsMethod == null)
return false;

#if XUNIT_NULLABLE
return equalsMethod.Invoke(x, new object[] { y }) is true;
#else
return (bool)equalsMethod.Invoke(x, new object[] { y });
#endif
}

// Implements IComparable<typeof(y)>?
var icomparableY = typeof(IComparable<>).MakeGenericType(y.GetType()).GetTypeInfo();
if (icomparableY.IsAssignableFrom(x.GetType().GetTypeInfo()))
{
var compareToMethod = icomparableY.GetDeclaredMethod(nameof(IComparable<T>.CompareTo));
if (compareToMethod == null)
return false;

try
{
#if XUNIT_NULLABLE
return compareToMethod.Invoke(x, new object[] { y }) is 0;
#else
return (int)compareToMethod.Invoke(x, new object[] { y }) == 0;
#endif
}
catch
{
// Some implementations of IComparable.CompareTo throw exceptions in
// certain situations, such as if x can't compare against y.
// If this happens, just swallow up the exception and continue comparing.
}
}

// Last case, rely on object.Equals
return object.Equals(x, y);
}
Expand Down Expand Up @@ -296,136 +253,10 @@ public bool Equals(
return dictionaryYKeys.Count == 0;
}

#if XUNIT_NULLABLE
static MethodInfo? s_compareTypedSetsMethod;
#else
static MethodInfo s_compareTypedSetsMethod;
#endif

bool? CheckIfSetsAreEqual(
#if XUNIT_NULLABLE
[AllowNull] T x,
[AllowNull] T y,
#else
T x,
T y,
#endif
TypeInfo typeInfo)
{
if (!IsSet(typeInfo))
return null;

var enumX = x as IEnumerable;
var enumY = y as IEnumerable;
if (enumX == null || enumY == null)
return null;

Type elementType;
if (typeof(T).GenericTypeArguments.Length != 1)
elementType = typeof(object);
else
elementType = typeof(T).GenericTypeArguments[0];

if (s_compareTypedSetsMethod == null)
{
s_compareTypedSetsMethod = GetType().GetTypeInfo().GetDeclaredMethod(nameof(CompareTypedSets));
if (s_compareTypedSetsMethod == null)
return false;
}

var method = s_compareTypedSetsMethod.MakeGenericMethod(new Type[] { elementType });
#if XUNIT_NULLABLE
return method.Invoke(this, new object[] { enumX, enumY }) is true;
#else
return (bool)method.Invoke(this, new object[] { enumX, enumY });
#endif
}

bool CompareTypedSets<R>(
IEnumerable enumX,
IEnumerable enumY)
{
var setX = new HashSet<R>(enumX.Cast<R>());
var setY = new HashSet<R>(enumY.Cast<R>());

return setX.SetEquals(setY);
}

bool IsSet(TypeInfo typeInfo) =>
typeInfo
.ImplementedInterfaces
.Select(i => i.GetTypeInfo())
.Where(ti => ti.IsGenericType)
.Select(ti => ti.GetGenericTypeDefinition())
.Contains(typeof(ISet<>).GetGenericTypeDefinition());

/// <inheritdoc/>
public int GetHashCode(T obj)
{
throw new NotImplementedException();
}

private class TypeErasedEqualityComparer : IEqualityComparer
{
readonly IEqualityComparer innerComparer;

public TypeErasedEqualityComparer(IEqualityComparer innerComparer)
{
this.innerComparer = innerComparer;
}

#if XUNIT_NULLABLE
static MethodInfo? s_equalsMethod;
#else
static MethodInfo s_equalsMethod;
#endif

public new bool Equals(
#if XUNIT_NULLABLE
object? x,
object? y)
#else
object x,
object y)
#endif
{
if (x == null)
return y == null;
if (y == null)
return false;

// Delegate checking of whether two objects are equal to AssertEqualityComparer.
// To get the best result out of AssertEqualityComparer, we attempt to specialize the
// comparer for the objects that we are checking.
// If the objects are the same, great! If not, assume they are objects.
// This is more naive than the C# compiler which tries to see if they share any interfaces
// etc. but that's likely overkill here as AssertEqualityComparer<object> is smart enough.
Type objectType = x.GetType() == y.GetType() ? x.GetType() : typeof(object);

// Lazily initialize and cache the EqualsGeneric<U> method.
if (s_equalsMethod == null)
{
s_equalsMethod = typeof(TypeErasedEqualityComparer).GetTypeInfo().GetDeclaredMethod(nameof(EqualsGeneric));
if (s_equalsMethod == null)
return false;
}

#if XUNIT_NULLABLE
return s_equalsMethod.MakeGenericMethod(objectType).Invoke(this, new object[] { x, y }) is true;
#else
return (bool)s_equalsMethod.MakeGenericMethod(objectType).Invoke(this, new object[] { x, y });
#endif
}

bool EqualsGeneric<U>(
U x,
U y) =>
new AssertEqualityComparer<U>(innerComparer: innerComparer).Equals(x, y);

public int GetHashCode(object obj)
{
throw new NotImplementedException();
}
}
}
}
15 changes: 9 additions & 6 deletions src/tests/Common/xunit/assert.xunit/Sdk/AssertHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ internal static class AssertHelper
static ConcurrentDictionary<Type, Dictionary<string, Func<object, object>>> gettersByType = new ConcurrentDictionary<Type, Dictionary<string, Func<object, object>>>();
#endif

#if XUNIT_NULLABLE
static Dictionary<string, Func<object?, object?>> GetGettersForType(Type type) =>
#else
static Dictionary<string, Func<object, object>> GetGettersForType(Type type) =>
#endif
gettersByType.GetOrAdd(type, _type =>
static Dictionary<string, Func<object?, object?>> GetGettersForType(
Type type) =>
gettersByType.GetOrAdd(type,
([DynamicallyAccessedMembers(
DynamicallyAccessedMemberTypes.PublicProperties
| DynamicallyAccessedMemberTypes.NonPublicProperties
| DynamicallyAccessedMemberTypes.PublicFields
| DynamicallyAccessedMemberTypes.NonPublicFields)]
Type _type) =>
{
var fieldGetters =
_type
Expand Down

0 comments on commit dc6e247

Please sign in to comment.