Skip to content

Commit

Permalink
Allow IAsyncEnumerable for ValueSource/TestCaseSource
Browse files Browse the repository at this point in the history
  • Loading branch information
manfred-brands committed Nov 14, 2023
1 parent 5a5415f commit aeda665
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ class MyTests
RoslynAssert.Diagnostics(analyzer, expectedDiagnostic, testCode);
}

[Test]
public void AnalyzeWhenTypeOfSourceImplementsIAsyncEnumerable()
{
var testCode = TestUtility.WrapClassInNamespaceAndAddUsing(@"
public class AnalyzeWhenTypeOfSourceImplementsIAsyncEnumerable
{
[TestCaseSource(typeof(MyTests))]
public void Test(int i)
{
}
}
public sealed class MyTests : IAsyncEnumerable<int>
{
public IAsyncEnumerator<int> GetAsyncEnumerator(CancellationToken cancellationToken = default(CancellationToken))
{
throw new NotImplementedException();
}
}",
additionalUsings: "using System.Collections.Generic;using System.Threading;");

RoslynAssert.Valid(analyzer, testCode);
}

[Test]
public void AnalyzeWhenTypeOfSourceNoDefaultConstructor()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ public static IEnumerable<int> TestData(string dummy, int anotherDummy)
[TestCase("private static TestCaseData[] TestCases => new TestCaseData[0];")]
[TestCase("private static TestCaseData[] TestCases() => new TestCaseData[0];")]
[TestCase("private static Task<TestCaseData[]> TestCases() => Task.FromResult(new TestCaseData[0]);")]
[TestCase("private static async IAsyncEnumerable<int> TestCases() { foreach (var value in new[] { 0 }) { yield return value; await Task.Yield(); } }")]
public void AnalyzeWhenSourceDoesProvideIEnumerable(string testCaseMember)
{
var testCode = TestUtility.WrapClassInNamespaceAndAddUsing($@"
Expand All @@ -210,7 +211,7 @@ public class AnalyzeWhenSourceDoesProvideIEnumerable
public void Test([ValueSource(nameof(TestCases))] int number)
{{
}}
}}");
}}", additionalUsings: "using System.Collections.Generic;");

