Skip to content

Commit

Permalink
Merge pull request #579 from microsoft/fix578
Browse files Browse the repository at this point in the history
Make `[In, Optional]` managed struct parameters actually optional in friendly overloads
  • Loading branch information
AArnott committed Jun 5, 2022
2 parents 305ea3f + 9d451d3 commit 6ea3f44
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/ArrayTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs
return new TypeSyntaxAndMarshaling(PointerType(element.Type));
}
}

internal override bool? IsValueType(TypeSyntaxSettings inputs) => false;
}
56 changes: 48 additions & 8 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ public class Generator : IDisposable
private readonly CSharpParseOptions? parseOptions;
private readonly bool canUseSpan;
private readonly bool canCallCreateSpan;
private readonly bool canUseUnsafeAsRef;
private readonly bool canUseUnsafeNullRef;
private readonly bool getDelegateForFunctionPointerGenericExists;
private readonly bool generateSupportedOSPlatformAttributes;
private readonly bool generateSupportedOSPlatformAttributesOnInterfaces; // only supported on net6.0 (https://github.com/dotnet/runtime/pull/48838)
Expand Down Expand Up @@ -367,6 +369,8 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option

this.canUseSpan = this.compilation?.GetTypeByMetadataName(typeof(Span<>).FullName) is not null;
this.canCallCreateSpan = this.compilation?.GetTypeByMetadataName(typeof(MemoryMarshal).FullName)?.GetMembers("CreateSpan").Any() is true;
this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true;
this.canUseUnsafeNullRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("NullRef").Any() is true;
this.getDelegateForFunctionPointerGenericExists = this.compilation?.GetTypeByMetadataName(typeof(Marshal).FullName)?.GetMembers(nameof(Marshal.GetDelegateForFunctionPointer)).Any(m => m is IMethodSymbol { IsGenericMethod: true }) is true;
this.generateDefaultDllImportSearchPathsAttribute = this.compilation?.GetTypeByMetadataName(typeof(DefaultDllImportSearchPathsAttribute).FullName) is object;
if (this.compilation?.GetTypeByMetadataName("System.Runtime.Versioning.SupportedOSPlatformAttribute") is { } attribute
Expand Down Expand Up @@ -4315,6 +4319,14 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
static ExpressionSyntax GetSpanLength(ExpressionSyntax span) => MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, span, IdentifierName(nameof(Span<int>.Length)));
bool isReleaseMethod = this.MetadataIndex.ReleaseMethods.Contains(externMethodDeclaration.Identifier.ValueText);

TypeSyntaxSettings parameterTypeSyntaxSettings = overloadOf switch
{
FriendlyOverloadOf.ExternMethod => this.externSignatureTypeSettings,
FriendlyOverloadOf.StructMethod => this.extensionMethodSignatureTypeSettings,
FriendlyOverloadOf.InterfaceMethod => this.extensionMethodSignatureTypeSettings,
_ => throw new NotSupportedException(overloadOf.ToString()),
};

MethodSignature<TypeHandleInfo> originalSignature = methodDefinition.DecodeSignature(SignatureHandleProvider.Instance, null);
var parameters = externMethodDeclaration.ParameterList.Parameters.Select(StripAttributes).ToList();
var lengthParamUsedBy = new Dictionary<int, int>();
Expand Down Expand Up @@ -4350,7 +4362,10 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}

TypeHandleInfo parameterTypeInfo = originalSignature.ParameterTypes[param.SequenceNumber - 1];
if (this.IsManagedType(parameterTypeInfo) && (externParam.Modifiers.Any(SyntaxKind.OutKeyword) || externParam.Modifiers.Any(SyntaxKind.RefKeyword)))
bool isManagedParameterType = this.IsManagedType(parameterTypeInfo);
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);

if (isManagedParameterType && (externParam.Modifiers.Any(SyntaxKind.OutKeyword) || externParam.Modifiers.Any(SyntaxKind.RefKeyword)))
{
bool hasOut = externParam.Modifiers.Any(SyntaxKind.OutKeyword);
arguments[param.SequenceNumber - 1] = arguments[param.SequenceNumber - 1].WithRefKindKeyword(TokenWithSpace(hasOut ? SyntaxKind.OutKeyword : SyntaxKind.RefKeyword));
Expand All @@ -4361,7 +4376,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
{
signatureChanged = true;

IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax typeDefHandleName = IdentifierName(externParam.Identifier.ValueText + "Local");

// out SafeHandle
Expand All @@ -4370,7 +4384,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
.WithModifiers(TokenList(TokenWithSpace(SyntaxKind.OutKeyword)));

// HANDLE SomeLocal;
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(pointedElementInfo.ToTypeSyntax(this.externSignatureTypeSettings, null).Type).AddVariables(
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, null).Type).AddVariables(
VariableDeclarator(typeDefHandleName.Identifier))));

