Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make NativeLibrary Load/TryLoad use ALC extension points for the specified assembly #34519

Merged
merged 8 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
901 changes: 481 additions & 420 deletions src/coreclr/src/vm/dllimport.cpp

Large diffs are not rendered by default.

8 changes: 0 additions & 8 deletions src/coreclr/src/vm/dllimport.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,6 @@ class NDirect

private:
NDirect() {LIMITED_METHOD_CONTRACT;}; // prevent "new"'s on this class

elinor-fung marked this conversation as resolved.
Show resolved Hide resolved
static NATIVE_LIBRARY_HANDLE LoadFromNativeDllSearchDirectories(LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker);
static NATIVE_LIBRARY_HANDLE LoadFromPInvokeAssemblyDirectory(Assembly *pAssembly, LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker);
static NATIVE_LIBRARY_HANDLE LoadLibraryModuleViaHost(NDirectMethodDesc * pMD, LPCWSTR wszLibName);
static NATIVE_LIBRARY_HANDLE LoadLibraryModuleViaEvent(NDirectMethodDesc * pMD, LPCWSTR wszLibName);
static NATIVE_LIBRARY_HANDLE LoadLibraryModuleViaCallback(NDirectMethodDesc * pMD, LPCWSTR wszLibName);
static NATIVE_LIBRARY_HANDLE LoadLibraryModuleBySearch(NDirectMethodDesc * pMD, LoadLibErrorTracker * pErrorTracker, LPCWSTR wszLibName);
static NATIVE_LIBRARY_HANDLE LoadLibraryModuleBySearch(Assembly *callingAssembly, BOOL searchAssemblyDirectory, DWORD dllImportSearchPathFlags, LoadLibErrorTracker * pErrorTracker, LPCWSTR wszLibName);
};

