Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate constants into their typedef structs wherever possible #550

Merged
merged 3 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
138 changes: 69 additions & 69 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ public class Generator : IDisposable
"PCSTR",
};

private static readonly HashSet<string> TypeDefsThatDoNotNestTheirConstants = new HashSet<string>(SpecialTypeDefNames, StringComparer.Ordinal)
{
"PWSTR",
};

/// <summary>
/// This is the preferred capitalizations for modules and class names.
/// If they are not in this list, the capitalization will come from the metadata assembly.
Expand Down Expand Up @@ -293,8 +298,6 @@ public class Generator : IDisposable
private readonly TypeSyntaxSettings functionPointerTypeSettings;
private readonly TypeSyntaxSettings errorMessageTypeSettings;

private readonly Dictionary<TypeReferenceHandle, TypeDefinitionHandle> refToDefCache = new();

private readonly GeneratorOptions options;
private readonly CSharpCompilation? compilation;
private readonly CSharpParseOptions? parseOptions;
Expand Down Expand Up @@ -427,7 +430,7 @@ select ClassDeclaration(Identifier(this.options.ClassName))
result = result.Concat(new MemberDeclarationSyntax[] { comInterfaceFriendlyExtensionsClass });
}

if (this.committedCode.Fields.Any())
if (this.committedCode.TopLevelFields.Any())
{
result = result.Concat(new MemberDeclarationSyntax[] { this.DeclareConstantDefiningClass() });
}
Expand Down Expand Up @@ -1445,9 +1448,22 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
{
this.volatileCode.GenerateConstant(fieldDefHandle, delegate
{
FieldDeclarationSyntax constantDeclaration = this.DeclareConstant(fieldDefHandle);
constantDeclaration = this.AddApiDocumentation(constantDeclaration.Declaration.Variables[0].Identifier.ValueText, constantDeclaration);
this.volatileCode.AddConstant(fieldDefHandle, constantDeclaration);
FieldDefinition fieldDef = this.Reader.GetFieldDefinition(fieldDefHandle);
FieldDeclarationSyntax constantDeclaration = this.DeclareConstant(fieldDef);

TypeHandleInfo fieldTypeInfo = fieldDef.DecodeSignature<TypeHandleInfo, SignatureHandleProvider.IGenericContext?>(SignatureHandleProvider.Instance, null) with { IsConstantField = true };
TypeDefinitionHandle? fieldType = null;
if (fieldTypeInfo is HandleTypeHandleInfo handleInfo && this.IsTypeDefStruct(handleInfo) && handleInfo.Handle.Kind == HandleKind.TypeReference)
{
TypeReference tr = this.Reader.GetTypeReference((TypeReferenceHandle)handleInfo.Handle);
string fieldTypeName = this.Reader.GetString(tr.Name);
if (!TypeDefsThatDoNotNestTheirConstants.Contains(fieldTypeName) && this.TryGetTypeDefHandle(tr, out TypeDefinitionHandle candidate))
{
fieldType = candidate;
}
}

this.volatileCode.AddConstant(fieldDefHandle, constantDeclaration, fieldType);
});
}

Expand Down Expand Up @@ -1494,7 +1510,7 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
return safeHandleType;
}

