Skip to content

Commit

Permalink
Merge pull request #1125 from stakx/awaitable-factories
Browse files Browse the repository at this point in the history
Create and deconstruct awaitables using dedicated factories (`IAwaitableFactory`)
  • Loading branch information
stakx authored Dec 31, 2020
2 parents 26de1a8 + f6b3232 commit bab305e
Show file tree
Hide file tree
Showing 15 changed files with 297 additions and 87 deletions.
29 changes: 29 additions & 0 deletions src/Moq/Async/Awaitable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

namespace Moq.Async
{
internal static class Awaitable
{
/// <summary>
/// Recursively gets the result of (i.e. "unwraps") completed awaitables
/// until a value is found that isn't a successfully completed awaitable.
/// </summary>
/// <remarks>
/// As an example, given <paramref name="obj"/> := <c>Task.FromResult(Task.FromResult(42))</c>,
/// this method will return <c>42</c>.
/// </remarks>
/// <param name="obj">The (possibly awaitable) object to be "unwrapped".</param>
public static object TryGetResultRecursive(object obj)
{
if (obj != null
&& AwaitableFactory.TryGet(obj.GetType()) is { } awaitableFactory
&& awaitableFactory.TryGetResult(obj, out var result))
{
return result;
}

return obj;
}
}
}
47 changes: 47 additions & 0 deletions src/Moq/Async/AwaitableFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading.Tasks;

namespace Moq.Async
{
internal static class AwaitableFactory
{
private static readonly Dictionary<Type, Func<Type, IAwaitableFactory>> Providers;

static AwaitableFactory()
{
AwaitableFactory.Providers = new Dictionary<Type, Func<Type, IAwaitableFactory>>
{
[typeof(Task)] = awaitableType => TaskFactory.Instance,
[typeof(ValueTask)] = awaitableType => ValueTaskFactory.Instance,
[typeof(Task<>)] = awaitableType => AwaitableFactory.Create(typeof(TaskFactory<>), awaitableType),
[typeof(ValueTask<>)] = awaitableType => AwaitableFactory.Create(typeof(ValueTaskFactory<>), awaitableType),
};
}

private static IAwaitableFactory Create(Type awaitableFactoryType, Type awaitableType)
{
return (IAwaitableFactory)Activator.CreateInstance(
awaitableFactoryType.MakeGenericType(
awaitableType.GetGenericArguments()));
}

public static IAwaitableFactory TryGet(Type type)
{
Debug.Assert(type != null);

var key = type.IsConstructedGenericType ? type.GetGenericTypeDefinition() : type;

if (AwaitableFactory.Providers.TryGetValue(key, out var provider))
{
return provider.Invoke(type);
}

return null;
}
}
}
34 changes: 34 additions & 0 deletions src/Moq/Async/AwaitableFactory`1.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Diagnostics;

namespace Moq.Async
{
/// <summary>
/// Abstract base class that facilitates type-safe implementation of <see cref="IAwaitableFactory"/>
/// for awaitables that do not produce a result when awaited.
/// </summary>
internal abstract class AwaitableFactory<TAwaitable> : IAwaitableFactory
{
Type IAwaitableFactory.ResultType => typeof(void);

public abstract TAwaitable CreateCompleted();

object IAwaitableFactory.CreateCompleted(object result)
{
Debug.Assert(result == null);

return this.CreateCompleted();
}

bool IAwaitableFactory.TryGetResult(object awaitable, out object result)
{
Debug.Assert(awaitable is TAwaitable);

result = null;
return false;
}
}
}
42 changes: 42 additions & 0 deletions src/Moq/Async/AwaitableFactory`2.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Diagnostics;

namespace Moq.Async
{
/// <summary>
/// Abstract base class that facilitates type-safe implementation of <see cref="IAwaitableFactory"/>
/// for awaitables that produce a result when awaited.
/// </summary>
internal abstract class AwaitableFactory<TAwaitable, TResult> : IAwaitableFactory
{
public Type ResultType => typeof(TResult);

public abstract TAwaitable CreateCompleted(TResult result);

object IAwaitableFactory.CreateCompleted(object result)
{
Debug.Assert(result is TResult || result == null);

return this.CreateCompleted((TResult)result);
}

public abstract bool TryGetResult(TAwaitable awaitable, out TResult result);

bool IAwaitableFactory.TryGetResult(object awaitable, out object result)
{
Debug.Assert(awaitable is TAwaitable);

if (this.TryGetResult((TAwaitable)awaitable, out var r))
{
result = r;
return true;
}

result = null;
return false;
}
}
}
16 changes: 16 additions & 0 deletions src/Moq/Async/IAwaitableFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;

namespace Moq.Async
{
internal interface IAwaitableFactory
{
Type ResultType { get; }

object CreateCompleted(object result = null);

bool TryGetResult(object awaitable, out object result);
}
}
21 changes: 21 additions & 0 deletions src/Moq/Async/TaskFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System.Threading.Tasks;

namespace Moq.Async
{
internal sealed class TaskFactory : AwaitableFactory<Task>
{
public static readonly TaskFactory Instance = new TaskFactory();

private TaskFactory()
{
}

public override Task CreateCompleted()
{
return Task.FromResult<object>(default);
}
}
}
27 changes: 27 additions & 0 deletions src/Moq/Async/TaskFactory`1.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System.Threading.Tasks;

namespace Moq.Async
{
internal sealed class TaskFactory<TResult> : AwaitableFactory<Task<TResult>, TResult>
{
public override Task<TResult> CreateCompleted(TResult result)
{
return Task.FromResult(result);
}

public override bool TryGetResult(Task<TResult> task, out TResult result)
{
if (task.Status == TaskStatus.RanToCompletion)
{
result = task.Result;
return true;
}

result = default;
return false;
}
}
}
21 changes: 21 additions & 0 deletions src/Moq/Async/ValueTaskFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System.Threading.Tasks;

namespace Moq.Async
{
internal sealed class ValueTaskFactory : AwaitableFactory<ValueTask>
{
public static readonly ValueTaskFactory Instance = new ValueTaskFactory();

private ValueTaskFactory()
{
}

public override ValueTask CreateCompleted()
{
return default;
}
}
}
27 changes: 27 additions & 0 deletions src/Moq/Async/ValueTaskFactory`1.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System.Threading.Tasks;

