From 7a680b632f6477ce759c49cef865f48d783808aa Mon Sep 17 00:00:00 2001 From: labbbirder <502100554@qq.com> Date: Tue, 30 Jan 2024 13:21:57 +0800 Subject: [PATCH] fix: fail to inject when the methods are sharing the same name --- Editor/InjectHelper.cs | 58 +++++++++++++++++++++++--------------- Editor/UnityInjectUtils.cs | 5 ++-- Runtime/Constants.cs | 30 +++++++++++++------- Runtime/FixHelper.cs | 16 +++++++---- package.json | 2 +- 5 files changed, 70 insertions(+), 41 deletions(-) diff --git a/Editor/InjectHelper.cs b/Editor/InjectHelper.cs index e83e326..12ad31f 100644 --- a/Editor/InjectHelper.cs +++ b/Editor/InjectHelper.cs @@ -53,8 +53,8 @@ internal static bool InjectAssembly(InjectionInfo[] injections, string inputAsse //mark check var injected = targetAssembly.MainModule.Types.Any(t => - Constants.InjectedMarkName == t.Name && - Constants.InjectedMarkNamespace == t.Namespace); + Constants.INJECTED_MARK_NAME == t.Name && + Constants.INJECTED_MARK_NAMESPACE == t.Namespace); if (injected) { targetAssembly.Release(); @@ -72,7 +72,7 @@ internal static bool InjectAssembly(InjectionInfo[] injections, string inputAsse { throw new($"Cannot find Type `{type}` in target assembly {inputAssemblyPath}"); } - var targetMethod = targetType.FindMethod(methodName).Resolve(); + var targetMethod = targetType.FindMethod(injectedMethod.GetSignature()).Resolve(); if (targetMethod is null) { throw new($"Cannot find Method `{methodName}` in Type `{type}`"); @@ -88,8 +88,8 @@ internal static bool InjectAssembly(InjectionInfo[] injections, string inputAsse //mark make var InjectedMark = new TypeDefinition( - Constants.InjectedMarkNamespace, - Constants.InjectedMarkName, + Constants.INJECTED_MARK_NAMESPACE, + Constants.INJECTED_MARK_NAME, TypeAttributes.Class, targetAssembly.MainModule.TypeSystem.Object); targetAssembly.MainModule.Types.Add(InjectedMark); @@ -127,9 +127,11 @@ static IEnumerable GetContainingTypes(Type type) } } } + + static MethodDefinition DuplicateOriginalMethod(this TypeDefinition targetType, MethodDefinition targetMethod) { - var originName = Constants.GetOriginMethodName(targetMethod.Name); + var originName = Constants.GetOriginMethodName(targetMethod.Name, targetMethod.GetSignature()); var duplicatedMethod = targetType.Methods.FirstOrDefault(m => m.Name == originName); if (duplicatedMethod is null) { @@ -140,6 +142,8 @@ static MethodDefinition DuplicateOriginalMethod(this TypeDefinition targetType, } return duplicatedMethod; } + + static void Release(this AssemblyDefinition assemblyDefinition) { if (assemblyDefinition == null) return; @@ -147,9 +151,11 @@ static void Release(this AssemblyDefinition assemblyDefinition) assemblyDefinition.MainModule.SymbolReader?.Dispose(); assemblyDefinition.Dispose(); } + + static (FieldDefinition field, MethodReference fieldInvokeMethod) AddInjectField(this TypeDefinition targetType, MethodDefinition targetMethod, string methodName) { - var injectionName = Constants.GetInjectedFieldName(methodName); + var injectionName = Constants.GetInjectedFieldName(methodName, targetMethod.GetSignature()); var HasThis = targetMethod.HasThis; var Parameters = targetMethod.Parameters; var GenericParameters = targetMethod.GenericParameters; @@ -194,7 +200,7 @@ static void Release(this AssemblyDefinition assemblyDefinition) // FieldAttributes.Private|FieldAttributes.Static|FieldAttributes.Assembly, // targetType.Module.ImportReference(typeof(Delegate))); // var resMth = genInst.Resolve(); - var genMtd = rawGenType.FindMethod("Invoke"); + var genMtd = rawGenType.FindMethodByName("Invoke"); // genMtd.DeclaringType = genInst; var mnlMth = new MethodReference(genMtd.Name, genMtd.ReturnType, genInst) { @@ -208,6 +214,7 @@ static void Release(this AssemblyDefinition assemblyDefinition) return (sfldInject, mnlMth); } + static void AddInjectionMethod( this TypeDefinition targetType, MethodDefinition targetMethod, MethodDefinition originalMethod, @@ -285,19 +292,8 @@ static void Release(this AssemblyDefinition assemblyDefinition) ilProcessor.Append(Instruction.Create(OpCodes.Nop)); ilProcessor.Append(Instruction.Create(OpCodes.Ret)); } - // static void InjectCctor(this TypeDefinition targetType,FieldDefinition field){ - // var cctorMethod = targetType.Methods.FirstOrDefault(m=>m.Name==".cctor"); - // if(cctorMethod is null){ - // cctorMethod = new MethodDefinition(".cctor",MethodAttributes.Static|MethodAttributes.Private,targetType.Module.TypeSystem.Void); - // targetType.Methods.Add(cctorMethod); - // } - // var ilProcessor = cctorMethod.Body.GetILProcessor(); - // var bdis = cctorMethod.Body.Instructions; - // var insertPoint = bdis[0]; - // ilProcessor.InsertBefore(insertPoint,Instruction.Create(OpCodes.Ldsfld,field)); - // ilProcessor.InsertBefore(insertPoint,Instruction.Create(OpCodes.Ldsfld,field)); - // } - // =>md.GetType(type.ToString(),true); + + static Instruction createLdarg(this ILProcessor ilProcessor, int i) { if (i < s_ldargs.Length) @@ -314,6 +310,7 @@ static Instruction createLdarg(this ILProcessor ilProcessor, int i) } } + /// /// Create a clone of the given method definition /// @@ -337,6 +334,7 @@ public static MethodDefinition Clone(this MethodDefinition source) return result; } + /// /// Create a clone of the given method body /// @@ -362,6 +360,7 @@ public static MethodBody Clone(this MethodBody source, MethodDefinition target) return result; } + internal static bool IsReturnVoid(this MethodDefinition md) => md.ReturnType.ToString() == voidType.ToString(); internal static bool IsReturnValueType(this MethodDefinition md) @@ -370,8 +369,15 @@ internal static bool IsComplexValueType(this TypeReference td) => td.ToString() != voidType.ToString() && !td.IsPrimitive; internal static Type GetUnderlyingType(this TypeReference td) => td.IsPrimitive ? Type.GetType(td.Name) : objType; - internal static MethodReference FindMethod(this TypeDefinition td, string methodName) - => td.Module.ImportReference(td.Methods.FirstOrDefault(m => m.Name == methodName)); + + internal static string GetSignature(this MethodDefinition md) + => $"{md.Name}({string.Join(",", md.Parameters.Select(p => p.ParameterType.FullName))})"; + + internal static MethodReference FindMethod(this TypeDefinition td, string methodSignature) + => td.Module.ImportReference(td.Methods.FirstOrDefault(m => m.GetSignature().Equals(methodSignature))); + + internal static MethodReference FindMethodByName(this TypeDefinition td, string methodName) + => td.Module.ImportReference(td.Methods.FirstOrDefault(m => m.Name.Equals(methodName))); internal static TypeDefinition FindType(this ModuleDefinition md, Type type) { Assert.IsNotNull(type); @@ -421,11 +427,15 @@ void AddModule(ModuleDefinition md) } // return new TypeReference(type.Namespace,type.Name,md,md.TypeSystem.CoreLibrary); } + + internal static TypeReference FindType(this ModuleDefinition md) { return FindType(md, typeof(T)); // return new TypeReference(typeof(T).Namespace,typeof(T).Name,md,md.TypeSystem.CoreLibrary); } + + internal static TypeDefinition CreateDelegateType(this ModuleDefinition assembly, string name, TypeDefinition declaringType, TypeReference returnType, IEnumerable parameters) { @@ -481,6 +491,8 @@ internal static TypeReference FindType(this ModuleDefinition md) return dt; } + + static Type voidType = typeof(void); static Type objType = typeof(object); static OpCode[] s_ldargs = new[] { OpCodes.Ldarg_0, OpCodes.Ldarg_1, OpCodes.Ldarg_2, OpCodes.Ldarg_3 }; diff --git a/Editor/UnityInjectUtils.cs b/Editor/UnityInjectUtils.cs index 772dfbc..92e7ceb 100644 --- a/Editor/UnityInjectUtils.cs +++ b/Editor/UnityInjectUtils.cs @@ -235,13 +235,14 @@ static bool VisitAssembly(string assemblyPath, InjectionInfo[] injections, bool Debug.Log($"Inject success: {assemblyPath}"); return isWritten; } - catch + catch (Exception e) { + Debug.LogException(e); if (IsEngineAssembly) { File.Copy(backPath, assemblyPath, true); } - throw; + throw e; } } diff --git a/Runtime/Constants.cs b/Runtime/Constants.cs index 34efaec..3adcfad 100644 --- a/Runtime/Constants.cs +++ b/Runtime/Constants.cs @@ -1,18 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Security.Cryptography; using System.Text; -namespace com.bbbirder.injection { - public static class Constants{ - public const string InjectedMarkNamespace = "com.bbbirder"; - public const string InjectedMarkName = "InjectedMarkAttribute"; - public static string GetDelegateTypeName(string methodName) - => strBuilder.Clear().Append("__").Append(methodName).Append("Delegate").ToString(); +namespace com.bbbirder.injection +{ + public static class Constants + { + public const string INJECTED_MARK_NAMESPACE = "com.bbbirder"; + public const string INJECTED_MARK_NAME = "InjectedMarkAttribute"; - public static string GetInjectedFieldName(string methodName) - => strBuilder.Clear().Append("s_").Append(methodName).Append("_injection").ToString(); + public static string GetInjectedFieldName(string methodName, string methodSignature) + => strBuilder.Clear().Append("_injection_field+").Append(methodName).Append(MD5Hash(methodSignature)).ToString(); - public static string GetOriginMethodName(string methodName) - => strBuilder.Clear().Append("origin_").Append(methodName).ToString(); + public static string GetOriginMethodName(string methodName, string methodSignature) + => strBuilder.Clear().Append("_injection_origin+").Append(methodName).Append(MD5Hash(methodSignature)).ToString(); + static string MD5Hash(string rawContent) + { + var md5 = MD5.Create(); + var buffer = md5.ComputeHash(Encoding.UTF8.GetBytes(rawContent)); + return string.Concat(buffer.Select(b => b.ToString("X"))); + } static StringBuilder strBuilder = new(); } diff --git a/Runtime/FixHelper.cs b/Runtime/FixHelper.cs index 4264723..c6df07b 100644 --- a/Runtime/FixHelper.cs +++ b/Runtime/FixHelper.cs @@ -51,16 +51,17 @@ public static void Install(Assembly assembly) /// public static bool IsInjected(Type type) { - var mark = type.Assembly.GetType($"{Constants.InjectedMarkNamespace}.{Constants.InjectedMarkName}"); + var mark = type.Assembly.GetType($"{Constants.INJECTED_MARK_NAMESPACE}.{Constants.INJECTED_MARK_NAME}"); return mark != null; } public static MethodInfo GetOriginMethodFor(MethodInfo targetMethod) { - var oriName = Constants.GetOriginMethodName(targetMethod.Name); + var oriName = Constants.GetOriginMethodName(targetMethod.Name, targetMethod.GetSignature()); return targetMethod.DeclaringType.GetMethod(oriName, bindingFlags); } - static void FixMethod(InjectionInfo injection) + + public static void FixMethod(InjectionInfo injection) { injection.onStartFix?.Invoke(); var targetMethod = injection.InjectedMethod; @@ -80,7 +81,8 @@ static void FixMethod(InjectionInfo injection) try { - sfld = targetType.GetField(Constants.GetInjectedFieldName(methodName), bindingFlags ^ BindingFlags.Instance); + var sfldName = Constants.GetInjectedFieldName(methodName, targetMethod.GetSignature()); + sfld = targetType.GetField(sfldName, bindingFlags ^ BindingFlags.Instance); } catch (Exception e) { @@ -112,7 +114,8 @@ static void FixMethod(InjectionInfo injection) } // set overwrite origin field - var originMethod = targetType.GetMethod(Constants.GetOriginMethodName(methodName), bindingFlags); + var originName = Constants.GetOriginMethodName(methodName, targetMethod.GetSignature()); + var originMethod = targetType.GetMethod(originName, bindingFlags); try { var oriDelegate = originMethod.CreateDelegate(sfld.FieldType); @@ -218,6 +221,9 @@ public static string GetAssemblyPath(this Assembly assembly) return null; } + public static string GetSignature(this MethodBase m) + => $"{m.Name}({string.Join(",", m.GetParameters().Select(p => p.ParameterType.FullName))})"; + static InjectionInfo[] m_allInjections; public static InjectionInfo[] allInjections => m_allInjections ??= GetAllInjections(); diff --git a/package.json b/package.json index ac70a50..c1c8f24 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "com.bbbirder.injection", "displayName": "Unity Injection", "description": "Unity注入模块,可以运行时改变被注入函数实现。", - "version": "1.3.21", + "version": "1.3.24", "hideInEditor": false, "author": "bbbirder <502100554@qq.com>", "dependencies": {