Skip to content

Commit

Permalink
Merge pull request #550 from microsoft/fix340
Browse files Browse the repository at this point in the history
Generate constants into their typedef structs wherever possible
  • Loading branch information
AArnott committed May 16, 2022
2 parents 14e48cf + ee42a7c commit d632375
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 69 deletions.
138 changes: 69 additions & 69 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ public class Generator : IDisposable
"PCSTR",
};

private static readonly HashSet<string> TypeDefsThatDoNotNestTheirConstants = new HashSet<string>(SpecialTypeDefNames, StringComparer.Ordinal)
{
"PWSTR",
};

/// <summary>
/// This is the preferred capitalizations for modules and class names.
/// If they are not in this list, the capitalization will come from the metadata assembly.
Expand Down Expand Up @@ -293,8 +298,6 @@ public class Generator : IDisposable
private readonly TypeSyntaxSettings functionPointerTypeSettings;
private readonly TypeSyntaxSettings errorMessageTypeSettings;

private readonly Dictionary<TypeReferenceHandle, TypeDefinitionHandle> refToDefCache = new();

private readonly GeneratorOptions options;
private readonly CSharpCompilation? compilation;
private readonly CSharpParseOptions? parseOptions;
Expand Down Expand Up @@ -427,7 +430,7 @@ select ClassDeclaration(Identifier(this.options.ClassName))
result = result.Concat(new MemberDeclarationSyntax[] { comInterfaceFriendlyExtensionsClass });
}

if (this.committedCode.Fields.Any())
if (this.committedCode.TopLevelFields.Any())
{
result = result.Concat(new MemberDeclarationSyntax[] { this.DeclareConstantDefiningClass() });
}
Expand Down Expand Up @@ -1445,9 +1448,22 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
{
this.volatileCode.GenerateConstant(fieldDefHandle, delegate
{
FieldDeclarationSyntax constantDeclaration = this.DeclareConstant(fieldDefHandle);
constantDeclaration = this.AddApiDocumentation(constantDeclaration.Declaration.Variables[0].Identifier.ValueText, constantDeclaration);
this.volatileCode.AddConstant(fieldDefHandle, constantDeclaration);
FieldDefinition fieldDef = this.Reader.GetFieldDefinition(fieldDefHandle);
FieldDeclarationSyntax constantDeclaration = this.DeclareConstant(fieldDef);

TypeHandleInfo fieldTypeInfo = fieldDef.DecodeSignature<TypeHandleInfo, SignatureHandleProvider.IGenericContext?>(SignatureHandleProvider.Instance, null) with { IsConstantField = true };
TypeDefinitionHandle? fieldType = null;
if (fieldTypeInfo is HandleTypeHandleInfo handleInfo && this.IsTypeDefStruct(handleInfo) && handleInfo.Handle.Kind == HandleKind.TypeReference)
{
TypeReference tr = this.Reader.GetTypeReference((TypeReferenceHandle)handleInfo.Handle);
string fieldTypeName = this.Reader.GetString(tr.Name);
if (!TypeDefsThatDoNotNestTheirConstants.Contains(fieldTypeName) && this.TryGetTypeDefHandle(tr, out TypeDefinitionHandle candidate))
{
fieldType = candidate;
}
}

this.volatileCode.AddConstant(fieldDefHandle, constantDeclaration, fieldType);
});
}

Expand Down Expand Up @@ -1494,7 +1510,7 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
return safeHandleType;
}