RoslynAssert.Valid(analyzer, testCode);
}
Expand All @@ -230,7 +231,7 @@ public void Test([ValueSource(nameof(TestCases))] int number)
public void AnalyzeWhenSourceDoesNotProvideIEnumerable(string testCaseMember, string returnType)
{
var testCode = TestUtility.WrapClassInNamespaceAndAddUsing($@"
public class AnalyzeWhenSourceDoesProvideIEnumerable
public class AnalyzeWhenSourceDoesNotProvideIEnumerable
{{
#pragma warning disable CS0414 // Consider changing the field to a 'const'
{testCaseMember}
Expand Down
34 changes: 34 additions & 0 deletions src/nunit.analyzers/Extensions/ITypeSymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,40 @@ internal static bool IsIEnumerable(this ITypeSymbol @this, out ITypeSymbol? elem
return false;
}

/// <summary>
/// Return value indicates whether type implements I(Async)Enumerable{T} interface.
/// </summary>
/// <param name="elementType">Contains I(Async)Enumerable generic argument, or null, if type implements
/// only non-generic IEnumerable interface, or no I(Async)Enumerable interface at all.</param>
internal static bool IsIEnumerableOrIAsyncEnumerable(this ITypeSymbol @this, out ITypeSymbol? elementType)
{
elementType = null;

var allInterfaces = @this.AllInterfaces;

if (@this is INamedTypeSymbol namedType && namedType.TypeKind == TypeKind.Interface)
allInterfaces = allInterfaces.Add(namedType);

var genericIEnumerableInterface = allInterfaces.FirstOrDefault(i =>
i.GetFullMetadataName() is "System.Collections.Generic.IEnumerable`1" or "System.Collections.Generic.IAsyncEnumerable`1");

if (genericIEnumerableInterface is not null)
{
elementType = genericIEnumerableInterface.TypeArguments.FirstOrDefault();
return true;
}

var nonGenericIEnumerableInterface = allInterfaces.FirstOrDefault(i =>
i.GetFullMetadataName() == "System.Collections.IEnumerable");

if (nonGenericIEnumerableInterface is not null)
{
return true;
}

return false;
}

/// <summary>
/// Return value indicates whether type implements IDisposable.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ internal static class TestCaseSourceUsageConstants
internal const string ConsiderNameOfInsteadOfStringConstantMessage = "Consider using nameof({0}) instead of \"{1}\"";
internal const string ConsiderNameOfInsteadOfStringConstantDescription = "The TestCaseSource should use nameof operator to specify target.";

internal const string SourceTypeNotIEnumerableTitle = "The source type does not implement IEnumerable";
internal const string SourceTypeNotIEnumerableMessage = "The source type '{0}' does not implement IEnumerable";
internal const string SourceTypeNotIEnumerableDescription = "The source type must implement IEnumerable in order to provide test cases.";
internal const string SourceTypeNotIEnumerableTitle = "The source type does not implement I(Async)Enumerable";
internal const string SourceTypeNotIEnumerableMessage = "The source type '{0}' does not implement I(Async)Enumerable";
internal const string SourceTypeNotIEnumerableDescription = "The source type must implement I(Async)Enumerable in order to provide test cases.";

internal const string SourceTypeNoDefaultConstructorTitle = "The source type does not have a default constructor";
internal const string SourceTypeNoDefaultConstructorMessage = "The source type '{0}' does not have a default constructor";
Expand All @@ -22,9 +22,9 @@ internal static class TestCaseSourceUsageConstants
internal const string MismatchInNumberOfParametersMessage = "The TestCaseSource provides '{0}' parameter(s), but the target method expects '{1}' parameter(s)";
internal const string MismatchInNumberOfParametersDescription = "The number of parameters provided by the TestCaseSource must match the number of parameters in the target method.";

internal const string SourceDoesNotReturnIEnumerableTitle = "The source specified by the TestCaseSource does not return an IEnumerable or a type that implements IEnumerable";
internal const string SourceDoesNotReturnIEnumerableMessage = "The TestCaseSource does not return an IEnumerable or a type that implements IEnumerable. Instead it returns a '{0}'.";
internal const string SourceDoesNotReturnIEnumerableDescription = "The source specified by the TestCaseSource must return an IEnumerable or a type that implements IEnumerable.";
internal const string SourceDoesNotReturnIEnumerableTitle = "The source specified by the TestCaseSource does not return an I(Async)Enumerable or a type that implements I(Async)Enumerable";
internal const string SourceDoesNotReturnIEnumerableMessage = "The TestCaseSource does not return an I(Async)Enumerable or a type that implements I(Async)Enumerable. Instead it returns a '{0}'.";
internal const string SourceDoesNotReturnIEnumerableDescription = "The source specified by the TestCaseSource must return an I(Async)Enumerable or a type that implements I(Async)Enumerable.";

internal const string TestCaseSourceSuppliesParametersTitle = "The TestCaseSource provides parameters to a source - field or property - that expects no parameters";
internal const string TestCaseSourceSuppliesParametersMessage = "The TestCaseSource provides '{0}' parameter(s), but {1} cannot take parameters";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ private static void AnalyzeAttribute(SyntaxNodeAnalysisContext context, INamedTy
// The Type argument in this form represents the class that provides test cases.
// It must have a default constructor and implement IEnumerable.
var sourceType = attributeInfo.SourceType;
bool typeImplementsIEnumerable = sourceType.IsIEnumerable(out _);
bool typeImplementsIEnumerable = sourceType.IsIEnumerableOrIAsyncEnumerable(out _);
bool typeHasDefaultConstructor = sourceType.Constructors.Any(c => c.Parameters.IsEmpty);
if (!typeImplementsIEnumerable)
{
Expand Down Expand Up @@ -292,7 +292,7 @@ private static void AnalyzeAttribute(SyntaxNodeAnalysisContext context, INamedTy
SyntaxNode syntaxNode,
ITypeSymbol typeSymbol)
{
if (!typeSymbol.IsIEnumerable(out ITypeSymbol? elementType))
if (!typeSymbol.IsIEnumerableOrIAsyncEnumerable(out ITypeSymbol? elementType))
{
context.ReportDiagnostic(Diagnostic.Create(
sourceDoesNotReturnIEnumerable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private static void AnalyzeAttribute(SyntaxNodeAnalysisContext context, INamedTy
if (symbol is IMethodSymbol && memberType.IsAwaitable(out ITypeSymbol? returnType))
memberType = returnType;

if (!memberType.IsIEnumerable(out var _))
if (!memberType.IsIEnumerableOrIAsyncEnumerable(out var _))
{
context.ReportDiagnostic(Diagnostic.Create(
sourceDoesNotReturnIEnumerable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ internal static class ValueSourceUsageConstants
internal const string MethodExpectParametersMessage = "The ValueSource cannot supply parameters, but the target method expects '{0}' parameter(s)";
internal const string MethodExpectParametersDescription = "The target method expects parameters which cannot be supplied by the ValueSource.";

internal const string SourceDoesNotReturnIEnumerableTitle = "The source specified by the ValueSource does not return an IEnumerable or a type that implements IEnumerable";
internal const string SourceDoesNotReturnIEnumerableMessage = "The ValueSource does not return an IEnumerable or a type that implements IEnumerable. Instead it returns a '{0}'.";
internal const string SourceDoesNotReturnIEnumerableDescription = "The source specified by the ValueSource must return an IEnumerable or a type that implements IEnumerable.";
internal const string SourceDoesNotReturnIEnumerableTitle = "The source specified by the ValueSource does not return an I(Async)Enumerable or a type that implements I(Async)Enumerable";
internal const string SourceDoesNotReturnIEnumerableMessage = "The ValueSource does not return an I(Async)Enumerable or a type that implements I(Async)Enumerable. Instead it returns a '{0}'.";
internal const string SourceDoesNotReturnIEnumerableDescription = "The source specified by the ValueSource must return an I(Async)Enumerable or a type that implements I(Async)Enumerable.";
}
}

0 comments on commit aeda665

Please sign in to comment.