Skip to content

Commit

Permalink
StringMarshalling behavior override tests (#86963)
Browse files Browse the repository at this point in the history
Adds tests to make sure MarshalAs and MarshalUsing override the default string marshalling behavior set by the interface-wide StringMarshalling. No changes were necessary in the generator.
  • Loading branch information
jtschuster committed Jun 2, 2023
1 parent bebd644 commit 8ee61fe
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public unsafe partial class StringMarshallingTests
[LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_utf16_marshalling")]
public static partial void* NewIUtf16Marshalling();

[LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_string_marshalling_override")]
public static partial void* NewStringMarshallingOverride();

[GeneratedComClass]
internal partial class Utf8MarshalledClass : IUTF8Marshalling
{
Expand Down Expand Up @@ -107,5 +110,26 @@ public void RcwToCcw()
customUtf16ComObject.SetString("Set from COM object");
Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString());
}

[Fact]
public void MarshalAsAndMarshalUsingOverrideStringMarshalling()
{
var ptr = NewStringMarshallingOverride();
var cw = new StrategyBasedComWrappers();
var obj = cw.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
var stringMarshallingOverride = (IStringMarshallingOverride)obj;
Assert.Equal("Your string: MyUtf8String", stringMarshallingOverride.StringMarshallingUtf8("MyUtf8String"));
Assert.Equal("Your string: MyLPWStrString", stringMarshallingOverride.MarshalAsLPWString("MyLPWStrString"));
Assert.Equal("Your string: MyUtf16String", stringMarshallingOverride.MarshalUsingUtf16("MyUtf16String"));

// Make sure the shadowing methods generated for the derived interface also follow the rules
var stringMarshallingOverrideDerived = (IStringMarshallingOverrideDerived)obj;
Assert.Equal("Your string: MyUtf8String", stringMarshallingOverrideDerived.StringMarshallingUtf8("MyUtf8String"));
Assert.Equal("Your string: MyLPWStrString", stringMarshallingOverrideDerived.MarshalAsLPWString("MyLPWStrString"));
Assert.Equal("Your string: MyUtf16String", stringMarshallingOverrideDerived.MarshalUsingUtf16("MyUtf16String"));
Assert.Equal("Your string 2: MyUtf8String", stringMarshallingOverrideDerived.StringMarshallingUtf8_2("MyUtf8String"));
Assert.Equal("Your string 2: MyLPWStrString", stringMarshallingOverrideDerived.MarshalAsLPWString_2("MyLPWStrString"));
Assert.Equal("Your string 2: MyUtf16String", stringMarshallingOverrideDerived.MarshalUsingUtf16_2("MyUtf16String"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using System.Text;
using System.Threading.Tasks;
using SharedTypes.ComInterfaces;
using static System.Runtime.InteropServices.ComWrappers;

namespace NativeExports.ComInterfaceGenerator
{
public unsafe partial class StringMarshallingOverride
{
[UnmanagedCallersOnly(EntryPoint = "new_string_marshalling_override")]
public static void* CreateStringMarshallingOverrideObject()
{
MyComWrapper cw = new();
var myObject = new Implementation();
nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
return (void*)ptr;
}

class MyComWrapper : ComWrappers
{
static void* _s_comInterfaceVTable = null;
static void* S_VTable
{
get
{
if (_s_comInterfaceVTable != null)
return _s_comInterfaceVTable;
void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 6);
GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
vtable[0] = (void*)fpQueryInterface;
vtable[1] = (void*)fpAddReference;
vtable[2] = (void*)fpRelease;
vtable[3] = (delegate* unmanaged<void*, byte*, byte**, int>)&Implementation.ABI.StringMarshallingUtf8;
vtable[4] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalAsLPWStr;
vtable[5] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalUsingUtf16;
_s_comInterfaceVTable = vtable;
return _s_comInterfaceVTable;
}
}

static void* _s_derivedVTable = null;
static void* S_DerivedVTable
{
get
{
if (_s_comInterfaceVTable != null)
return _s_comInterfaceVTable;
void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 9);
GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
vtable[0] = (void*)fpQueryInterface;
vtable[1] = (void*)fpAddReference;
vtable[2] = (void*)fpRelease;
vtable[3] = (delegate* unmanaged<void*, byte*, byte**, int>)&Implementation.ABI.StringMarshallingUtf8;
vtable[4] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalAsLPWStr;
vtable[5] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalUsingUtf16;
vtable[6] = (delegate* unmanaged<void*, byte*, byte**, int>)&Implementation.ABI.StringMarshallingUtf8_2;
vtable[7] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalAsLPWStr_2;
vtable[8] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalUsingUtf16_2;
_s_comInterfaceVTable = vtable;
return _s_comInterfaceVTable;
}
}

protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
{
if (obj is IStringMarshallingOverrideDerived)
{
ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Implementation), sizeof(ComInterfaceEntry) * 2);
comInterfaceEntry[0].IID = new Guid(IStringMarshallingOverrideDerived._guid);
comInterfaceEntry[0].Vtable = (nint)S_DerivedVTable;
comInterfaceEntry[1].IID = new Guid(IStringMarshallingOverride._guid);
comInterfaceEntry[1].Vtable = (nint)S_VTable;
count = 2;
return comInterfaceEntry;
}
if (obj is IStringMarshallingOverride)
{
ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Implementation), sizeof(ComInterfaceEntry));
comInterfaceEntry->IID = new Guid(IStringMarshallingOverride._guid);
comInterfaceEntry->Vtable = (nint)S_VTable;
count = 1;
return comInterfaceEntry;
}
count = 0;
return null;
}

protected override object? CreateObject(nint externalComObject, CreateObjectFlags flags) => throw new NotImplementedException();
protected override void ReleaseObjects(IEnumerable objects) => throw new NotImplementedException();
}

partial class Implementation : IStringMarshallingOverride, IStringMarshallingOverrideDerived
{
string _data = "Your string: ";
string IStringMarshallingOverride.StringMarshallingUtf8(string input) => _data + input;
string IStringMarshallingOverride.MarshalAsLPWString(string input) => _data + input;
string IStringMarshallingOverride.MarshalUsingUtf16(string input) => _data + input;

string _data2 = "Your string 2: ";
string IStringMarshallingOverrideDerived.StringMarshallingUtf8_2(string input) => _data2 + input;
string IStringMarshallingOverrideDerived.MarshalAsLPWString_2(string input) => _data2 + input;
string IStringMarshallingOverrideDerived.MarshalUsingUtf16_2(string input) => _data2 + input;

// Provides function pointers in the COM format to use in COM VTables
public static class ABI
{
[UnmanagedCallersOnly]
public static int StringMarshallingUtf8(void* @this, byte* input, byte** output)
{
try
{
string inputStr = Utf8StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverride>((ComInterfaceDispatch*)@this).StringMarshallingUtf8(inputStr);
*output = Utf8StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int MarshalAsLPWStr(void* @this, ushort* input, ushort** output)
{
try
{
string inputStr = Utf16StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverride>((ComInterfaceDispatch*)@this).MarshalAsLPWString(inputStr);
*output = Utf16StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int MarshalUsingUtf16(void* @this, ushort* input, ushort** output)
{
try
{
string inputStr = Utf16StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverride>((ComInterfaceDispatch*)@this).MarshalUsingUtf16(inputStr);
*output = Utf16StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int StringMarshallingUtf8_2(void* @this, byte* input, byte** output)
{
try
{
string inputStr = Utf8StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverrideDerived>((ComInterfaceDispatch*)@this).StringMarshallingUtf8_2(inputStr);
*output = Utf8StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int MarshalAsLPWStr_2(void* @this, ushort* input, ushort** output)
{
try
{
string inputStr = Utf16StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverrideDerived>((ComInterfaceDispatch*)@this).MarshalAsLPWString_2(inputStr);
*output = Utf16StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int MarshalUsingUtf16_2(void* @this, ushort* input, ushort** output)
{
try
{
string inputStr = Utf16StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverrideDerived>((ComInterfaceDispatch*)@this).MarshalUsingUtf16_2(inputStr);
*output = Utf16StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using System.Text;
using System.Threading.Tasks;

namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface(StringMarshalling = System.Runtime.InteropServices.StringMarshalling.Utf8)]
[Guid(_guid)]
internal partial interface IStringMarshallingOverride
{
public const string _guid = "5146B7DB-0588-469B-B8E5-B38090A2FC15";
string StringMarshallingUtf8(string input);

[return: MarshalAs(UnmanagedType.LPWStr)]
string MarshalAsLPWString([MarshalAs(UnmanagedType.LPWStr)] string input);

[return: MarshalUsing(typeof(Utf16StringMarshaller))]
string MarshalUsingUtf16([MarshalUsing(typeof(Utf16StringMarshaller))] string input);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices.Marshalling;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;

namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface(StringMarshalling = StringMarshalling.Utf8)]
[Guid(_guid)]
internal partial interface IStringMarshallingOverrideDerived : IStringMarshallingOverride
{
public new const string _guid = "3AFFE3FD-D11E-4195-8250-0C73321977A0";
string StringMarshallingUtf8_2(string input);

[return: MarshalAs(UnmanagedType.LPWStr)]
string MarshalAsLPWString_2([MarshalAs(UnmanagedType.LPWStr)] string input);

[return: MarshalUsing(typeof(Utf16StringMarshaller))]
string MarshalUsingUtf16_2([MarshalUsing(typeof(Utf16StringMarshaller))] string input);
}
}

0 comments on commit 8ee61fe

Please sign in to comment.