diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index 9df46043178eda..d4571f02fa8b66 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -1287,6 +1287,11 @@ private static void AddWrapperToReferenceTrackerHandleCache(NativeObjectWrapper } } + internal void RemoveWrappersFromCache(IEnumerable wrappers) + { + _rcwCache.RemoveAll(wrappers); + } + private sealed class RcwCache { private readonly Lock _lock = new Lock(useTrivialWaits: true); @@ -1363,20 +1368,37 @@ public void Remove(IntPtr comPointer, NativeObjectWrapper wrapper) { lock (_lock) { - // TryGetOrCreateObjectForComInstanceInternal may have put a new entry into the cache - // in the time between the GC cleared the contents of the GC handle but before the - // NativeObjectWrapper finalizer ran. - // Only remove the entry if the target of the GC handle is the NativeObjectWrapper - // or is null (indicating that the corresponding NativeObjectWrapper has been scheduled for finalization). - if (_cache.TryGetValue(comPointer, out GCHandle cachedRef) - && (wrapper == cachedRef.Target - || cachedRef.Target is null)) + Remove_Locked(comPointer, wrapper); + } + } + + public void RemoveAll(IEnumerable wrappers) + { + lock (_lock) + { + foreach (NativeObjectWrapper wrapper in wrappers) { - _cache.Remove(comPointer); - cachedRef.Free(); + Remove_Locked(wrapper.ExternalComObject, wrapper); } } } + + private void Remove_Locked(IntPtr comPointer, NativeObjectWrapper wrapper) + { + // This method is used in a scenario where we already have a lock on the cache, so we can skip acquiring the lock again. + // TryGetOrCreateObjectForComInstanceInternal may have put a new entry into the cache + // in the time between the GC cleared the contents of the GC handle but before the + // NativeObjectWrapper finalizer ran. + // Only remove the entry if the target of the GC handle is the NativeObjectWrapper + // or is null (indicating that the corresponding NativeObjectWrapper has been scheduled for finalization). + if (_cache.TryGetValue(comPointer, out GCHandle cachedRef) + && (wrapper == cachedRef.Target + || cachedRef.Target is null)) + { + _cache.Remove(comPointer); + cachedRef.Free(); + } + } } /// diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/TrackerObjectManager.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/TrackerObjectManager.cs index df0053629c9ca6..385ff438a3d44d 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/TrackerObjectManager.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/TrackerObjectManager.cs @@ -61,7 +61,8 @@ internal static void ReleaseExternalObjectsFromCurrentThread() IntPtr contextToken = GetContextToken(); - List objects = new List(); + List wrappersToRemove = []; + List objects = []; // Here we aren't part of a GC callback, so other threads can still be running // who are adding and removing from the collection. This means we can possibly race @@ -76,10 +77,20 @@ internal static void ReleaseExternalObjectsFromCurrentThread() if (nativeObjectWrapper != null && nativeObjectWrapper._contextToken == contextToken) { - object? target = nativeObjectWrapper.ProxyHandle.Target; - if (target != null) + // If this object is associated with the global instance for tracker support, + // then we can request that instance to clear out the native object wrapper's state + // to ensure the object gets released now. + // Also, we will remove the wrappers from the cache to ensure a stale wrapper + // isn't returned in the future. + if (nativeObjectWrapper.ComWrappers == GlobalInstanceForTrackerSupport) { - objects.Add(target); + wrappersToRemove.Add(nativeObjectWrapper); + + object? target = nativeObjectWrapper.ProxyHandle.Target; + if (target != null) + { + objects.Add(target); + } } // Separate the wrapper from the tracker runtime prior to @@ -89,6 +100,10 @@ internal static void ReleaseExternalObjectsFromCurrentThread() } } + // Remove the native object wrappers from the cache + // so we don't return released wrappers to the user if the native COM object + // happens to be reused. + GlobalInstanceForTrackerSupport.RemoveWrappersFromCache(wrappersToRemove); GlobalInstanceForTrackerSupport.ReleaseObjects(objects); } diff --git a/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs b/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs index eaa93ee94fbd86..810b8215915818 100644 --- a/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs +++ b/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs @@ -177,7 +177,10 @@ protected override void ReleaseObjects(IEnumerable objects) Assert.NotNull(o); } - throw new Exception() { HResult = ReleaseObjectsCallAck }; + if (ReturnInvalid) + { + throw new Exception() { HResult = ReleaseObjectsCallAck }; + } } private unsafe ComInterfaceEntry* ComputeVtablesForTestObject(Test obj, out int count) @@ -461,6 +464,34 @@ private static void ValidateNotifyEndOfReferenceTrackingOnThread() // Trigger the thread lifetime end API and verify the callback occurs. int hr = MockReferenceTrackerRuntime.Trigger_NotifyEndOfReferenceTrackingOnThread(); Assert.Equal(GlobalComWrappers.ReleaseObjectsCallAck, hr); + + // Validate that the RCW cache gets cleared when we call NotifyEndOfReferenceTrackingOnThread + GlobalComWrappers.Instance.ReturnInvalid = false; + IntPtr tracker = MockReferenceTrackerRuntime.CreateTrackerObject(); + try + { + object rcw = GlobalComWrappers.Instance.GetOrCreateObjectForComInstance(tracker, CreateObjectFlags.TrackerObject); + + // Make sure that we keep the tracker object alive even after we notify end of reference tracking on this thread. + Marshal.AddRef(tracker); + + const int S_OK = 0; + Assert.Equal(S_OK, MockReferenceTrackerRuntime.Trigger_NotifyEndOfReferenceTrackingOnThread()); + + // We should get a new RCW after we've released the reference tracked objects on this thread. + object rcwNew = GlobalComWrappers.Instance.GetOrCreateObjectForComInstance(tracker, CreateObjectFlags.TrackerObject); + + Assert.NotSame(rcw, rcwNew); + } + finally + { + if (tracker != IntPtr.Zero) + { + // Release the extra ref we added above and the original ref from CreateTrackerObject. + Marshal.Release(tracker); + Marshal.Release(tracker); + } + } } } }