Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make [In, Optional] managed struct parameters actually optional in friendly overloads #579

Merged
merged 2 commits into from
Jun 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/ArrayTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs
return new TypeSyntaxAndMarshaling(PointerType(element.Type));
}
}

internal override bool? IsValueType(TypeSyntaxSettings inputs) => false;
}
56 changes: 48 additions & 8 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ public class Generator : IDisposable
private readonly CSharpParseOptions? parseOptions;
private readonly bool canUseSpan;
private readonly bool canCallCreateSpan;
private readonly bool canUseUnsafeAsRef;
private readonly bool canUseUnsafeNullRef;
private readonly bool getDelegateForFunctionPointerGenericExists;
private readonly bool generateSupportedOSPlatformAttributes;
private readonly bool generateSupportedOSPlatformAttributesOnInterfaces; // only supported on net6.0 (https://github.com/dotnet/runtime/pull/48838)
Expand Down Expand Up @@ -367,6 +369,8 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option

this.canUseSpan = this.compilation?.GetTypeByMetadataName(typeof(Span<>).FullName) is not null;
this.canCallCreateSpan = this.compilation?.GetTypeByMetadataName(typeof(MemoryMarshal).FullName)?.GetMembers("CreateSpan").Any() is true;
this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true;
this.canUseUnsafeNullRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("NullRef").Any() is true;
this.getDelegateForFunctionPointerGenericExists = this.compilation?.GetTypeByMetadataName(typeof(Marshal).FullName)?.GetMembers(nameof(Marshal.GetDelegateForFunctionPointer)).Any(m => m is IMethodSymbol { IsGenericMethod: true }) is true;
this.generateDefaultDllImportSearchPathsAttribute = this.compilation?.GetTypeByMetadataName(typeof(DefaultDllImportSearchPathsAttribute).FullName) is object;
if (this.compilation?.GetTypeByMetadataName("System.Runtime.Versioning.SupportedOSPlatformAttribute") is { } attribute
Expand Down Expand Up @@ -4315,6 +4319,14 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
static ExpressionSyntax GetSpanLength(ExpressionSyntax span) => MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, span, IdentifierName(nameof(Span<int>.Length)));
bool isReleaseMethod = this.MetadataIndex.ReleaseMethods.Contains(externMethodDeclaration.Identifier.ValueText);

TypeSyntaxSettings parameterTypeSyntaxSettings = overloadOf switch
{
FriendlyOverloadOf.ExternMethod => this.externSignatureTypeSettings,
FriendlyOverloadOf.StructMethod => this.extensionMethodSignatureTypeSettings,
FriendlyOverloadOf.InterfaceMethod => this.extensionMethodSignatureTypeSettings,
_ => throw new NotSupportedException(overloadOf.ToString()),
};

MethodSignature<TypeHandleInfo> originalSignature = methodDefinition.DecodeSignature(SignatureHandleProvider.Instance, null);
var parameters = externMethodDeclaration.ParameterList.Parameters.Select(StripAttributes).ToList();
var lengthParamUsedBy = new Dictionary<int, int>();
Expand Down Expand Up @@ -4350,7 +4362,10 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}

TypeHandleInfo parameterTypeInfo = originalSignature.ParameterTypes[param.SequenceNumber - 1];
if (this.IsManagedType(parameterTypeInfo) && (externParam.Modifiers.Any(SyntaxKind.OutKeyword) || externParam.Modifiers.Any(SyntaxKind.RefKeyword)))
bool isManagedParameterType = this.IsManagedType(parameterTypeInfo);
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);

if (isManagedParameterType && (externParam.Modifiers.Any(SyntaxKind.OutKeyword) || externParam.Modifiers.Any(SyntaxKind.RefKeyword)))
{
bool hasOut = externParam.Modifiers.Any(SyntaxKind.OutKeyword);
arguments[param.SequenceNumber - 1] = arguments[param.SequenceNumber - 1].WithRefKindKeyword(TokenWithSpace(hasOut ? SyntaxKind.OutKeyword : SyntaxKind.RefKeyword));
Expand All @@ -4361,7 +4376,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
{
signatureChanged = true;

IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax typeDefHandleName = IdentifierName(externParam.Identifier.ValueText + "Local");

// out SafeHandle
Expand All @@ -4370,7 +4384,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
.WithModifiers(TokenList(TokenWithSpace(SyntaxKind.OutKeyword)));

// HANDLE SomeLocal;
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(pointedElementInfo.ToTypeSyntax(this.externSignatureTypeSettings, null).Type).AddVariables(
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, null).Type).AddVariables(
VariableDeclarator(typeDefHandleName.Identifier))));

// Argument: &SomeLocal
Expand All @@ -4387,7 +4401,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}
else if (isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod))
{
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax typeDefHandleName = IdentifierName(externParam.Identifier.ValueText + "Local");
signatureChanged = true;

Expand Down Expand Up @@ -4480,7 +4493,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}
}

