Skip to content

Commit

Permalink
Avoid generating bool as struct field
Browse files Browse the repository at this point in the history
Fixes #126
  • Loading branch information
AArnott committed Feb 18, 2021
1 parent 0ee8941 commit b934e2b
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 18 deletions.
32 changes: 18 additions & 14 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@ public class Generator : IDisposable
{
{ nameof(System.Runtime.InteropServices.ComTypes.FILETIME), ParseTypeName("System.Runtime.InteropServices.ComTypes.FILETIME") },
{ nameof(Guid), ParseTypeName("System.Guid") },
{ "BOOL", PredefinedType(Token(SyntaxKind.BoolKeyword)) },
{ "OLD_LARGE_INTEGER", PredefinedType(Token(SyntaxKind.LongKeyword)) },
{ "LARGE_INTEGER", PredefinedType(Token(SyntaxKind.LongKeyword)) },
{ "ULARGE_INTEGER", PredefinedType(Token(SyntaxKind.ULongKeyword)) },
};

internal static readonly Dictionary<string, TypeSyntax> AdditionalBclInteropStructsMarshaled = new Dictionary<string, TypeSyntax>(StringComparer.Ordinal)
{
{ "BOOL", PredefinedType(Token(SyntaxKind.BoolKeyword)) },
};