// Argument: &SomeLocal
Expand All @@ -4387,7 +4401,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}
else if (isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod))
{
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax typeDefHandleName = IdentifierName(externParam.Identifier.ValueText + "Local");
signatureChanged = true;

Expand Down Expand Up @@ -4480,7 +4493,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}
}

IdentifierNameSyntax origName = IdentifierName(parameters[param.SequenceNumber - 1].Identifier.ValueText);
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
if (isArray)
{
Expand Down Expand Up @@ -4612,7 +4624,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}
else if (isIn && !isOut && isConst && externParam.Type is QualifiedNameSyntax { Right: { Identifier: { ValueText: "PCWSTR" } } })
{
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
signatureChanged = true;
parameters[param.SequenceNumber - 1] = externParam
Expand All @@ -4623,7 +4634,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}
else if (isIn && !isOut && isConst && externParam.Type is QualifiedNameSyntax { Right: { Identifier: { ValueText: "PCSTR" } } })
{
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
signatureChanged = true;
parameters[param.SequenceNumber - 1] = externParam
Expand All @@ -4649,7 +4659,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}
else if (isIn && isOut && this.canUseSpan && externParam.Type is QualifiedNameSyntax { Right: { Identifier: { ValueText: "PWSTR" } } })
{
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax localName = IdentifierName("p" + origName);
IdentifierNameSyntax localWstrName = IdentifierName("wstr" + origName);
signatureChanged = true;
Expand Down Expand Up @@ -4689,6 +4698,37 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, localWstrName, IdentifierName("Length"))))))));
}
else if (isIn && isOptional && !isOut && isManagedParameterType && parameterTypeInfo is PointerTypeHandleInfo ptrInfo && ptrInfo.ElementType.IsValueType(parameterTypeSyntaxSettings) is true && this.canUseUnsafeAsRef)
{
// The extern method couldn't have exposed the parameter as a pointer because the type is managed.
// It would have exposed as an `in` modifier, and non-optional. But we can expose as optional anyway.
signatureChanged = true;
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
parameters[param.SequenceNumber - 1] = parameters[param.SequenceNumber - 1]
.WithType(NullableType(externParam.Type).WithTrailingTrivia(TriviaList(Space)))
.WithModifiers(TokenList()); // drop the `in` modifier.
leadingStatements.Add(
LocalDeclarationStatement(VariableDeclaration(externParam.Type)
.AddVariables(VariableDeclarator(localName.Identifier).WithInitializer(
EqualsValueClause(ConditionalExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName("HasValue")),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName("Value")),
DefaultExpression(externParam.Type)))))));

// We can't pass in null, but we can be fancy to achieve the same effect.
// Unsafe.NullRef<TParamType>() or Unsafe.AsRef<TParamType>(null), depending on what's available.
ExpressionSyntax nullRef = this.canUseUnsafeNullRef
? InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(Unsafe)), GenericName("NullRef", TypeArgumentList().AddArguments(externParam.Type))),
ArgumentList())
: InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(Unsafe)), GenericName(nameof(Unsafe.AsRef), TypeArgumentList().AddArguments(externParam.Type))),
ArgumentList().AddArguments(Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))));
arguments[param.SequenceNumber - 1] = Argument(ConditionalExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName("HasValue")),
localName,
nullRef));
}
}

TypeSyntax? returnSafeHandleType = originalSignature.ReturnType is HandleTypeHandleInfo returnTypeHandleInfo
Expand Down
47 changes: 47 additions & 0 deletions src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,53 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs
return new TypeSyntaxAndMarshaling(syntax);
}

internal override bool? IsValueType(TypeSyntaxSettings inputs)
{
Generator generator = inputs.Generator ?? throw new ArgumentException("Generator required.");
TypeDefinitionHandle typeDefHandle = default;
switch (this.Handle.Kind)
{
case HandleKind.TypeDefinition:
typeDefHandle = (TypeDefinitionHandle)this.Handle;
break;
case HandleKind.TypeReference:
if (generator.TryGetTypeDefHandle((TypeReferenceHandle)this.Handle, out QualifiedTypeDefinitionHandle qualifiedTypeDefHandle))
{
generator = qualifiedTypeDefHandle.Generator;
typeDefHandle = qualifiedTypeDefHandle.DefinitionHandle;
}

break;
default:
return null;
}

if (typeDefHandle.IsNil)
{
return null;
}

TypeDefinition typeDef = generator.Reader.GetTypeDefinition(typeDefHandle);
generator.GetBaseTypeInfo(typeDef, out StringHandle baseName, out StringHandle baseNamespace);
if (generator.Reader.StringComparer.Equals(baseName, nameof(ValueType)) && generator.Reader.StringComparer.Equals(baseNamespace, nameof(System)))
{
// When marshaling, the VARIANT struct becomes object, which is *not* a value type.
if (inputs.AllowMarshaling && generator.Reader.StringComparer.Equals(typeDef.Name, "VARIANT"))
{
return false;
}

return true;
}

if (generator.Reader.StringComparer.Equals(baseName, nameof(Enum)) && generator.Reader.StringComparer.Equals(baseNamespace, nameof(System)))
{
return true;
}

return false;
}