IdentifierNameSyntax origName = IdentifierName(parameters[param.SequenceNumber - 1].Identifier.ValueText);
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
if (isArray)
{
Expand Down Expand Up @@ -4612,7 +4624,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}
else if (isIn && !isOut && isConst && externParam.Type is QualifiedNameSyntax { Right: { Identifier: { ValueText: "PCWSTR" } } })
{
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
signatureChanged = true;
parameters[param.SequenceNumber - 1] = externParam
Expand All @@ -4623,7 +4634,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}
else if (isIn && !isOut && isConst && externParam.Type is QualifiedNameSyntax { Right: { Identifier: { ValueText: "PCSTR" } } })
{
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
signatureChanged = true;
parameters[param.SequenceNumber - 1] = externParam
Expand All @@ -4649,7 +4659,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}
else if (isIn && isOut && this.canUseSpan && externParam.Type is QualifiedNameSyntax { Right: { Identifier: { ValueText: "PWSTR" } } })
{
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax localName = IdentifierName("p" + origName);
IdentifierNameSyntax localWstrName = IdentifierName("wstr" + origName);
signatureChanged = true;
Expand Down Expand Up @@ -4689,6 +4698,37 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, localWstrName, IdentifierName("Length"))))))));
}
else if (isIn && isOptional && !isOut && isManagedParameterType && parameterTypeInfo is PointerTypeHandleInfo ptrInfo && ptrInfo.ElementType.IsValueType(parameterTypeSyntaxSettings) is true && this.canUseUnsafeAsRef)
{
// The extern method couldn't have exposed the parameter as a pointer because the type is managed.
// It would have exposed as an `in` modifier, and non-optional. But we can expose as optional anyway.
signatureChanged = true;
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
parameters[param.SequenceNumber - 1] = parameters[param.SequenceNumber - 1]
.WithType(NullableType(externParam.Type).WithTrailingTrivia(TriviaList(Space)))
.WithModifiers(TokenList()); // drop the `in` modifier.
leadingStatements.Add(
LocalDeclarationStatement(VariableDeclaration(externParam.Type)
.AddVariables(VariableDeclarator(localName.Identifier).WithInitializer(
EqualsValueClause(ConditionalExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName("HasValue")),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName("Value")),
DefaultExpression(externParam.Type)))))));

// We can't pass in null, but we can be fancy to achieve the same effect.
// Unsafe.NullRef<TParamType>() or Unsafe.AsRef<TParamType>(null), depending on what's available.
ExpressionSyntax nullRef = this.canUseUnsafeNullRef
? InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(Unsafe)), GenericName("NullRef", TypeArgumentList().AddArguments(externParam.Type))),
ArgumentList())
: InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(Unsafe)), GenericName(nameof(Unsafe.AsRef), TypeArgumentList().AddArguments(externParam.Type))),
ArgumentList().AddArguments(Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))));
arguments[param.SequenceNumber - 1] = Argument(ConditionalExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName("HasValue")),
localName,
nullRef));
}
}

TypeSyntax? returnSafeHandleType = originalSignature.ReturnType is HandleTypeHandleInfo returnTypeHandleInfo
Expand Down
47 changes: 47 additions & 0 deletions src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,53 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs
return new TypeSyntaxAndMarshaling(syntax);
}

internal override bool? IsValueType(TypeSyntaxSettings inputs)
{
Generator generator = inputs.Generator ?? throw new ArgumentException("Generator required.");
TypeDefinitionHandle typeDefHandle = default;
switch (this.Handle.Kind)
{
case HandleKind.TypeDefinition:
typeDefHandle = (TypeDefinitionHandle)this.Handle;
break;
case HandleKind.TypeReference:
if (generator.TryGetTypeDefHandle((TypeReferenceHandle)this.Handle, out QualifiedTypeDefinitionHandle qualifiedTypeDefHandle))
{
generator = qualifiedTypeDefHandle.Generator;
typeDefHandle = qualifiedTypeDefHandle.DefinitionHandle;
}

break;
default:
return null;
}

if (typeDefHandle.IsNil)
{
return null;
}

TypeDefinition typeDef = generator.Reader.GetTypeDefinition(typeDefHandle);
generator.GetBaseTypeInfo(typeDef, out StringHandle baseName, out StringHandle baseNamespace);
if (generator.Reader.StringComparer.Equals(baseName, nameof(ValueType)) && generator.Reader.StringComparer.Equals(baseNamespace, nameof(System)))
{
// When marshaling, the VARIANT struct becomes object, which is *not* a value type.
if (inputs.AllowMarshaling && generator.Reader.StringComparer.Equals(typeDef.Name, "VARIANT"))
{
return false;
}

return true;
}

if (generator.Reader.StringComparer.Equals(baseName, nameof(Enum)) && generator.Reader.StringComparer.Equals(baseNamespace, nameof(System)))
{
return true;
}

return false;
}

