-
Notifications
You must be signed in to change notification settings - Fork 9.8k
/
IdentityEntityFrameworkBuilderExtensions.cs
113 lines (105 loc) · 5.25 KB
/
IdentityEntityFrameworkBuilderExtensions.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Reflection;
using Microsoft.AspNetCore.Identity;
using Microsoft.AspNetCore.Identity.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection.Extensions;
namespace Microsoft.Extensions.DependencyInjection
{
/// <summary>
/// Contains extension methods to <see cref="IdentityBuilder"/> for adding entity framework stores.
/// </summary>
public static class IdentityEntityFrameworkBuilderExtensions
{
/// <summary>
/// Adds an Entity Framework implementation of identity information stores.
/// </summary>
/// <typeparam name="TContext">The Entity Framework database context to use.</typeparam>
/// <param name="builder">The <see cref="IdentityBuilder"/> instance this method extends.</param>
/// <returns>The <see cref="IdentityBuilder"/> instance this method extends.</returns>
public static IdentityBuilder AddEntityFrameworkStores<TContext>(this IdentityBuilder builder)
where TContext : DbContext
{
AddStores(builder.Services, builder.UserType, builder.RoleType, typeof(TContext));
return builder;
}
private static void AddStores(IServiceCollection services, Type userType, Type roleType, Type contextType)
{
var identityUserType = FindGenericBaseType(userType, typeof(IdentityUser<>));
if (identityUserType == null)
{
throw new InvalidOperationException(Resources.NotIdentityUser);
}
var keyType = identityUserType.GenericTypeArguments[0];
if (roleType != null)
{
var identityRoleType = FindGenericBaseType(roleType, typeof(IdentityRole<>));
if (identityRoleType == null)
{
throw new InvalidOperationException(Resources.NotIdentityRole);
}
Type userStoreType = null;
Type roleStoreType = null;
var identityContext = FindGenericBaseType(contextType, typeof(IdentityDbContext<,,,,,,,>));
if (identityContext == null)
{
// If its a custom DbContext, we can only add the default POCOs
userStoreType = typeof(UserStore<,,,>).MakeGenericType(userType, roleType, contextType, keyType);
roleStoreType = typeof(RoleStore<,,>).MakeGenericType(roleType, contextType, keyType);
}
else
{
userStoreType = typeof(UserStore<,,,,,,,,>).MakeGenericType(userType, roleType, contextType,
identityContext.GenericTypeArguments[2],
identityContext.GenericTypeArguments[3],
identityContext.GenericTypeArguments[4],
identityContext.GenericTypeArguments[5],
identityContext.GenericTypeArguments[7],
identityContext.GenericTypeArguments[6]);
roleStoreType = typeof(RoleStore<,,,,>).MakeGenericType(roleType, contextType,
identityContext.GenericTypeArguments[2],
identityContext.GenericTypeArguments[4],
identityContext.GenericTypeArguments[6]);
}
services.TryAddScoped(typeof(IUserStore<>).MakeGenericType(userType), userStoreType);
services.TryAddScoped(typeof(IRoleStore<>).MakeGenericType(roleType), roleStoreType);
}
else
{ // No Roles
Type userStoreType = null;
var identityContext = FindGenericBaseType(contextType, typeof(IdentityUserContext<,,,,>));
if (identityContext == null)
{
// If its a custom DbContext, we can only add the default POCOs
userStoreType = typeof(UserOnlyStore<,,>).MakeGenericType(userType, contextType, keyType);
}
else
{
userStoreType = typeof(UserOnlyStore<,,,,,>).MakeGenericType(userType, contextType,
identityContext.GenericTypeArguments[1],
identityContext.GenericTypeArguments[2],
identityContext.GenericTypeArguments[3],
identityContext.GenericTypeArguments[4]);
}
services.TryAddScoped(typeof(IUserStore<>).MakeGenericType(userType), userStoreType);
}
}
private static TypeInfo FindGenericBaseType(Type currentType, Type genericBaseType)
{
var type = currentType;
while (type != null)
{
var typeInfo = type.GetTypeInfo();
var genericType = type.IsGenericType ? type.GetGenericTypeDefinition() : null;
if (genericType != null && genericType == genericBaseType)
{
return typeInfo;
}
type = type.BaseType;
}
return null;
}
}
}