Skip to content

Commit

Permalink
Merge pull request #788 from openmod/better_type_reflection
Browse files Browse the repository at this point in the history
Better type reflection
  • Loading branch information
Trojaner committed Nov 8, 2023
2 parents 5b9ab7b + 17931e6 commit 49b0801
Show file tree
Hide file tree
Showing 19 changed files with 277 additions and 196 deletions.
2 changes: 1 addition & 1 deletion framework/OpenMod.API/Plugins/IPluginAssemblyStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public interface IPluginAssemblyStore
/// Loads plugin assemblies from the given assembly source.
/// </summary>
/// <param name="source">The assemblies source.</param>
/// <returns>The loaded plugin asemblies.</returns>
/// <returns>The loaded plugin assemblies.</returns>
[OpenModInternal]
Task<ICollection<Assembly>> LoadPluginAssembliesAsync(IPluginAssembliesSource source);

Expand Down
2 changes: 1 addition & 1 deletion framework/OpenMod.API/Prioritization/PriorityAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace OpenMod.API.Prioritization
/// <summary>
/// Sets the priority for the given implementation.
/// </summary>
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false)]
[AttributeUsage(AttributeTargets.Class)]
public class PriorityAttribute : Attribute
{
public Priority Priority { get; set; } = Priority.Normal;
Expand Down
68 changes: 61 additions & 7 deletions framework/OpenMod.Common/Helpers/AssemblyExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,97 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;

namespace OpenMod.Common.Helpers
{
/// <summary>
/// Extension methods for <see cref="System.Reflection.Assembly"/>.
/// Extension methods for <see cref="Assembly"/>.
/// </summary>
public static class AssemblyExtensions
{
/// <summary>
/// Safely returns the set of loadable types from an assembly.
/// </summary>
/// <param name="assembly">The <see cref="System.Reflection.Assembly"/> from which to load types.</param>
/// <param name="assembly">The <see cref="Assembly"/> from which to load types.</param>
/// <returns>
/// The set of types from the <paramref name="assembly" />, or the subset
/// of types that could be loaded if there was any error.
/// </returns>
/// <exception cref="System.ArgumentNullException">
/// Thrown if <paramref name="assembly" /> is <see langword="null" />.
/// </exception>
/// <remarks>
/// Avoid using this method, unless you don't care about missing types
/// </remarks>
public static IEnumerable<Type> GetLoadableTypes(this Assembly assembly)
{
// Algorithm from StackOverflow answer here:
// https://stackoverflow.com/questions/7889228/how-to-prevent-reflectiontypeloadexception-when-calling-assembly-gettypes
if (assembly == null)
{
throw new ArgumentNullException(nameof(assembly));
}

try
{
// Its important that it is ToList() here or the exception may not get caught
return assembly.DefinedTypes.Select(t => t.AsType()).ToList();
return assembly.GetTypes();
}
catch (ReflectionTypeLoadException ex)
{
return ex.Types.Where(t => t != null)!;
return ex.Types.Where(t => t != null);
}
}

private static readonly Regex s_MissingFileAssemblyVersionRegex =
new("'(?<assembly>\\S+?), Version=(?<version>.+?), ",
RegexOptions.Compiled);

private static readonly Regex s_TypeLoadAssemblyVersionRegex =
new("assembly:(?<assembly>\\S+?), Version=(?<version>.+?), ",
RegexOptions.Compiled);

/// <summary>
/// Gets the missing assembly names based on ReflectionTypeLoadException
/// </summary>
public static IEnumerable<AssemblyName> GetMissingDependencies(this ReflectionTypeLoadException reflectionTypeLoadException)
{
if (reflectionTypeLoadException == null)
{
throw new ArgumentNullException(nameof(reflectionTypeLoadException));
}

var missingAssemblies = new Dictionary<string, Version>();
var loaderExceptions = reflectionTypeLoadException.LoaderExceptions;
foreach (var loaderException in loaderExceptions)
{
//TypeLoadException is just matching with MissingFileAssemblyVersionRegex
var match = s_MissingFileAssemblyVersionRegex.Match(loaderException.Message);
if (!match.Success)
match = s_TypeLoadAssemblyVersionRegex.Match(loaderException.Message);

if (!match.Success)
continue;

var assemblyName = match.Groups["assembly"].Value;
var version = Version.Parse(match.Groups["version"].Value);

if (missingAssemblies.TryGetValue(assemblyName, out var currentVersion) && currentVersion >= version)
continue;

missingAssemblies[assemblyName] = version;
}

return missingAssemblies.Select(s => new AssemblyName(s.Key) { Version = s.Value });
}

public static string GetNameVersion(this Assembly assembly)
{
var assemblyName = assembly.GetName();
return $"{assemblyName.Name} Version={assemblyName.Version}";
}

public static string GetNameVersion(this AssemblyName assemblyName)
{
return $"{assemblyName.Name} Version={assemblyName.Version}";
}
}
}
2 changes: 1 addition & 1 deletion framework/OpenMod.Common/Helpers/FileHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ internal static async Task<byte[]> ReadAllBytesAsync(string path)
{
using var stream = File.Open(path, FileMode.Open);
var buffer = new byte[stream.Length];
await stream.ReadAsync(buffer, offset: 0, count: (int) stream.Length);
_ = await stream.ReadAsync(buffer, offset: 0, count: (int) stream.Length);
return buffer;
}
}
Expand Down
43 changes: 24 additions & 19 deletions framework/OpenMod.Common/Helpers/ReflectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public static class ReflectionExtensions
{
private static readonly Regex s_VersionRegex = new("Version=(?<version>.+?), ", RegexOptions.Compiled);

// ReSharper disable once UnusedMember.Global
public static MethodBase? GetCallingMethod(Type[]? skipTypes = null, MethodBase[]? skipMethods = null, bool applyAsyncMethodPatch = true)
{
var skipList = new List<Type>(skipTypes ?? Type.EmptyTypes) { typeof(ReflectionExtensions) };
Expand All @@ -39,9 +40,7 @@ public static class ReflectionExtensions
frameMethod = frame.GetMethod();

// Check parent frame - if its from AsyncMethodBuilderCore, its definitely an async Task
if (frameMethod is MethodInfo parentFrameMethodInfo &&
(parentFrameMethodInfo.DeclaringType?.Name == "AsyncMethodBuilderCore"
|| parentFrameMethodInfo.DeclaringType?.Name == "AsyncTaskMethodBuilder"))
if (frameMethod is MethodInfo { DeclaringType.Name: "AsyncMethodBuilderCore" or "AsyncTaskMethodBuilder" })
{
frame = st.GetFrame(++tmpIndex);
frameMethod = frame.GetMethod();
Expand Down Expand Up @@ -69,6 +68,7 @@ public static class ReflectionExtensions
return frameTarget?.GetMethod();
}

// ReSharper disable once UnusedMember.Global
public static MethodBase? GetCallingMethod(params Assembly[] skipAssemblies)
{
var st = new StackTrace();
Expand All @@ -85,6 +85,7 @@ public static class ReflectionExtensions
return frameTarget?.GetMethod();
}

// ReSharper disable once UnusedMember.Global
public static IEnumerable<Type> GetTypeHierarchy(this Type type)
{
if (type == null)
Expand All @@ -108,22 +109,25 @@ public static IEnumerable<Type> FindAllTypes(this Assembly assembly, bool includ
throw new ArgumentNullException(nameof(assembly));
}

try
{
return assembly.GetLoadableTypes()
.Where(c => includeAbstractAndInterfaces || !c.IsAbstract && !c.IsInterface);
}
catch (ReflectionTypeLoadException e)
{
return e.Types.Where(t => t != null);
}
var types = assembly.GetLoadableTypes();
return includeAbstractAndInterfaces ? types : types.Where(t => !t.IsAbstract && !t.IsInterface);
}

public static IEnumerable<Type> FindTypes<T>(this Assembly assembly, bool includeAbstractAndInterfaces = false)
{
return assembly.FindAllTypes(includeAbstractAndInterfaces).Where(c => c.IsAssignableFrom(typeof(T)) || c.GetInterfaces().Any(x => x == typeof(T)));
var wantedType = typeof(T);
var types = assembly.FindAllTypes(includeAbstractAndInterfaces);
return types.Where(t => t.IsAssignableFrom(wantedType) || t.GetInterfaces().Any(x => x == wantedType));
}

public static IEnumerable<Type> FindTypes<T>(this IEnumerable<Type> types, bool includeAbstractAndInterfaces = false)
{
var wantedType = typeof(T);
var validTypes = types.Where(type => includeAbstractAndInterfaces || (!type.IsAbstract && !type.IsInterface));
return validTypes.Where(type => type.IsAssignableFrom(wantedType) || type.GetInterfaces().Any(x => x == wantedType));
}

// ReSharper disable once UnusedMember.Global
public static string GetVersionIndependentName(string assemblyName)
{
return GetVersionIndependentName(assemblyName, out _);
Expand All @@ -141,19 +145,20 @@ public static string GetVersionIndependentName(string assemblyName, out string e
return s_VersionRegex.Replace(assemblyName, string.Empty);
}

// ReSharper disable once UnusedMember.Global
public static string GetDebugName(this MethodBase mb)
{
if (mb == null)
{
throw new ArgumentNullException(nameof(mb));
}

if (mb is MemberInfo mi && mi.DeclaringType != null)
if (mb is MemberInfo { DeclaringType: not null } mi)
{
return mi.DeclaringType.Name + "." + mi.Name;
return $"{mi.DeclaringType.Name}.{mi.Name}";
}

return "<anonymous>#" + mb.Name;
return $"<anonymous>#{mb.Name}";
}

public static Task InvokeWithTaskSupportAsync(this MethodBase method, object? instance, object?[] @params)
Expand Down Expand Up @@ -212,9 +217,9 @@ public static T ToObject<T>(this Dictionary<object, object> dict)

public static bool HasConversionOperator(this Type from, Type to)
{
if (@from == null)
if (from == null)
{
throw new ArgumentNullException(nameof(@from));
throw new ArgumentNullException(nameof(from));
}

if (to == null)
Expand All @@ -223,7 +228,7 @@ public static bool HasConversionOperator(this Type from, Type to)
}

UnaryExpression BodyFunction(Expression body) => Expression.Convert(body, to);
ParameterExpression inp = Expression.Parameter(from, "inp");
var inp = Expression.Parameter(from, "inp");
try
{
// If this succeeds then we can cast 'from' type to 'to' type using implicit coercion
Expand Down
11 changes: 10 additions & 1 deletion framework/OpenMod.Common/Hotloading/Hotloader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public static class Hotloader

static Hotloader()
{
s_Assemblies = new(AssemblyNameEqualityComparer.Instance);
s_Assemblies = new Dictionary<AssemblyName, Assembly>(AssemblyNameEqualityComparer.Instance);
s_IsMono = Type.GetType("Mono.Runtime") is not null;
AppDomain.CurrentDomain.AssemblyResolve += OnAssemblyResolve;
}
Expand Down Expand Up @@ -162,6 +162,15 @@ public static void Remove(Assembly assembly)
return s_Assemblies.TryGetValue(name, out var assembly) ? assembly : null;
}

/// <summary>
/// Checks if hotloader contains assembly.
/// </summary>
/// <param name="name">The assembly name to check.</param>
public static bool ContainsAssembly(AssemblyName name)
{
return s_Assemblies.ContainsKey(name);
}

/// <summary>
/// Gets all hotloaded assemblies.
/// </summary>
Expand Down
4 changes: 2 additions & 2 deletions framework/OpenMod.Core/Commands/CommandStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public async Task InvalidateAsync()
}

var commandsData = await m_CommandDataStore.GetRegisteredCommandsAsync();
if (commandsData?.Commands == null)
if (commandsData.Commands == null)
{
throw new Exception("Failed to register commands: command data was null");
}
Expand Down Expand Up @@ -204,7 +204,7 @@ public async ValueTask DisposeAsync()

m_IsDisposing = true;
var commandsData = await m_CommandDataStore.GetRegisteredCommandsAsync();
if (commandsData?.Commands != null && commandsData.Commands.Count > 0)
if (commandsData?.Commands is { Count: > 0 })
{
// clear unknown commands

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ protected override async Task OnExecuteAsync()
await PrintAsync("No plugins found.", Color.Red);
}

var sb = new StringBuilder();
foreach (var plugin in plugins)
{
var sb = new StringBuilder(); // lgtm [cs/stringbuilder-creation-in-loop]
sb.Clear();
sb.Append($"{plugin.DisplayName} v{plugin.Version}");

if (!string.IsNullOrEmpty(plugin.Author))
Expand Down
10 changes: 7 additions & 3 deletions framework/OpenMod.Core/Commands/OpenModComponentCommandSource.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -30,8 +31,10 @@ public OpenModComponentCommandSource(

private void ScanAssemblyForCommmmands(Assembly assembly)
{
var types = assembly.FindTypes<ICommand>();
foreach (var type in types)
var assemblyTypes = assembly.FindAllTypes().ToList();
var commandTypes = assemblyTypes.FindTypes<ICommand>();

foreach (var type in commandTypes)
{
if (type.GetCustomAttribute<DontAutoRegister>(false) != null)
{
Expand All @@ -50,9 +53,10 @@ private void ScanAssemblyForCommmmands(Assembly assembly)
var registatration = new OpenModComponentBoundCommandRegistration(m_OpenModComponent, type);

m_Commands.Add(registatration);

}

foreach (var type in assembly.GetLoadableTypes())
foreach (var type in assemblyTypes)
{
ScanTypeForCommands(type);
}
Expand Down
11 changes: 1 addition & 10 deletions framework/OpenMod.Core/Helpers/SmartFormatterHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,18 @@
using OpenMod.API;
using OpenMod.Core.Localization;
using SmartFormat;
using SmartFormat.Extensions;

namespace OpenMod.Core.Helpers
{
[Obsolete("This class will be set to internal in future releases. To get SmartFormatter use " + nameof(SmartFormatOptions.GetSmartFormatter))]
[OpenModInternal]
public static class SmartFormatterHelper
internal static class SmartFormatterHelper
{
// https://github.com/axuno/SmartFormat/wiki/Async-and-Thread-Safety
// since v3.1.0 Smart.Default is flagged with [ThreadStatic], but it creates the default formatter and it's not possible to override it.
// So we are storing our formatter here
[ThreadStatic]
private static SmartFormatter? s_SmartFormatter;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[Obsolete("Use " + nameof(SmartFormatOptions.GetSmartFormatter) + " to get SmartFormatter")]
public static SmartFormatter ObtainSmartFormatter()
{
return s_SmartFormatter ??= Smart.CreateDefaultSmartFormat().AddExtensions(new TimeFormatter());
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static SmartFormatter ObtainSmartFormatter(SmartFormatOptions options)
{
Expand Down
Loading

0 comments on commit 49b0801

Please sign in to comment.