Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 67 additions & 10 deletions src/coreclr/vm/comcallablewrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4206,6 +4206,69 @@ ComMethodTable* ComCallWrapperTemplate::CreateComMethodTableForBasic(MethodTable
RETURN pComMT;
}

//--------------------------------------------------------------------------
// Returns TRUE if the parent's ComMethodTable for pItfMT can be reused for
// pClassMT. This requires that no class between pClassMT and pParentMT has
// re-implemented pItfMT in its dispatch map, and that the interface methods
// resolve to the same MethodDescs on both pClassMT and pParentMT.
//--------------------------------------------------------------------------
static bool CanShareComMethodTableWithParent(MethodTable* pClassMT, MethodTable* pParentMT, MethodTable* pItfMT)
{
CONTRACTL
{
THROWS;
GC_TRIGGERS;
MODE_ANY;
PRECONDITION(pClassMT != NULL && !pClassMT->IsInterface());
PRECONDITION(pParentMT != NULL && !pParentMT->IsInterface());
PRECONDITION(pItfMT != NULL && pItfMT->IsInterface());
}
CONTRACTL_END;

// Check for explicit interface re-implementations in the dispatch map.
MethodTable* pMT = pClassMT;
do
{
DispatchMap::EncodedMapIterator mapIt(pMT);
for (; mapIt.IsValid(); mapIt.Next())
{
DispatchMapEntry *pEntry = mapIt.Entry();
if (pMT->DispatchMapTypeMatchesMethodTable(pEntry->GetTypeID(), pItfMT))
{
return false;
}
}

pMT = pMT->GetParentMethodTable();
_ASSERTE(pMT != NULL);
}
while (pMT != pParentMT);

// Check that interface methods resolve to the same MethodDescs on both
// this class and pParentMT. With the baked-in dispatch target model, the
// ComMethodTable stores the resolved MethodDesc at layout time, so the
// table can only be shared if the targets are identical.
for (unsigned i = 0; i < pItfMT->GetNumVirtuals(); i++)
{
MethodDesc *pItfMD = pItfMT->GetMethodDescForSlot_NoThrow(i);
Comment thread
elinor-fung marked this conversation as resolved.
_ASSERTE(pItfMD != NULL);

if (pItfMD->IsAsyncMethod())
continue;

Comment thread
elinor-fung marked this conversation as resolved.
DispatchSlot childSlot(pClassMT->FindDispatchSlotForInterfaceMD(pItfMD, FALSE /* throwOnConflict */));
DispatchSlot parentSlot(pParentMT->FindDispatchSlotForInterfaceMD(pItfMD, FALSE /* throwOnConflict */));

if (childSlot.IsNull() || parentSlot.IsNull())
return false;

if (childSlot.GetMethodDesc() != parentSlot.GetMethodDesc())
return false;
}

return true;
}

//--------------------------------------------------------------------------
// Creates a ComMethodTable for an interface and stores it in the m_rgpIPtr array.
//--------------------------------------------------------------------------
Expand All @@ -4222,22 +4285,16 @@ ComMethodTable *ComCallWrapperTemplate::InitializeForInterface(MethodTable *pPar
ComMethodTable *pItfComMT = NULL;
if (m_pParent != NULL)
{
pItfComMT = m_pParent->GetComMTForItf(pItfMT);
if (pItfComMT != NULL)
// Check if we can reuse the parent's ComMethodTable for this interface.
ComMethodTable* pParentComMT = m_pParent->GetComMTForItf(pItfMT);
if (pParentComMT != NULL && CanShareComMethodTableWithParent(m_thClass.GetMethodTable(), pParentMT, pItfMT))
{
// if the parent COM MT is not a trivial aggregate, simple MethodTable slot check is enough
if (!m_thClass.GetMethodTable()->ImplementsInterfaceWithSameSlotsAsParent(pItfMT, pParentMT))
{
// the interface is implemented by parent but this class reimplemented
// its method(s) so we will need to build a new COM vtable for it
pItfComMT = NULL;
}
pItfComMT = pParentComMT;
}
}

if (pItfComMT == NULL)
{
// we couldn't use parent's vtable so we create a new one
pItfComMT = CreateComMethodTableForInterface(pItfMT);
}

Expand Down
33 changes: 0 additions & 33 deletions src/coreclr/vm/methodtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5977,39 +5977,6 @@ UINT32 MethodTable::LookupTypeID()
return AppDomain::GetCurrentDomain()->LookupTypeID(pMT);
}

//==========================================================================================
BOOL MethodTable::ImplementsInterfaceWithSameSlotsAsParent(MethodTable *pItfMT, MethodTable *pParentMT)
{
CONTRACTL
{
THROWS;
GC_TRIGGERS;
PRECONDITION(!IsInterface() && !pParentMT->IsInterface());
PRECONDITION(pItfMT->IsInterface());
} CONTRACTL_END;

MethodTable *pMT = this;
do
{
DispatchMap::EncodedMapIterator it(pMT);
for (; it.IsValid(); it.Next())
{
DispatchMapEntry *pCurEntry = it.Entry();
if (DispatchMapTypeMatchesMethodTable(pCurEntry->GetTypeID(), pItfMT))
{
// this class and its parents up to pParentMT must have no mappings for the interface
return FALSE;
}
}

pMT = pMT->GetParentMethodTable();
_ASSERTE(pMT != NULL);
}
while (pMT != pParentMT);

return TRUE;
}