internal static readonly Dictionary<string, TypeSyntax> BclInteropSafeHandles = new Dictionary<string, TypeSyntax>(StringComparer.Ordinal)
{
{ "CloseHandle", ParseTypeName("Microsoft.Win32.SafeHandles.SafeFileHandle").WithAdditionalAnnotations(IsManagedTypeAnnotation, IsSafeHandleTypeAnnotation) },
Expand Down Expand Up @@ -219,8 +223,8 @@ public class Generator : IDisposable
private readonly MetadataReader mr;
private readonly SignatureTypeProvider signatureTypeProvider;
private readonly SignatureTypeProvider signatureTypeProviderAlwaysUseIntPtr;
private readonly SignatureTypeProvider signatureTypeProviderNoSafeHandles;
private readonly SignatureTypeProvider signatureTypeProviderNoSafeHandlesOrNint;
private readonly SignatureTypeProvider signatureTypeProviderNoMarshaledTypes;
private readonly SignatureTypeProvider signatureTypeProviderNoMarshaledTypesOrNint;
private readonly CustomAttributeTypeProvider customAttributeTypeProvider;
private readonly Dictionary<string, List<MemberDeclarationSyntax>> modulesAndMembers = new Dictionary<string, List<MemberDeclarationSyntax>>(StringComparer.OrdinalIgnoreCase);

Expand Down Expand Up @@ -281,10 +285,10 @@ public Generator(Stream metadataLibraryStream, GeneratorOptions? options = null,
this.peReader = new PEReader(this.metadataStream);
this.mr = this.peReader.GetMetadataReader();

this.signatureTypeProvider = new SignatureTypeProvider(this, preferNativeInt: this.LanguageVersion >= LanguageVersion.CSharp9, preferSafeHandles: true);
this.signatureTypeProviderAlwaysUseIntPtr = new SignatureTypeProvider(this, preferNativeInt: false, preferSafeHandles: true);
this.signatureTypeProviderNoSafeHandles = new SignatureTypeProvider(this, preferNativeInt: true, preferSafeHandles: false);
this.signatureTypeProviderNoSafeHandlesOrNint = new SignatureTypeProvider(this, preferNativeInt: false, preferSafeHandles: false);
this.signatureTypeProvider = new SignatureTypeProvider(this, preferNativeInt: this.LanguageVersion >= LanguageVersion.CSharp9, preferMarshaledTypes: true);
this.signatureTypeProviderAlwaysUseIntPtr = new SignatureTypeProvider(this, preferNativeInt: false, preferMarshaledTypes: true);
this.signatureTypeProviderNoMarshaledTypes = new SignatureTypeProvider(this, preferNativeInt: true, preferMarshaledTypes: false);
this.signatureTypeProviderNoMarshaledTypesOrNint = new SignatureTypeProvider(this, preferNativeInt: false, preferMarshaledTypes: false);
this.customAttributeTypeProvider = new CustomAttributeTypeProvider();

this.Apis = this.mr.TypeDefinitions.Select(this.mr.GetTypeDefinition).Where(td => this.mr.StringComparer.Equals(td.Name, "Apis")).ToList();
Expand Down Expand Up @@ -722,7 +726,7 @@ internal void GenerateExternMethod(MethodDefinitionHandle methodDefinitionHandle
}

// If this method releases a handle, recreate the method signature such that we take the struct rather than the SafeHandle as a parameter.
var signatureTypeProvider = this.releaseMethods.Contains(entrypoint ?? methodName) ? this.signatureTypeProviderNoSafeHandlesOrNint : this.signatureTypeProvider;
var signatureTypeProvider = this.releaseMethods.Contains(entrypoint ?? methodName) ? this.signatureTypeProviderNoMarshaledTypesOrNint : this.signatureTypeProvider;
MethodSignature<TypeSyntax> signature = methodDefinition.DecodeSignature(signatureTypeProvider, null);

CustomAttributeHandleCollection? returnTypeAttributes = this.GetReturnTypeCustomAttributes(methodDefinition);
Expand Down Expand Up @@ -852,7 +856,7 @@ internal void GenerateConstant(FieldDefinitionHandle fieldDefHandle)
: safeHandleTypeIdentifier;
safeHandleType = safeHandleType.WithAdditionalAnnotations(IsManagedTypeAnnotation, IsSafeHandleTypeAnnotation);

var releaseMethodSignature = releaseMethodDef.DecodeSignature(this.signatureTypeProviderNoSafeHandlesOrNint, null);
var releaseMethodSignature = releaseMethodDef.DecodeSignature(this.signatureTypeProviderNoMarshaledTypesOrNint, null);

// If the release method takes more than one parameter, we can't generate a SafeHandle for it.
if (releaseMethodSignature.RequiredParameterCount != 1)
Expand Down Expand Up @@ -1654,7 +1658,7 @@ private FieldDeclarationSyntax CreateField(FieldDefinitionHandle fieldDefHandle)
string name = this.mr.GetString(fieldDef.Name);
try
{
TypeSyntax fieldType = fieldDef.DecodeSignature(this.signatureTypeProviderNoSafeHandles, null);
TypeSyntax fieldType = fieldDef.DecodeSignature(this.signatureTypeProviderNoMarshaledTypes, null);
Constant constant = this.mr.GetConstant(fieldDef.GetDefaultValue());
ExpressionSyntax value = this.ToExpressionSyntax(constant);
if (fieldType is not PredefinedTypeSyntax)
Expand Down Expand Up @@ -1758,7 +1762,7 @@ private ClassDeclarationSyntax CreateConstantDefiningClass()
string methodName = this.mr.GetString(methodDefinition.Name);
IdentifierNameSyntax innerMethodName = IdentifierName($"{methodName}_{methodCounter}");

MethodSignature<TypeSyntax> signature = methodDefinition.DecodeSignature(this.signatureTypeProviderNoSafeHandles, null);
MethodSignature<TypeSyntax> signature = methodDefinition.DecodeSignature(this.signatureTypeProvider, null);
CustomAttributeHandleCollection? returnTypeAttributes = this.GetReturnTypeCustomAttributes(methodDefinition);

TypeSyntax returnType = signature.ReturnType;
Expand Down Expand Up @@ -1948,7 +1952,7 @@ private StructDeclarationSyntax CreateInteropStruct(TypeDefinition typeDef)
}
else
{
var fieldInfo = this.ReinterpretFieldType(fieldDeclarator.Identifier.ValueText, fieldDef.DecodeSignature(this.signatureTypeProviderNoSafeHandles, null), fieldDef.GetCustomAttributes());
var fieldInfo = this.ReinterpretFieldType(fieldDeclarator.Identifier.ValueText, fieldDef.DecodeSignature(this.signatureTypeProviderNoMarshaledTypes, null), fieldDef.GetCustomAttributes());
if (fieldInfo.AdditionalMembers.Count > 0)
{
fieldDeclarator = fieldDeclarator.WithIdentifier(Identifier(GetHiddenFieldName(fieldDeclarator.Identifier.ValueText)));
Expand Down Expand Up @@ -2880,11 +2884,11 @@ private ParameterSyntax CreateParameter(MethodSignature<TypeSyntax> methodSignat
// If the field is a delegate type, we have to replace that with a native function pointer to avoid the struct becoming a 'managed type'.
if (originalType is PointerTypeSyntax { ElementType: IdentifierNameSyntax idName } && this.IsDelegateReference(idName, out TypeDefinition typeDef))
{
return (this.FunctionPointer(typeDef, this.signatureTypeProviderNoSafeHandles), default);
return (this.FunctionPointer(typeDef, this.signatureTypeProviderNoMarshaledTypes), default);
}
else if (originalType is IdentifierNameSyntax idName2 && this.IsDelegateReference(idName2, out typeDef))
{
return (this.FunctionPointer(typeDef, this.signatureTypeProviderNoSafeHandles), default);
return (this.FunctionPointer(typeDef, this.signatureTypeProviderNoMarshaledTypes), default);
}

return (originalType, default);
Expand Down
18 changes: 14 additions & 4 deletions src/Microsoft.Windows.CsWin32/SignatureTypeProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ internal class SignatureTypeProvider : ISignatureTypeProvider<TypeSyntax, IGener
{
private readonly Generator owner;
private readonly bool preferNativeInt;
private readonly bool preferSafeHandles;
private readonly bool preferMarshaledTypes;

internal SignatureTypeProvider(Generator owner, bool preferNativeInt, bool preferSafeHandles)
internal SignatureTypeProvider(Generator owner, bool preferNativeInt, bool preferMarshaledTypes)
{
this.owner = owner;
this.preferNativeInt = preferNativeInt;
this.preferSafeHandles = preferSafeHandles;
this.preferMarshaledTypes = preferMarshaledTypes;
}

/// <inheritdoc/>
Expand Down Expand Up @@ -67,6 +67,11 @@ public TypeSyntax GetTypeFromDefinition(MetadataReader reader, TypeDefinitionHan
return bclType;
}

if (this.preferMarshaledTypes && Generator.AdditionalBclInteropStructsMarshaled.TryGetValue(name, out bclType))
{
return bclType;
}

this.owner.GenerateInteropType(handle);
TypeSyntax identifier = IdentifierName(name);

Expand All @@ -90,10 +95,15 @@ public TypeSyntax GetTypeFromReference(MetadataReader reader, TypeReferenceHandl
return bclType;
}

if (this.preferMarshaledTypes && Generator.AdditionalBclInteropStructsMarshaled.TryGetValue(name, out bclType))
{
return bclType;
}

TypeDefinitionHandle? typeDefHandle = this.owner.GenerateInteropType(handle);
if (typeDefHandle.HasValue)
{
if (this.preferSafeHandles && this.owner.TryGetHandleReleaseMethod(name, out string? releaseMethod) && this.owner.GenerateSafeHandle(releaseMethod) is TypeSyntax safeHandleType)
if (this.preferMarshaledTypes && this.owner.TryGetHandleReleaseMethod(name, out string? releaseMethod) && this.owner.GenerateSafeHandle(releaseMethod) is TypeSyntax safeHandleType)
{
// Return the safe handle instead.
return safeHandleType;
Expand Down
27 changes: 27 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,33 @@ public void BOOL_ReturnTypeBecomes_Boolean()
Assert.Equal(SyntaxKind.BoolKeyword, Assert.IsType<PredefinedTypeSyntax>(createFileMethod!.ReturnType).Keyword.Kind());
}

[Fact]
public void BOOL_ReturnTypeBecomes_Boolean_InCOMInterface()
{
this.generator = new Generator(this.metadataStream, compilation: this.compilation, parseOptions: this.parseOptions);
Assert.True(this.generator.TryGenerate("ISpellCheckerFactory", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
MethodDeclarationSyntax? method = this.FindGeneratedMethod("IsSupported");
Assert.NotNull(method);
Assert.Equal(SyntaxKind.BoolKeyword, Assert.IsType<PredefinedTypeSyntax>(method!.ParameterList.Parameters.Last().Type).Keyword.Kind());
}

/// <summary>
/// Verifies that fields are not converted from BOOL to bool.
/// </summary>
[Fact]
public void BOOL_FieldRemainsBOOL()
{
this.generator = new Generator(this.metadataStream, compilation: this.compilation, parseOptions: this.parseOptions);
Assert.True(this.generator.TryGenerate("ICONINFO", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
var theStruct = (StructDeclarationSyntax?)this.FindGeneratedType("ICONINFO");
Assert.NotNull(theStruct);
Assert.Equal("BOOL", theStruct!.Members.OfType<FieldDeclarationSyntax>().Select(m => m.Declaration).Single(d => d.Variables.Any(v => v.Identifier.ValueText == "fIcon")).Type.ToString());
}

[Fact]
public void BSTR_FieldsDoNotBecomeSafeHandles()
{
Expand Down

0 comments on commit b934e2b

Please sign in to comment.