if (this.FindSymbolIfAlreadyAvailable($"{this.Namespace}.{safeHandleType}") is object)
if (this.FindTypeSymbolIfAlreadyAvailable($"{this.Namespace}.{safeHandleType}") is object)
{
return safeHandleType;
}
Expand Down Expand Up @@ -1607,12 +1623,12 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
{
case "NTSTATUS":
this.TryGenerateConstantOrThrow("STATUS_SUCCESS");
ExpressionSyntax statusSuccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.methodsAndConstantsClassName, IdentifierName("STATUS_SUCCESS"));
ExpressionSyntax statusSuccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseName("winmdroot.Foundation.NTSTATUS"), IdentifierName("STATUS_SUCCESS"));
releaseInvocation = BinaryExpression(SyntaxKind.EqualsExpression, releaseInvocation, statusSuccess);
break;
case "HRESULT":
this.TryGenerateConstantOrThrow("S_OK");
ExpressionSyntax ok = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.methodsAndConstantsClassName, IdentifierName("S_OK"));
ExpressionSyntax ok = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseName("winmdroot.Foundation.HRESULT"), IdentifierName("S_OK"));
releaseInvocation = BinaryExpression(SyntaxKind.EqualsExpression, releaseInvocation, ok);
break;
default:
Expand Down Expand Up @@ -1717,7 +1733,7 @@ internal void GetBaseTypeInfo(TypeDefinition typeDef, out StringHandle baseTypeN
fullyQualifiedName = $"{ns}.{specialName}";

// Skip if the compilation already defines this type or can access it from elsewhere.
if (this.FindSymbolIfAlreadyAvailable(fullyQualifiedName) is object)
if (this.FindTypeSymbolIfAlreadyAvailable(fullyQualifiedName) is object)
{
// The type already exists either in this project or a referenced one.
return null;
Expand Down Expand Up @@ -1816,7 +1832,7 @@ internal bool TryGetTypeDefHandle(TypeReferenceHandle typeRefHandle, out Qualifi
return this.SuperGenerator.TryGetTypeDefinitionHandle(new QualifiedTypeReferenceHandle(this, typeRefHandle), out typeDefHandle);
}

if (this.TryGetTypeDefHandle(typeRefHandle, out TypeDefinitionHandle localTypeDefHandle))
if (this.MetadataIndex.TryGetTypeDefHandle(typeRefHandle, out TypeDefinitionHandle localTypeDefHandle))
{
typeDefHandle = new QualifiedTypeDefinitionHandle(this, localTypeDefHandle);
return true;
Expand All @@ -1826,52 +1842,7 @@ internal bool TryGetTypeDefHandle(TypeReferenceHandle typeRefHandle, out Qualifi
return false;
}

/// <summary>
/// Attempts to translate a <see cref="TypeReferenceHandle"/> to a <see cref="TypeDefinitionHandle"/>.
/// </summary>
/// <param name="typeRefHandle">The reference handle.</param>
/// <param name="typeDefHandle">Receives the type def handle, if one was discovered.</param>
/// <returns><see langword="true"/> if a TypeDefinition was found; otherwise <see langword="false"/>.</returns>
internal bool TryGetTypeDefHandle(TypeReferenceHandle typeRefHandle, out TypeDefinitionHandle typeDefHandle)
{
if (this.refToDefCache.TryGetValue(typeRefHandle, out typeDefHandle))
{
return !typeDefHandle.IsNil;
}

TypeReference typeRef = this.Reader.GetTypeReference(typeRefHandle);
if (typeRef.ResolutionScope.Kind != HandleKind.AssemblyReference)
{
foreach (TypeDefinitionHandle tdh in this.Reader.TypeDefinitions)
{
TypeDefinition typeDef = this.Reader.GetTypeDefinition(tdh);
if (typeDef.Name == typeRef.Name && typeDef.Namespace == typeRef.Namespace)
{
if (typeRef.ResolutionScope.Kind == HandleKind.TypeReference)
{
// The ref is nested. Verify that the type we found is nested in the same type as well.
if (this.TryGetTypeDefHandle((TypeReferenceHandle)typeRef.ResolutionScope, out TypeDefinitionHandle nestingTypeDef) && nestingTypeDef == typeDef.GetDeclaringType())
{
typeDefHandle = tdh;
break;
}
}
else if (typeRef.ResolutionScope.Kind == HandleKind.ModuleDefinition && typeDef.GetDeclaringType().IsNil)
{
typeDefHandle = tdh;
break;
}
else
{
throw new NotSupportedException("Unrecognized ResolutionScope: " + typeRef.ResolutionScope);
}
}
}
}

this.refToDefCache.Add(typeRefHandle, typeDefHandle);
return !typeDefHandle.IsNil;
}
internal bool TryGetTypeDefHandle(TypeReferenceHandle typeRefHandle, out TypeDefinitionHandle typeDefHandle) => this.MetadataIndex.TryGetTypeDefHandle(typeRefHandle, out typeDefHandle);

internal bool TryGetTypeDefHandle(TypeReference typeRef, out TypeDefinitionHandle typeDefHandle) => this.TryGetTypeDefHandle(typeRef.Namespace, typeRef.Name, out typeDefHandle);

Expand Down Expand Up @@ -2662,7 +2633,7 @@ private bool HasObsoleteAttribute(CustomAttributeHandleCollection attributes)
return false;
}

private ISymbol? FindSymbolIfAlreadyAvailable(string fullyQualifiedMetadataName)
private ISymbol? FindTypeSymbolIfAlreadyAvailable(string fullyQualifiedMetadataName)
{
if (this.compilation is object)
{
Expand Down Expand Up @@ -2706,7 +2677,7 @@ private bool HasObsoleteAttribute(CustomAttributeHandleCollection attributes)
string name = this.Reader.GetString(typeDef.Name);
string ns = this.Reader.GetString(typeDef.Namespace);
string fullyQualifiedName = ns + "." + name;
if (this.FindSymbolIfAlreadyAvailable(fullyQualifiedName) is object)
if (this.FindTypeSymbolIfAlreadyAvailable(fullyQualifiedName) is object)
{
// The type already exists either in this project or a referenced one.
return null;
Expand All @@ -2728,7 +2699,7 @@ private bool HasObsoleteAttribute(CustomAttributeHandleCollection attributes)
// Is this a special typedef struct?
if (this.IsTypeDefStruct(typeDef))
{
typeDeclaration = this.DeclareTypeDefStruct(typeDef);
typeDeclaration = this.DeclareTypeDefStruct(typeDef, typeDefHandle);
}
else if (this.IsEmptyStructWithGuid(typeDef))
{
Expand Down Expand Up @@ -2970,9 +2941,8 @@ private void TryGenerateConstantOrThrow(string possiblyQualifiedName)
}
}

private FieldDeclarationSyntax DeclareConstant(FieldDefinitionHandle fieldDefHandle)
private FieldDeclarationSyntax DeclareConstant(FieldDefinition fieldDef)
{
FieldDefinition fieldDef = this.Reader.GetFieldDefinition(fieldDefHandle);
string name = this.Reader.GetString(fieldDef.Name);
try
{
Expand Down Expand Up @@ -3027,6 +2997,8 @@ private FieldDeclarationSyntax DeclareConstant(FieldDefinitionHandle fieldDefHan
VariableDeclarator(Identifier(name)).WithInitializer(EqualsValueClause(value))))
.WithModifiers(modifiers);
result = fieldType.AddMarshalAs(result);
result = this.AddApiDocumentation(result.Declaration.Variables[0].Identifier.ValueText, result);

return result;
}
catch (Exception ex)
Expand All @@ -3041,7 +3013,7 @@ private FieldDeclarationSyntax DeclareConstant(FieldDefinitionHandle fieldDefHan
private ClassDeclarationSyntax DeclareConstantDefiningClass()
{
return ClassDeclaration(this.methodsAndConstantsClassName.Identifier)
.AddMembers(this.committedCode.Fields.ToArray())
.AddMembers(this.committedCode.TopLevelFields.ToArray())
.WithModifiers(TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword)));
}

Expand Down Expand Up @@ -3611,7 +3583,7 @@ private ClassDeclarationSyntax DeclareCocreatableClass(TypeDefinition typeDef)
/// <summary>
/// Creates a struct that emulates a typedef in the C language headers.
/// </summary>
private StructDeclarationSyntax DeclareTypeDefStruct(TypeDefinition typeDef)
private StructDeclarationSyntax DeclareTypeDefStruct(TypeDefinition typeDef, TypeDefinitionHandle typeDefHandle)
{
IdentifierNameSyntax name = IdentifierName(this.Reader.GetString(typeDef.Name));
if (name.Identifier.ValueText == "BOOL")
Expand Down Expand Up @@ -5586,7 +5558,7 @@ private class GeneratedCode
/// </summary>
private readonly Dictionary<TypeDefinitionHandle, MemberDeclarationSyntax> types = new();

private readonly Dictionary<FieldDefinitionHandle, FieldDeclarationSyntax> fieldsToSyntax = new();
private readonly Dictionary<FieldDefinitionHandle, (FieldDeclarationSyntax FieldDeclaration, TypeDefinitionHandle? FieldType)> fieldsToSyntax = new();

private readonly List<ClassDeclarationSyntax> safeHandleTypes = new();

Expand Down Expand Up @@ -5624,13 +5596,15 @@ internal GeneratedCode(GeneratedCode parent)
this.parent = parent;
}

internal IEnumerable<MemberDeclarationSyntax> GeneratedTypes => this.types.Values.Concat(this.specialTypes.Values).Concat(this.safeHandleTypes);
internal IEnumerable<MemberDeclarationSyntax> GeneratedTypes => this.GetTypesWithInjectedFields().Concat(this.specialTypes.Values).Concat(this.safeHandleTypes);

internal IEnumerable<MethodDeclarationSyntax> ComInterfaceExtensions => this.comInterfaceFriendlyExtensionsMembers;

internal IEnumerable<MethodDeclarationSyntax> InlineArrayIndexerExtensions => this.inlineArrayIndexerExtensionsMembers;

internal IEnumerable<FieldDeclarationSyntax> Fields => this.fieldsToSyntax.Values;
internal IEnumerable<FieldDeclarationSyntax> TopLevelFields => from field in this.fieldsToSyntax.Values
where field.FieldType is null || !this.types.ContainsKey(field.FieldType.Value)
select field.FieldDeclaration;

internal IEnumerable<IGrouping<string, MemberDeclarationSyntax>> MembersByModule
{
Expand Down Expand Up @@ -5674,10 +5648,10 @@ internal void AddMemberToModule(string moduleName, IEnumerable<MemberDeclaration
methodsList.AddRange(members);
}

internal void AddConstant(FieldDefinitionHandle fieldDefHandle, FieldDeclarationSyntax constantDeclaration)
internal void AddConstant(FieldDefinitionHandle fieldDefHandle, FieldDeclarationSyntax constantDeclaration, TypeDefinitionHandle? fieldType)
{
this.ThrowIfNotGenerating();
this.fieldsToSyntax.Add(fieldDefHandle, constantDeclaration);
this.fieldsToSyntax.Add(fieldDefHandle, (constantDeclaration, fieldType));
}

internal void AddInlineArrayIndexerExtension(MethodDeclarationSyntax inlineIndexer)
Expand Down Expand Up @@ -5901,6 +5875,32 @@ private void Commit(GeneratedCode? parent)
Commit(this.comInterfaceFriendlyExtensionsMembers, parent?.comInterfaceFriendlyExtensionsMembers);
}

private IEnumerable<MemberDeclarationSyntax> GetTypesWithInjectedFields()
{
var fieldsByType =
(from field in this.fieldsToSyntax
where field.Value.FieldType is not null
group field.Value.FieldDeclaration by field.Value.FieldType into typeGroup
select typeGroup).ToDictionary(k => k.Key!, k => k.ToArray());
foreach (KeyValuePair<TypeDefinitionHandle, MemberDeclarationSyntax> pair in this.types)
{
MemberDeclarationSyntax type = pair.Value;
if (fieldsByType.TryGetValue(pair.Key, out var extraFields))
{
switch (type)
{
case StructDeclarationSyntax structType:
type = structType.AddMembers(extraFields);
break;
default:
throw new NotSupportedException();
}
}

yield return type;
}
}

private void ThrowIfNotGenerating()
{
if (!this.generating)
Expand Down
49 changes: 49 additions & 0 deletions src/Microsoft.Windows.CsWin32/MetadataIndex.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ internal class MetadataIndex : IDisposable

private readonly HashSet<string> releaseMethods = new HashSet<string>(StringComparer.Ordinal);

private readonly Dictionary<TypeReferenceHandle, TypeDefinitionHandle> refToDefCache = new();

/// <summary>
/// The set of names of typedef structs that represent handles where the handle has length of <see cref="IntPtr"/>
/// and is therefore appropriate to wrap in a <see cref="SafeHandle"/>.
Expand Down Expand Up @@ -274,6 +276,53 @@ internal static void Return(MetadataIndex index)
}
}

/// <summary>
/// Attempts to translate a <see cref="TypeReferenceHandle"/> to a <see cref="TypeDefinitionHandle"/>.
/// </summary>
/// <param name="typeRefHandle">The reference handle.</param>
/// <param name="typeDefHandle">Receives the type def handle, if one was discovered.</param>
/// <returns><see langword="true"/> if a TypeDefinition was found; otherwise <see langword="false"/>.</returns>
internal bool TryGetTypeDefHandle(TypeReferenceHandle typeRefHandle, out TypeDefinitionHandle typeDefHandle)
{
if (this.refToDefCache.TryGetValue(typeRefHandle, out typeDefHandle))
{
return !typeDefHandle.IsNil;
}

TypeReference typeRef = this.Reader.GetTypeReference(typeRefHandle);
if (typeRef.ResolutionScope.Kind != HandleKind.AssemblyReference)
{
foreach (TypeDefinitionHandle tdh in this.Reader.TypeDefinitions)
{
TypeDefinition typeDef = this.Reader.GetTypeDefinition(tdh);
if (typeDef.Name == typeRef.Name && typeDef.Namespace == typeRef.Namespace)
{
if (typeRef.ResolutionScope.Kind == HandleKind.TypeReference)
{
// The ref is nested. Verify that the type we found is nested in the same type as well.
if (this.TryGetTypeDefHandle((TypeReferenceHandle)typeRef.ResolutionScope, out TypeDefinitionHandle nestingTypeDef) && nestingTypeDef == typeDef.GetDeclaringType())
{
typeDefHandle = tdh;
break;
}
}
else if (typeRef.ResolutionScope.Kind == HandleKind.ModuleDefinition && typeDef.GetDeclaringType().IsNil)
{
typeDefHandle = tdh;
break;
}
else
{
throw new NotSupportedException("Unrecognized ResolutionScope: " + typeRef.ResolutionScope);
}
}
}
}

this.refToDefCache.Add(typeRefHandle, typeDefHandle);
return !typeDefHandle.IsNil;
}

private static string CommonPrefix(IReadOnlyList<string> ss)
{
if (ss.Count == 0)
Expand Down

0 comments on commit d632375

Please sign in to comment.