//----------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,21 @@ public class ALC : AssemblyLoadContext
{
public bool LoadUnmanagedDllCalled { get; private set; }

public void Reset()
{
LoadUnmanagedDllCalled = false;
}

protected override IntPtr LoadUnmanagedDll(string unmanagedDllName)
{
LoadUnmanagedDllCalled = true;

if (string.Equals(unmanagedDllName, NativeLibraryToLoad.InvalidName))
return LoadUnmanagedDllFromPath(NativeLibraryToLoad.GetFullPath());

if (string.Equals(unmanagedDllName, FakeNativeLibrary.Name))
return FakeNativeLibrary.Handle;

return IntPtr.Zero;
}
}
Expand Down Expand Up @@ -50,12 +58,28 @@ public static void ValidateLoadUnmanagedDll()
{
Console.WriteLine($"Running {nameof(ValidateLoadUnmanagedDll)}...");

ALC alc = new ALC();
var asm = alc.LoadFromAssemblyPath(Assembly.GetExecutingAssembly().Location);

Console.WriteLine(" -- Validate explicit load...");
IntPtr ptr = NativeLibrary.Load(FakeNativeLibrary.Name, asm, null);
Assert.IsTrue(alc.LoadUnmanagedDllCalled, "AssemblyLoadContext.LoadUnmanagedDll should have been called.");
Assert.AreEqual(FakeNativeLibrary.Handle, ptr, $"Unexpected return value for {nameof(NativeLibrary.Load)}");

alc.Reset();
ptr = IntPtr.Zero;

bool success = NativeLibrary.TryLoad(FakeNativeLibrary.Name, asm, null, out ptr);
Assert.IsTrue(success, $"NativeLibrary.TryLoad should have succeeded");
Assert.IsTrue(alc.LoadUnmanagedDllCalled, "AssemblyLoadContext.LoadUnmanagedDll should have been called.");
Assert.AreEqual(FakeNativeLibrary.Handle, ptr, $"Unexpected return value for {nameof(NativeLibrary.Load)}");
alc.Reset();

Console.WriteLine(" -- Validate p/invoke...");
int addend1 = rand.Next(int.MaxValue / 2);
int addend2 = rand.Next(int.MaxValue / 2);
int expected = addend1 + addend2;

ALC alc = new ALC();
int value = NativeSumInAssemblyLoadContext(alc, addend1, addend2);
Assert.IsTrue(alc.LoadUnmanagedDllCalled, "AssemblyLoadContext.LoadUnmanagedDll should have been called.");
Assert.AreEqual(expected, value, $"Unexpected return value for {nameof(NativeSum)}");
Expand All @@ -65,14 +89,39 @@ public static void ValidateResolvingUnmanagedDllEvent()
{
Console.WriteLine($"Running {nameof(ValidateResolvingUnmanagedDllEvent)}...");

Console.WriteLine(" -- Validate explicit load: custom ALC...");
AssemblyLoadContext alcExplicitLoad = new AssemblyLoadContext(nameof(ValidateResolvingUnmanagedDllEvent));
var asm = alcExplicitLoad.LoadFromAssemblyPath(Assembly.GetExecutingAssembly().Location);
ValidateResolvingUnmanagedDllEvent_ExplicitLoad(asm);

Console.WriteLine(" -- Validate explicit load: default ALC...");
ValidateResolvingUnmanagedDllEvent_ExplicitLoad(Assembly.GetExecutingAssembly());

Console.WriteLine(" -- Validate p/invoke: custom ALC...");
AssemblyLoadContext alc = new AssemblyLoadContext(nameof(ValidateResolvingUnmanagedDllEvent));
ValidateResolvingUnmanagedDllEvent_PInvoke(alc);
AssemblyLoadContext alcPInvoke = new AssemblyLoadContext(nameof(ValidateResolvingUnmanagedDllEvent));
ValidateResolvingUnmanagedDllEvent_PInvoke(alcPInvoke);

Console.WriteLine(" -- Validate p/invoke: default ALC...");
ValidateResolvingUnmanagedDllEvent_PInvoke(AssemblyLoadContext.Default);
}

private static void ValidateResolvingUnmanagedDllEvent_ExplicitLoad(Assembly assembly)
{
AssemblyLoadContext alc = AssemblyLoadContext.GetLoadContext(assembly);
using (var handler = new Handlers(alc, returnValid: false))
{
Assert.Throws<DllNotFoundException>(() => NativeLibrary.Load(FakeNativeLibrary.Name, assembly, null));
Assert.IsTrue(handler.EventHandlerInvoked, "Event handler should have been invoked");
}

using (var handler = new Handlers(alc, returnValid: true))
{
IntPtr ptr = NativeLibrary.Load(FakeNativeLibrary.Name, assembly, null);
Assert.IsTrue(handler.EventHandlerInvoked, "Event handler should have been invoked");
Assert.AreEqual(FakeNativeLibrary.Handle, ptr, $"Unexpected return value for {nameof(NativeLibrary.Load)}");
}
}

private static void ValidateResolvingUnmanagedDllEvent_PInvoke(AssemblyLoadContext alc)
{
int addend1 = rand.Next(int.MaxValue / 2);
Expand Down Expand Up @@ -155,6 +204,9 @@ private IntPtr OnResolvingUnmanagedDll(Assembly assembly, string libraryName)
if (string.Equals(libraryName, NativeLibraryToLoad.InvalidName))
return NativeLibrary.Load(NativeLibraryToLoad.Name, assembly, null);

if (string.Equals(libraryName, FakeNativeLibrary.Name))
return FakeNativeLibrary.Handle;

return IntPtr.Zero;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,26 @@ public class CallbackStressTest
public static void SetResolve()
{
Console.WriteLine("Setting PInvoke Resolver");

DllImportResolver resolver =
(string libraryName, Assembly asm, DllImportSearchPath? dllImportSearchPath) =>
{
if (dllImportSearchPath != DllImportSearchPath.System32)
if (string.Equals(libraryName, NativeLibraryToLoad.InvalidName))
{
Console.WriteLine($"Unexpected dllImportSearchPath: {dllImportSearchPath.ToString()}");
throw new ArgumentException();
if (dllImportSearchPath != DllImportSearchPath.System32)
{
Console.WriteLine($"Unexpected dllImportSearchPath: {dllImportSearchPath.ToString()}");
throw new ArgumentException();
}

return NativeLibrary.Load(NativeLibraryToLoad.Name, asm, null);
}

return NativeLibrary.Load(NativeLibraryToLoad.Name, asm, null);
return IntPtr.Zero;
};

NativeLibrary.SetDllImportResolver(
Assembly.GetExecutingAssembly(),
Assembly.GetExecutingAssembly(),
resolver);
}

Expand All @@ -61,7 +66,7 @@ public static void DoCallTryCatch(bool shouldThrow)
s_PInvokesExecuted += (a == 20 ? 1 : 0);
}
catch (DllNotFoundException) { s_CatchCalled++; }

throw new ArgumentException();
}

