Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,11 @@ private static void AddWrapperToReferenceTrackerHandleCache(NativeObjectWrapper
}
}

internal void RemoveWrappersFromCache(IEnumerable<NativeObjectWrapper> wrappers)
{
_rcwCache.RemoveAll(wrappers);
}

private sealed class RcwCache
{
private readonly Lock _lock = new Lock(useTrivialWaits: true);
Expand Down Expand Up @@ -1363,20 +1368,37 @@ public void Remove(IntPtr comPointer, NativeObjectWrapper wrapper)
{
lock (_lock)
{
// TryGetOrCreateObjectForComInstanceInternal may have put a new entry into the cache
// in the time between the GC cleared the contents of the GC handle but before the
// NativeObjectWrapper finalizer ran.
// Only remove the entry if the target of the GC handle is the NativeObjectWrapper
// or is null (indicating that the corresponding NativeObjectWrapper has been scheduled for finalization).
if (_cache.TryGetValue(comPointer, out GCHandle cachedRef)
&& (wrapper == cachedRef.Target
|| cachedRef.Target is null))
Remove_Locked(comPointer, wrapper);
}
}

public void RemoveAll(IEnumerable<NativeObjectWrapper> wrappers)
{
lock (_lock)
{
foreach (NativeObjectWrapper wrapper in wrappers)
{
_cache.Remove(comPointer);
cachedRef.Free();
Remove_Locked(wrapper.ExternalComObject, wrapper);
}
}
}

private void Remove_Locked(IntPtr comPointer, NativeObjectWrapper wrapper)
{
// This method is used in a scenario where we already have a lock on the cache, so we can skip acquiring the lock again.
// TryGetOrCreateObjectForComInstanceInternal may have put a new entry into the cache
// in the time between the GC cleared the contents of the GC handle but before the
// NativeObjectWrapper finalizer ran.
// Only remove the entry if the target of the GC handle is the NativeObjectWrapper
// or is null (indicating that the corresponding NativeObjectWrapper has been scheduled for finalization).
if (_cache.TryGetValue(comPointer, out GCHandle cachedRef)
&& (wrapper == cachedRef.Target
|| cachedRef.Target is null))
{
_cache.Remove(comPointer);
cachedRef.Free();
}
}
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ internal static void ReleaseExternalObjectsFromCurrentThread()

IntPtr contextToken = GetContextToken();

List<object> objects = new List<object>();
List<ReferenceTrackerNativeObjectWrapper> wrappersToRemove = [];
List<object> objects = [];

// Here we aren't part of a GC callback, so other threads can still be running
// who are adding and removing from the collection. This means we can possibly race
Expand All @@ -76,10 +77,20 @@ internal static void ReleaseExternalObjectsFromCurrentThread()
if (nativeObjectWrapper != null &&
nativeObjectWrapper._contextToken == contextToken)
{
object? target = nativeObjectWrapper.ProxyHandle.Target;
if (target != null)
// If this object is associated with the global instance for tracker support,
// then we can request that instance to clear out the native object wrapper's state
// to ensure the object gets released now.
// Also, we will remove the wrappers from the cache to ensure a stale wrapper
// isn't returned in the future.
if (nativeObjectWrapper.ComWrappers == GlobalInstanceForTrackerSupport)
{
objects.Add(target);
wrappersToRemove.Add(nativeObjectWrapper);

object? target = nativeObjectWrapper.ProxyHandle.Target;
if (target != null)
{
objects.Add(target);
}
}

// Separate the wrapper from the tracker runtime prior to
Expand All @@ -89,6 +100,10 @@ internal static void ReleaseExternalObjectsFromCurrentThread()
}
}

// Remove the native object wrappers from the cache
// so we don't return released wrappers to the user if the native COM object
// happens to be reused.
GlobalInstanceForTrackerSupport.RemoveWrappersFromCache(wrappersToRemove);
GlobalInstanceForTrackerSupport.ReleaseObjects(objects);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@ protected override void ReleaseObjects(IEnumerable objects)
Assert.NotNull(o);
}

throw new Exception() { HResult = ReleaseObjectsCallAck };
if (ReturnInvalid)
{
throw new Exception() { HResult = ReleaseObjectsCallAck };
}
}

private unsafe ComInterfaceEntry* ComputeVtablesForTestObject(Test obj, out int count)
Expand Down Expand Up @@ -461,6 +464,34 @@ private static void ValidateNotifyEndOfReferenceTrackingOnThread()
// Trigger the thread lifetime end API and verify the callback occurs.
int hr = MockReferenceTrackerRuntime.Trigger_NotifyEndOfReferenceTrackingOnThread();
Assert.Equal(GlobalComWrappers.ReleaseObjectsCallAck, hr);

// Validate that the RCW cache gets cleared when we call NotifyEndOfReferenceTrackingOnThread
GlobalComWrappers.Instance.ReturnInvalid = false;
IntPtr tracker = MockReferenceTrackerRuntime.CreateTrackerObject();
try
{
object rcw = GlobalComWrappers.Instance.GetOrCreateObjectForComInstance(tracker, CreateObjectFlags.TrackerObject);

// Make sure that we keep the tracker object alive even after we notify end of reference tracking on this thread.
Marshal.AddRef(tracker);

const int S_OK = 0;
Assert.Equal(S_OK, MockReferenceTrackerRuntime.Trigger_NotifyEndOfReferenceTrackingOnThread());

// We should get a new RCW after we've released the reference tracked objects on this thread.
object rcwNew = GlobalComWrappers.Instance.GetOrCreateObjectForComInstance(tracker, CreateObjectFlags.TrackerObject);

Assert.NotSame(rcw, rcwNew);
}
finally
{
if (tracker != IntPtr.Zero)
{
// Release the extra ref we added above and the original ref from CreateTrackerObject.
Marshal.Release(tracker);
Marshal.Release(tracker);
}
}
}
}
}
Expand Down
Loading