private static bool TryMarshalAsObject(TypeSyntaxSettings inputs, string name, [NotNullWhen(true)] out MarshalAsAttribute? marshalAs)
{
if (inputs.AllowMarshaling)
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/PointerTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs
return new TypeSyntaxAndMarshaling(PointerType(elementTypeDetails.Type));
}

internal override bool? IsValueType(TypeSyntaxSettings inputs) => false;

private bool TryGetElementTypeDefinition(Generator generator, out TypeDefinition typeDef)
{
if (this.ElementType is HandleTypeHandleInfo handleElement)
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.Windows.CsWin32/PrimitiveTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ internal record PrimitiveTypeHandleInfo(PrimitiveTypeCode PrimitiveTypeCode) : T
internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs, CustomAttributeHandleCollection? customAttributes, ParameterAttributes parameterAttributes)
=> new TypeSyntaxAndMarshaling(ToTypeSyntax(this.PrimitiveTypeCode, inputs.PreferNativeInt));

internal override bool? IsValueType(TypeSyntaxSettings inputs)
{
return this.PrimitiveTypeCode is not PrimitiveTypeCode.Object or PrimitiveTypeCode.Void;
}

internal static TypeSyntax ToTypeSyntax(PrimitiveTypeCode typeCode, bool preferNativeInt)
{
return typeCode switch
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/TypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ internal abstract record TypeHandleInfo

internal abstract TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs, CustomAttributeHandleCollection? customAttributes, ParameterAttributes parameterAttributes = default);

internal abstract bool? IsValueType(TypeSyntaxSettings inputs);

protected static bool TryGetSimpleName(TypeSyntax nameSyntax, [NotNullWhen(true)] out string? simpleName)
{
if (nameSyntax is QualifiedNameSyntax qname)
Expand Down
25 changes: 25 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ public GeneratorTests(ITestOutputHelper logger)
new object[] { "net6.0" },
};

public static IEnumerable<object[]> TFMDataNoNetFx35 =>
new object[][]
{
new object[] { "net472" },
new object[] { "netstandard2.0" },
new object[] { "net6.0" },
};

public static Platform[] SpecificCpuArchitectures =>
new Platform[]
{
Expand Down Expand Up @@ -254,6 +262,7 @@ public void COMInterfaceWithSupportedOSPlatform(bool net60, bool allowMarshaling
"CertFreeCertificateChainList", // double pointer extern method
"D3DGetTraceInstructionOffsets", // SizeParamIndex
"PlgBlt", // SizeConst
"IWebBrowser", // Navigate method has an [In, Optional] object parameter
"ENABLE_TRACE_PARAMETERS_V1", // bad xml created at some point.
"JsRuntimeVersion", // An enum that has an extra member in a separate header file.
"ReportEvent", // Failed at one point
Expand Down Expand Up @@ -2723,6 +2732,21 @@ public void OpensMetadataForSharedReading()
Assert.True(this.generator.TryGenerate("CreateFile", CancellationToken.None));
}

[Theory]
[MemberData(nameof(TFMDataNoNetFx35))]
public void MiniDumpWriteDump_AllOptionalPointerParametersAreOptional(string tfm)
{
// We split on TFMs because the generated code is slightly different depending on TFM.
this.compilation = this.starterCompilations[tfm].WithOptions(this.compilation.Options.WithPlatform(Platform.X64));
this.generator = this.CreateGenerator();
Assert.True(this.generator.TryGenerate("MiniDumpWriteDump", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();

MethodDeclarationSyntax externMethod = Assert.Single(this.FindGeneratedMethod("MiniDumpWriteDump"), m => !m.Modifiers.Any(SyntaxKind.ExternKeyword));
Assert.All(externMethod.ParameterList.Parameters.Reverse().Take(3), p => Assert.IsType<NullableTypeSyntax>(p.Type));
}

[Fact]
public void ContainsIllegalCharactersForAPIName_InvisibleCharacters()
{
Expand Down Expand Up @@ -3029,6 +3053,7 @@ private static class MyReferenceAssemblies
#pragma warning disable SA1202 // Elements should be ordered by access
private static readonly ImmutableArray<PackageIdentity> AdditionalPackages = ImmutableArray.Create(
new PackageIdentity("Microsoft.Windows.SDK.Contracts", "10.0.19041.1"),
new PackageIdentity("System.Memory", "4.5.4"),
new PackageIdentity("Microsoft.Win32.Registry", "5.0.0"));

internal static readonly ReferenceAssemblies NetStandard20 = ReferenceAssemblies.NetStandard.NetStandard20.AddPackages(AdditionalPackages.Add(new PackageIdentity("System.Memory", "4.5.4")));
Expand Down

0 comments on commit 6ea3f44

Please sign in to comment.