Expand Down Expand Up @@ -97,7 +102,7 @@ public static void DoCallTryFinally()
}
finally { s_FinallyCalled++; }
}

[MethodImpl(MethodImplOptions.NoInlining)]
public static void ManualRaiseException()
{
Expand All @@ -111,7 +116,7 @@ public static void ManualRaiseException()
// TODO: test on Unix when implementing pinvoke inlining
s_SEHExceptionCatchCalled++;
#endif
}
}

public static int Main()
{
Expand All @@ -123,13 +128,13 @@ public static int Main()
s_WrongPInvokesExecuted++;
}
catch (DllNotFoundException) { GC.Collect(); s_CatchCalled++; }

try { DoCall(); }
catch (DllNotFoundException) { GC.Collect(); s_CatchCalled++; }

try { DoCallTryFinally(); }
catch (DllNotFoundException) { GC.Collect(); s_CatchCalled++; }

try { DoCallTryCatch(true); }
catch (ArgumentException) { GC.Collect(); s_OtherExceptionCatchCalled++; }

Expand All @@ -138,10 +143,10 @@ public static int Main()

try { DoCallTryRethrowDifferentExceptionInCatch(); }
catch (InvalidOperationException) { GC.Collect(); s_OtherExceptionCatchCalled++; }

ManualRaiseException();
}

SetResolve();

for(int i = 0; i < s_LoopCounter; i++)
Expand All @@ -152,11 +157,11 @@ public static int Main()

try { DoCallTryCatch(false); }
catch (ArgumentException) { GC.Collect(); s_OtherExceptionCatchCalled++; }

ManualRaiseException();
}
if (s_FinallyCalled == s_LoopCounter &&

if (s_FinallyCalled == s_LoopCounter &&
s_CatchCalled == (s_LoopCounter * 7) &&
s_OtherExceptionCatchCalled == (s_LoopCounter * 3) &&
s_WrongPInvokesExecuted == 0 &&
Expand All @@ -166,7 +171,7 @@ public static int Main()
Console.WriteLine("PASS");
return 100;
}

Console.WriteLine("s_FinallyCalled = " + s_FinallyCalled);
Console.WriteLine("s_CatchCalled = " + s_CatchCalled);
Console.WriteLine("s_OtherExceptionCatchCalled = " + s_OtherExceptionCatchCalled);
Expand All @@ -179,7 +184,7 @@ public static int Main()
[DllImport(NativeLibraryToLoad.InvalidName)]
[DefaultDllImportSearchPaths(DllImportSearchPath.System32)]
static extern int NativeSum(int arg1, int arg2);

#if WINDOWS
[DllImport("kernel32")]
static extern void RaiseException(uint dwExceptionCode, uint dwExceptionFlags, uint nNumberOfArguments, IntPtr lpArguments);
Expand Down
Loading