if (this.FindSymbolIfAlreadyAvailable($"{this.Namespace}.{safeHandleType}") is object)
if (this.FindTypeSymbolIfAlreadyAvailable($"{this.Namespace}.{safeHandleType}") is object)
{
return safeHandleType;
}
Expand Down Expand Up @@ -1607,12 +1623,12 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
{
case "NTSTATUS":
this.TryGenerateConstantOrThrow("STATUS_SUCCESS");
ExpressionSyntax statusSuccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.methodsAndConstantsClassName, IdentifierName("STATUS_SUCCESS"));
ExpressionSyntax statusSuccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseName("winmdroot.Foundation.NTSTATUS"), IdentifierName("STATUS_SUCCESS"));
releaseInvocation = BinaryExpression(SyntaxKind.EqualsExpression, releaseInvocation, statusSuccess);
break;
case "HRESULT":
this.TryGenerateConstantOrThrow("S_OK");
ExpressionSyntax ok = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.methodsAndConstantsClassName, IdentifierName("S_OK"));
ExpressionSyntax ok = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseName("winmdroot.Foundation.HRESULT"), IdentifierName("S_OK"));
releaseInvocation = BinaryExpression(SyntaxKind.EqualsExpression, releaseInvocation, ok);
break;
default:
Expand Down Expand Up @@ -1717,7 +1733,7 @@ internal void GetBaseTypeInfo(TypeDefinition typeDef, out StringHandle baseTypeN
fullyQualifiedName = $"{ns}.{specialName}";

// Skip if the compilation already defines this type or can access it from elsewhere.
if (this.FindSymbolIfAlreadyAvailable(fullyQualifiedName) is object)
if (this.FindTypeSymbolIfAlreadyAvailable(fullyQualifiedName) is object)
{
// The type already exists either in this project or a referenced one.
return null;
Expand Down Expand Up @@ -1816,7 +1832,7 @@ internal bool TryGetTypeDefHandle(TypeReferenceHandle typeRefHandle, out Qualifi
return this.SuperGenerator.TryGetTypeDefinitionHandle(new QualifiedTypeReferenceHandle(this, typeRefHandle), out typeDefHandle);
}

if (this.TryGetTypeDefHandle(typeRefHandle, out TypeDefinitionHandle localTypeDefHandle))
if (this.MetadataIndex.TryGetTypeDefHandle(typeRefHandle, out TypeDefinitionHandle localTypeDefHandle))
{
typeDefHandle = new QualifiedTypeDefinitionHandle(this, localTypeDefHandle);
return true;
Expand All @@ -1826,52 +1842,7 @@ internal bool TryGetTypeDefHandle(TypeReferenceHandle typeRefHandle, out Qualifi
return false;
}

/// <summary>
/// Attempts to translate a <see cref="TypeReferenceHandle"/> to a <see cref="TypeDefinitionHandle"/>.
/// </summary>
/// <param name="typeRefHandle">The reference handle.</param>
/// <param name="typeDefHandle">Receives the type def handle, if one was discovered.</param>
/// <returns><see langword="true"/> if a TypeDefinition was found; otherwise <see langword="false"/>.</returns>
internal bool TryGetTypeDefHandle(TypeReferenceHandle typeRefHandle, out TypeDefinitionHandle typeDefHandle)
{
if (this.refToDefCache.TryGetValue(typeRefHandle, out typeDefHandle))
{
return !typeDefHandle.IsNil;
}

TypeReference typeRef = this.Reader.GetTypeReference(typeRefHandle);
if (typeRef.ResolutionScope.Kind != HandleKind.AssemblyReference)
{
foreach (TypeDefinitionHandle tdh in this.Reader.TypeDefinitions)
{
TypeDefinition typeDef = this.Reader.GetTypeDefinition(tdh);
if (typeDef.Name == typeRef.Name && typeDef.Namespace == typeRef.Namespace)
{
if (typeRef.ResolutionScope.Kind == HandleKind.TypeReference)
{
// The ref is nested. Verify that the type we found is nested in the same type as well.
if (this.TryGetTypeDefHandle((TypeReferenceHandle)typeRef.ResolutionScope, out TypeDefinitionHandle nestingTypeDef) && nestingTypeDef == typeDef.GetDeclaringType())
{
typeDefHandle = tdh;
break;
}
}
else if (typeRef.ResolutionScope.Kind == HandleKind.ModuleDefinition && typeDef.GetDeclaringType().IsNil)
{
typeDefHandle = tdh;
break;
}
else
{
throw new NotSupportedException("Unrecognized ResolutionScope: " + typeRef.ResolutionScope);
}
}
}
}

this.refToDefCache.Add(typeRefHandle, typeDefHandle);
return !typeDefHandle.IsNil;
}
internal bool TryGetTypeDefHandle(TypeReferenceHandle typeRefHandle, out TypeDefinitionHandle typeDefHandle) => this.MetadataIndex.TryGetTypeDefHandle(typeRefHandle, out typeDefHandle);

internal bool TryGetTypeDefHandle(TypeReference typeRef, out TypeDefinitionHandle typeDefHandle) => this.TryGetTypeDefHandle(typeRef.Namespace, typeRef.Name, out typeDefHandle);

Expand Down Expand Up @@ -2662,7 +2633,7 @@ private bool HasObsoleteAttribute(CustomAttributeHandleCollection attributes)
return false;
}

private ISymbol? FindSymbolIfAlreadyAvailable(string fullyQualifiedMetadataName)
private ISymbol? FindTypeSymbolIfAlreadyAvailable(string fullyQualifiedMetadataName)
{
if (this.compilation is object)
{
Expand Down Expand Up @@ -2706,7 +2677,7 @@ private bool HasObsoleteAttribute(CustomAttributeHandleCollection attributes)
string name = this.Reader.GetString(typeDef.Name);
string ns = this.Reader.GetString(typeDef.Namespace);
string fullyQualifiedName = ns + "." + name;
if (this.FindSymbolIfAlreadyAvailable(fullyQualifiedName) is object)
if (this.FindTypeSymbolIfAlreadyAvailable(fullyQualifiedName) is object)
{
// The type already exists either in this project or a referenced one.
return null;
Expand All @@ -2728,7 +2699,7 @@ private bool HasObsoleteAttribute(CustomAttributeHandleCollection attributes)
// Is this a special typedef struct?
if (this.IsTypeDefStruct(typeDef))
{
typeDeclaration = this.DeclareTypeDefStruct(typeDef);
typeDeclaration = this.DeclareTypeDefStruct(typeDef, typeDefHandle);
}
else if (this.IsEmptyStructWithGuid(typeDef))
{
Expand Down Expand Up @@ -2970,9 +2941,8 @@ private void TryGenerateConstantOrThrow(string possiblyQualifiedName)
}
}

private FieldDeclarationSyntax DeclareConstant(FieldDefinitionHandle fieldDefHandle)
private FieldDeclarationSyntax DeclareConstant(FieldDefinition fieldDef)
{
FieldDefinition fieldDef = this.Reader.GetFieldDefinition(fieldDefHandle);
string name = this.Reader.GetString(fieldDef.Name);
try
{
Expand Down Expand Up @@ -3027,6 +2997,8 @@ private FieldDeclarationSyntax DeclareConstant(FieldDefinitionHandle fieldDefHan
VariableDeclarator(Identifier(name)).WithInitializer(EqualsValueClause(value))))
.WithModifiers(modifiers);
result = fieldType.AddMarshalAs(result);
result = this.AddApiDocumentation(result.Declaration.Variables[0].Identifier.ValueText, result);

return result;
}
catch (Exception ex)
Expand All @@ -3041,7 +3013,7 @@ private FieldDeclarationSyntax DeclareConstant(FieldDefinitionHandle fieldDefHan
private ClassDeclarationSyntax DeclareConstantDefiningClass()
{
return ClassDeclaration(this.methodsAndConstantsClassName.Identifier)
.AddMembers(this.committedCode.Fields.ToArray())
.AddMembers(this.committedCode.TopLevelFields.ToArray())
.WithModifiers(TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword)));
}

Expand Down Expand Up @@ -3611,7 +3583,7 @@ private ClassDeclarationSyntax DeclareCocreatableClass(TypeDefinition typeDef)
/// <summary>
/// Creates a struct that emulates a typedef in the C language headers.
/// </summary>
private StructDeclarationSyntax DeclareTypeDefStruct(TypeDefinition typeDef)
private StructDeclarationSyntax DeclareTypeDefStruct(TypeDefinition typeDef, TypeDefinitionHandle typeDefHandle)
{
IdentifierNameSyntax name = IdentifierName(this.Reader.GetString(typeDef.Name));
if (name.Identifier.ValueText == "BOOL")
Expand Down Expand Up @@ -5586,7 +5558,7 @@ private class GeneratedCode
/// </summary>
private readonly Dictionary<TypeDefinitionHandle, MemberDeclarationSyntax> types = new();

private readonly Dictionary<FieldDefinitionHandle, FieldDeclarationSyntax> fieldsToSyntax = new();
private readonly Dictionary<FieldDefinitionHandle, (FieldDeclarationSyntax FieldDeclaration, TypeDefinitionHandle? FieldType)> fieldsToSyntax = new();

private readonly List<ClassDeclarationSyntax> safeHandleTypes = new();

Expand Down Expand Up @@ -5624,13 +5596,15 @@ internal GeneratedCode(GeneratedCode parent)
this.parent = parent;
}

internal IEnumerable<MemberDeclarationSyntax> GeneratedTypes => this.types.Values.Concat(this.specialTypes.Values).Concat(this.safeHandleTypes);
internal IEnumerable<MemberDeclarationSyntax> GeneratedTypes => this.GetTypesWithInjectedFields().Concat(this.specialTypes.Values).Concat(this.safeHandleTypes);

internal IEnumerable<MethodDeclarationSyntax> ComInterfaceExtensions => this.comInterfaceFriendlyExtensionsMembers;

internal IEnumerable<MethodDeclarationSyntax> InlineArrayIndexerExtensions => this.inlineArrayIndexerExtensionsMembers;

internal IEnumerable<FieldDeclarationSyntax> Fields => this.fieldsToSyntax.Values;
internal IEnumerable<FieldDeclarationSyntax> TopLevelFields => from field in this.fieldsToSyntax.Values
where field.FieldType is null || !this.types.ContainsKey(field.FieldType.Value)
select field.FieldDeclaration;

internal IEnumerable<IGrouping<string, MemberDeclarationSyntax>> MembersByModule
{
Expand Down Expand Up @@ -5674,10 +5648,10 @@ internal void AddMemberToModule(string moduleName, IEnumerable<MemberDeclaration
methodsList.AddRange(members);
}

internal void AddConstant(FieldDefinitionHandle fieldDefHandle, FieldDeclarationSyntax constantDeclaration)
internal void AddConstant(FieldDefinitionHandle fieldDefHandle, FieldDeclarationSyntax constantDeclaration, TypeDefinitionHandle? fieldType)
{
this.ThrowIfNotGenerating();
this.fieldsToSyntax.Add(fieldDefHandle, constantDeclaration);
this.fieldsToSyntax.Add(fieldDefHandle, (constantDeclaration, fieldType));
}

internal void AddInlineArrayIndexerExtension(MethodDeclarationSyntax inlineIndexer)
Expand Down Expand Up @@ -5901,6 +5875,32 @@ private void Commit(GeneratedCode? parent)
Commit(this.comInterfaceFriendlyExtensionsMembers, parent?.comInterfaceFriendlyExtensionsMembers);
}

private IEnumerable<MemberDeclarationSyntax> GetTypesWithInjectedFields()
{
var fieldsByType =
(from field in this.fieldsToSyntax
where field.Value.FieldType is not null
group field.Value.FieldDeclaration by field.Value.FieldType into typeGroup
select typeGroup).ToDictionary(k => k.Key!, k => k.ToArray());
foreach (KeyValuePair<TypeDefinitionHandle, MemberDeclarationSyntax> pair in this.types)
{
MemberDeclarationSyntax type = pair.Value;
if (fieldsByType.TryGetValue(pair.Key, out var extraFields))
{
switch (type)
{
case StructDeclarationSyntax structType:
type = structType.AddMembers(extraFields);
break;
default:
throw new NotSupportedException();
}
}

yield return type;
}
}

private void ThrowIfNotGenerating()
{
if (!this.generating)
Expand Down
49 changes: 49 additions & 0 deletions src/Microsoft.Windows.CsWin32/MetadataIndex.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ internal class MetadataIndex : IDisposable

private readonly HashSet<string> releaseMethods = new HashSet<string>(StringComparer.Ordinal);

private readonly Dictionary<TypeReferenceHandle, TypeDefinitionHandle> refToDefCache = new();

/// <summary>
/// The set of names of typedef structs that represent handles where the handle has length of <see cref="IntPtr"/>
/// and is therefore appropriate to wrap in a <see cref="SafeHandle"/>.
Expand Down Expand Up @@ -274,6 +276,53 @@ internal static void Return(MetadataIndex index)
}
}

/// <summary>
/// Attempts to translate a <see cref="TypeReferenceHandle"/> to a <see cref="TypeDefinitionHandle"/>.
/// </summary>
/// <param name="typeRefHandle">The reference handle.</param>
/// <param name="typeDefHandle">Receives the type def handle, if one was discovered.</param>
/// <returns><see langword="true"/> if a TypeDefinition was found; otherwise <see langword="false"/>.</returns>
internal bool TryGetTypeDefHandle(TypeReferenceHandle typeRefHandle, out TypeDefinitionHandle typeDefHandle)
{
if (this.refToDefCache.TryGetValue(typeRefHandle, out typeDefHandle))
{
return !typeDefHandle.IsNil;
}

TypeReference typeRef = this.Reader.GetTypeReference(typeRefHandle);
if (typeRef.ResolutionScope.Kind != HandleKind.AssemblyReference)
{
foreach (TypeDefinitionHandle tdh in this.Reader.TypeDefinitions)
{
TypeDefinition typeDef = this.Reader.GetTypeDefinition(tdh);
if (typeDef.Name == typeRef.Name && typeDef.Namespace == typeRef.Namespace)
{
if (typeRef.ResolutionScope.Kind == HandleKind.TypeReference)
{
// The ref is nested. Verify that the type we found is nested in the same type as well.
if (this.TryGetTypeDefHandle((TypeReferenceHandle)typeRef.ResolutionScope, out TypeDefinitionHandle nestingTypeDef) && nestingTypeDef == typeDef.GetDeclaringType())
{
typeDefHandle = tdh;
break;
}
}
else if (typeRef.ResolutionScope.Kind == HandleKind.ModuleDefinition && typeDef.GetDeclaringType().IsNil)
{
typeDefHandle = tdh;
break;
}
else
{
throw new NotSupportedException("Unrecognized ResolutionScope: " + typeRef.ResolutionScope);
}
}
}
}

this.refToDefCache.Add(typeRefHandle, typeDefHandle);
return !typeDefHandle.IsNil;
}

private static string CommonPrefix(IReadOnlyList<string> ss)
{
if (ss.Count == 0)
Expand Down