#endif // !DACCESS_COMPILE

//==========================================================================================
Expand Down
5 changes: 0 additions & 5 deletions src/coreclr/vm/methodtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -2572,11 +2572,6 @@ class MethodTable
MethodTable *LookupDispatchMapType(DispatchMapTypeID typeID);
bool DispatchMapTypeMatchesMethodTable(DispatchMapTypeID typeID, MethodTable* pMT);

// Determines whether all methods in the given interface have their final implementing
// slot in a parent class. I.e. if this returns TRUE, it is trivial (no VSD lookup) to
// dispatch pItfMT methods on this class if one knows how to dispatch them on pParentMT.
BOOL ImplementsInterfaceWithSameSlotsAsParent(MethodTable *pItfMT, MethodTable *pParentMT);

// Try to resolve a given static virtual method override on this type. Return nullptr
// when not found.
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags, ClassLoadLevel level);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using Xunit;

[ComVisible(true)]
[Guid("A1111111-0000-0000-0000-000000000001")]
public interface IFoo
{
void DoWork();
}

[ComVisible(true)]
[Guid("A1111111-0000-0000-0000-000000000002")]
[ComDefaultInterface(typeof(IFoo))]
public class Foo : IFoo
{
public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Foo);
}

[ComVisible(true)]
[Guid("A1111111-0000-0000-0000-000000000003")]
[ComDefaultInterface(typeof(IFoo))]
public class FooDerived : Foo
{
public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(FooDerived);
}

[ComVisible(true)]
[Guid("B2222222-0000-0000-0000-000000000001")]
public interface IBar
{
void DoWork();
}

[ComVisible(true)]
[Guid("B2222222-0000-0000-0000-000000000002")]
[ComDefaultInterface(typeof(IBar))]
public class Bar : IBar
{
public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Bar);
}

[ComVisible(true)]
[Guid("B2222222-0000-0000-0000-000000000003")]
[ComDefaultInterface(typeof(IBar))]
public class BarDerived : Bar
{
public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(BarDerived);
}

/// <summary>
/// Tests that COM-to-CLR dispatch correctly resolves virtual method overrides
/// regardless of whether the base or derived class is accessed via COM first.
/// </summary>
public class VirtualMethodOverrideTest
{
internal static string? LastCalledType;

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
delegate int DoWorkDelegate(IntPtr pThis);

private static int CallDoWork(IntPtr pInterface, int slot)
{
IntPtr vtbl = Marshal.ReadIntPtr(pInterface);
IntPtr fnPtr = Marshal.ReadIntPtr(vtbl, slot * IntPtr.Size);
Assert.NotEqual(IntPtr.Zero, fnPtr);

var fn = Marshal.GetDelegateForFunctionPointer<DoWorkDelegate>(fnPtr);
return fn(pInterface);
}

[Fact]
public static void DerivedFirst()
{
int doWorkSlot = Marshal.GetStartComSlot(typeof(IFoo));
IntPtr pDerived = IntPtr.Zero;
IntPtr pBase = IntPtr.Zero;
try
{
pDerived = Marshal.GetComInterfaceForObject(new FooDerived(), typeof(IFoo));
pBase = Marshal.GetComInterfaceForObject(new Foo(), typeof(IFoo));

LastCalledType = null;
Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0);
Assert.Equal(nameof(FooDerived), LastCalledType);

LastCalledType = null;
Assert.True(CallDoWork(pBase, doWorkSlot) >= 0);
Assert.Equal(nameof(Foo), LastCalledType);
}
finally
{
if (pDerived != IntPtr.Zero)
Marshal.Release(pDerived);

if (pBase != IntPtr.Zero)
Marshal.Release(pBase);
}
Comment thread
elinor-fung marked this conversation as resolved.
}

[Fact]
public static void BaseFirst()
{
int doWorkSlot = Marshal.GetStartComSlot(typeof(IBar));
IntPtr pBase = IntPtr.Zero;
IntPtr pDerived = IntPtr.Zero;
try
{
pBase = Marshal.GetComInterfaceForObject(new Bar(), typeof(IBar));
pDerived = Marshal.GetComInterfaceForObject(new BarDerived(), typeof(IBar));

LastCalledType = null;
Assert.True(CallDoWork(pBase, doWorkSlot) >= 0);
Assert.Equal(nameof(Bar), LastCalledType);

LastCalledType = null;
Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0);
Assert.Equal(nameof(BarDerived), LastCalledType);
}
finally
{
if (pBase != IntPtr.Zero) Marshal.Release(pBase);
if (pDerived != IntPtr.Zero) Marshal.Release(pDerived);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<RequiresProcessIsolation>true</RequiresProcessIsolation>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<NativeAotIncompatible>true</NativeAotIncompatible>
</PropertyGroup>
<ItemGroup>
<Compile Include="VirtualMethodOverrideTest.cs" />
</ItemGroup>
</Project>