Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recognize the stateful marshaller shape #71355

Merged
merged 10 commits into from
Jun 29, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.Interop
public readonly record struct CustomTypeMarshallerData(
ManagedTypeInfo MarshallerType,
ManagedTypeInfo NativeType,
bool HasState,
MarshallerShape Shape,
bool IsStrictlyBlittable,
ManagedTypeInfo? BufferElementType);
Expand Down Expand Up @@ -73,6 +74,11 @@ public static bool IsLinearCollectionEntryPoint(ITypeSymbol entryPointType)
return false;
}

public static bool HasEntryPointMarshallerAttribute(ITypeSymbol entryPointType)
{
return entryPointType.GetAttributes().Any(attr => attr.AttributeClass.ToDisplayString() == TypeNames.CustomMarshallerAttribute);
}

public static bool TryGetMarshallersFromEntryType(
INamedTypeSymbol entryPointType,
ITypeSymbol managedType,
Expand Down Expand Up @@ -288,7 +294,20 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault

private static CustomTypeMarshallerData? GetMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, Compilation compilation)
{
(MarshallerShape shape, Dictionary<MarshallerShape, IMethodSymbol> methodsByShape) = MarshallerShapeHelper.GetShapeForType(marshallerType, managedType, compilation);
if (marshallerType is { IsStatic: true, TypeKind: TypeKind.Class })
{
return GetStatelessMarshallerDataForType(marshallerType, direction, managedType, compilation);
}
if (marshallerType.IsValueType)
{
return GetStatefulMarshallerDataForType(marshallerType, direction, managedType, compilation);
}
return null;
}

private static CustomTypeMarshallerData? GetStatelessMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, Compilation compilation)
{
(MarshallerShape shape, Dictionary<MarshallerShape, IMethodSymbol> methodsByShape) = StatelessMarshallerShapeHelper.GetShapeForType(marshallerType, managedType, compilation);

ITypeSymbol? nativeType = null;
if (direction.HasFlag(MarshallingDirection.ManagedToUnmanaged))
Expand Down Expand Up @@ -339,6 +358,56 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault
return new CustomTypeMarshallerData(
ManagedTypeInfo.CreateTypeInfoForTypeSymbol(marshallerType),
ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeType),
HasState: false,
shape,
nativeType.IsStrictlyBlittable(),
bufferElementType);
}

private static CustomTypeMarshallerData? GetStatefulMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, Compilation compilation)
{
(MarshallerShape shape, StatefulMarshallerShapeHelper.MarshallerMethods methods) = StatefulMarshallerShapeHelper.GetShapeForType(marshallerType, managedType, compilation);

ITypeSymbol? nativeType = null;
if (direction.HasFlag(MarshallingDirection.ManagedToUnmanaged))
{
if (!shape.HasFlag(MarshallerShape.CallerAllocatedBuffer) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
return null;

if (methods.ToUnmanaged is not null)
{
nativeType = methods.ToUnmanaged.ReturnType;
}
}

if (nativeType is null && direction.HasFlag(MarshallingDirection.UnmanagedToManaged))
{
if (!shape.HasFlag(MarshallerShape.GuaranteedUnmarshal) && !shape.HasFlag(MarshallerShape.ToManaged))
return null;

if (methods.FromUnmanaged is not null)
{
nativeType = methods.FromUnmanaged.Parameters[0].Type;
}
}

// Bidirectional requires ToUnmanaged without the caller-allocated buffer
if (direction.HasFlag(MarshallingDirection.Bidirectional) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
return null;

if (nativeType is null)
return null;

ManagedTypeInfo bufferElementType = null;
if (methods.FromManagedWithBuffer is not null)
{
bufferElementType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(((INamedTypeSymbol)methods.FromManagedWithBuffer.Parameters[1].Type).TypeArguments[0]);
}

return new CustomTypeMarshallerData(
ManagedTypeInfo.CreateTypeInfoForTypeSymbol(marshallerType),
ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeType),
HasState: true,
shape,
nativeType.IsStrictlyBlittable(),
bufferElementType);
Expand Down
Loading