private static bool TryMarshalAsObject(TypeSyntaxSettings inputs, string name, [NotNullWhen(true)] out MarshalAsAttribute? marshalAs)
{
if (inputs.AllowMarshaling)
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/PointerTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs
return new TypeSyntaxAndMarshaling(PointerType(elementTypeDetails.Type));
}

internal override bool? IsValueType(TypeSyntaxSettings inputs) => false;

private bool TryGetElementTypeDefinition(Generator generator, out TypeDefinition typeDef)
{
if (this.ElementType is HandleTypeHandleInfo handleElement)
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.Windows.CsWin32/PrimitiveTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ internal record PrimitiveTypeHandleInfo(PrimitiveTypeCode PrimitiveTypeCode) : T
internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs, CustomAttributeHandleCollection? customAttributes, ParameterAttributes parameterAttributes)
=> new TypeSyntaxAndMarshaling(ToTypeSyntax(this.PrimitiveTypeCode, inputs.PreferNativeInt));

internal override bool? IsValueType(TypeSyntaxSettings inputs)
{
return this.PrimitiveTypeCode is not PrimitiveTypeCode.Object or PrimitiveTypeCode.Void;
}

internal static TypeSyntax ToTypeSyntax(PrimitiveTypeCode typeCode, bool preferNativeInt)
{
return typeCode switch
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/TypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ internal abstract record TypeHandleInfo

internal abstract TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs, CustomAttributeHandleCollection? customAttributes, ParameterAttributes parameterAttributes = default);

internal abstract bool? IsValueType(TypeSyntaxSettings inputs);

protected static bool TryGetSimpleName(TypeSyntax nameSyntax, [NotNullWhen(true)] out string? simpleName)
{
if (nameSyntax is QualifiedNameSyntax qname)
Expand Down
25 changes: 25 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ public GeneratorTests(ITestOutputHelper logger)
new object[] { "net6.0" },
};

public static IEnumerable<object[]> TFMDataNoNetFx35 =>
new object[][]
{
new object[] { "net472" },
new object[] { "netstandard2.0" },
new object[] { "net6.0" },
};

public static Platform[] SpecificCpuArchitectures =>
new Platform[]
{
Expand Down Expand Up @@ -254,6 +262,7 @@ public void COMInterfaceWithSupportedOSPlatform(bool net60, bool allowMarshaling
"CertFreeCertificateChainList", // double pointer extern method
"D3DGetTraceInstructionOffsets", // SizeParamIndex
"PlgBlt", // SizeConst
"IWebBrowser", // Navigate method has an [In, Optional] object parameter
"ENABLE_TRACE_PARAMETERS_V1", // bad xml created at some point.
"JsRuntimeVersion", // An enum that has an extra member in a separate header file.
"ReportEvent", // Failed at one point
Expand Down Expand Up @@ -2723,6 +2732,21 @@ public void OpensMetadataForSharedReading()
Assert.True(this.generator.TryGenerate("CreateFile", CancellationToken.None));
}

[Theory]
[MemberData(nameof(TFMDataNoNetFx35))]
public void MiniDumpWriteDump_AllOptionalPointerParametersAreOptional(string tfm)
{
// We split on TFMs because the generated code is slightly different depending on TFM.
this.compilation = this.starterCompilations[tfm].WithOptions(this.compilation.Options.WithPlatform(Platform.X64));
this.generator = this.CreateGenerator();
Assert.True(this.generator.TryGenerate("MiniDumpWriteDump", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();

MethodDeclarationSyntax externMethod = Assert.Single(this.FindGeneratedMethod("MiniDumpWriteDump"), m => !m.Modifiers.Any(SyntaxKind.ExternKeyword));
Assert.All(externMethod.ParameterList.Parameters.Reverse().Take(3), p => Assert.IsType<NullableTypeSyntax>(p.Type));
}

[Fact]
public void ContainsIllegalCharactersForAPIName_InvisibleCharacters()
{
Expand Down Expand Up @@ -3029,6 +3053,7 @@ private static class MyReferenceAssemblies
#pragma warning disable SA1202 // Elements should be ordered by access
private static readonly ImmutableArray<PackageIdentity> AdditionalPackages = ImmutableArray.Create(
new PackageIdentity("Microsoft.Windows.SDK.Contracts", "10.0.19041.1"),
new PackageIdentity("System.Memory", "4.5.4"),
new PackageIdentity("Microsoft.Win32.Registry", "5.0.0"));

internal static readonly ReferenceAssemblies NetStandard20 = ReferenceAssemblies.NetStandard.NetStandard20.AddPackages(AdditionalPackages.Add(new PackageIdentity("System.Memory", "4.5.4")));
Expand Down