namespace Moq.Async
{
internal sealed class ValueTaskFactory<TResult> : AwaitableFactory<ValueTask<TResult>, TResult>
{
public override ValueTask<TResult> CreateCompleted(TResult result)
{
return new ValueTask<TResult>(result);
}

public override bool TryGetResult(ValueTask<TResult> valueTask, out TResult result)
{
if (valueTask.IsCompletedSuccessfully)
{
result = valueTask.Result;
return true;
}

result = default;
return false;
}
}
}
4 changes: 3 additions & 1 deletion src/Moq/InnerMockSetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using System.Diagnostics;
using System.Linq.Expressions;

using Moq.Async;

namespace Moq
{
internal sealed class InnerMockSetup : SetupWithOutParameterSupport
Expand All @@ -13,7 +15,7 @@ internal sealed class InnerMockSetup : SetupWithOutParameterSupport
public InnerMockSetup(Expression originalExpression, Mock mock, InvocationShape expectation, object returnValue)
: base(originalExpression, mock, expectation)
{
Debug.Assert(Unwrap.ResultIfCompletedTask(returnValue) is IMocked);
Debug.Assert(Awaitable.TryGetResultRecursive(returnValue) is IMocked);

this.returnValue = returnValue;

Expand Down
57 changes: 22 additions & 35 deletions src/Moq/LookupOrFallbackDefaultValueProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
using System.Reflection;
using System.Threading.Tasks;

using Moq.Async;

namespace Moq
{
/// <summary>
Expand Down Expand Up @@ -41,9 +43,6 @@ protected LookupOrFallbackDefaultValueProvider()
{
this.factories = new Dictionary<object, Func<Type, Mock, object>>()
{
[typeof(Task)] = CreateTask,
[typeof(Task<>)] = CreateTaskOf,
[typeof(ValueTask<>)] = CreateValueTaskOf,
["System.ValueTuple`1"] = CreateValueTupleOf,
["System.ValueTuple`2"] = CreateValueTupleOf,
["System.ValueTuple`3"] = CreateValueTupleOf,
Expand All @@ -64,8 +63,11 @@ protected void Deregister(Type factoryKey)
{
Debug.Assert(factoryKey != null);

this.factories.Remove(factoryKey);
this.factories.Remove(factoryKey.FullName);
// NOTE: In order to be able to unregister the default logic for awaitable types,
// we need a way (below) to know when to delegate to an `IAwaitableFactory`, and when not to.
// This is why we only reset the dictionary entry instead of removing it.
this.factories[factoryKey] = null;
this.factories[factoryKey.FullName] = null;
}

/// <summary>
Expand Down Expand Up @@ -122,9 +124,21 @@ protected internal sealed override object GetDefaultValue(Type type, Mock mock)
: type;

Func<Type, Mock, object> factory;
return this.factories.TryGetValue(handlerKey , out factory) ? factory.Invoke(type, mock)
: this.factories.TryGetValue(handlerKey.FullName, out factory) ? factory.Invoke(type, mock)
: this.GetFallbackDefaultValue(type, mock);
if (this.factories.TryGetValue(handlerKey, out factory) || this.factories.TryGetValue(handlerKey.FullName, out factory))
{
if (factory != null) // This prevents delegation to an `IAwaitableFactory` for deregistered awaitable types; see note above.
{
return factory.Invoke(type, mock);
}
}
else if (AwaitableFactory.TryGet(type) is { } awaitableFactory)
{
var resultType = awaitableFactory.ResultType;
var result = resultType != typeof(void) ? this.GetDefaultValue(resultType, mock) : null;
return awaitableFactory.CreateCompleted(result);
}

return this.GetFallbackDefaultValue(type, mock);
}

/// <summary>
Expand All @@ -142,33 +156,6 @@ protected virtual object GetFallbackDefaultValue(Type type, Mock mock)
return type.GetDefaultValue();
}

private static object CreateTask(Type type, Mock mock)
{
return Task.FromResult(false);
}

private object CreateTaskOf(Type type, Mock mock)
{
var resultType = type.GetGenericArguments()[0];
var result = this.GetDefaultValue(resultType, mock);

var tcsType = typeof(TaskCompletionSource<>).MakeGenericType(resultType);
var tcs = Activator.CreateInstance(tcsType);
tcsType.GetMethod("SetResult").Invoke(tcs, new[] { result });
return tcsType.GetProperty("Task").GetValue(tcs, null);
}

private object CreateValueTaskOf(Type type, Mock mock)
{
var resultType = type.GetGenericArguments()[0];
var result = this.GetDefaultValue(resultType, mock);

// `Activator.CreateInstance` could throw an `AmbiguousMatchException` in this use case,
// so we're explicitly selecting and calling the constructor we want to use:
var valueTaskCtor = type.GetConstructor(new[] { resultType });
return valueTaskCtor.Invoke(new object[] { result });
}

private object CreateValueTupleOf(Type type, Mock mock)
{
var itemTypes = type.GetGenericArguments();
Expand Down
Loading

0 comments on commit bab305e

Please sign in to comment.