Skip to content

Commit

Permalink
Consolidate extern method and constants into the same class
Browse files Browse the repository at this point in the history
  • Loading branch information
AArnott committed Nov 9, 2021
1 parent d365c4d commit 3eb26b4
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 71 deletions.
53 changes: 18 additions & 35 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ public class Generator : IDisposable
private readonly bool generateDefaultDllImportSearchPathsAttribute;
private readonly GeneratedCode committedCode = new();
private readonly GeneratedCode volatileCode;
private readonly IdentifierNameSyntax constantsClassName;
private readonly IdentifierNameSyntax methodsAndConstantsClassName;
private bool needsWinRTCustomMarshaler;

/// <summary>
Expand Down Expand Up @@ -356,7 +356,7 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option
this.functionPointerTypeSettings = this.generalTypeSettings with { QualifyNames = true };
this.errorMessageTypeSettings = this.generalTypeSettings with { QualifyNames = true };

this.constantsClassName = IdentifierName(options.ConstantsClassName);
this.methodsAndConstantsClassName = IdentifierName(options.ClassName);
}

private enum FriendlyOverloadOf
Expand Down Expand Up @@ -393,37 +393,23 @@ private enum FriendlyOverloadOf

private bool WideCharOnly => this.options.WideCharOnly;

private bool GroupByModule => string.IsNullOrEmpty(this.options.MethodsClassName);

private string Namespace => this.InputAssemblyName;

private string SingleClassName => this.options.MethodsClassName ?? throw new InvalidOperationException("Not in one-class mode.");

private SyntaxKind Visibility => this.options.Public ? SyntaxKind.PublicKeyword : SyntaxKind.InternalKeyword;

private IEnumerable<MemberDeclarationSyntax> NamespaceMembers
{
get
{
IEnumerable<MemberDeclarationSyntax> result = this.GroupByModule
? this.ExternMethodsByModuleClassName.Select(kv =>
ClassDeclaration(Identifier(GetClassNameForModule(kv.Key)))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword))
.AddMembers(kv.ToArray()))
: from entry in this.committedCode.MembersByModule
select ClassDeclaration(Identifier(this.SingleClassName))
IEnumerable<MemberDeclarationSyntax> result =
from entry in this.committedCode.MembersByModule
select ClassDeclaration(Identifier(this.options.ClassName))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword))
.AddMembers(entry.ToArray())
.WithLeadingTrivia(ParseLeadingTrivia(string.Format(CultureInfo.InvariantCulture, PartialPInvokeContentComment, entry.Key)))
.WithAdditionalAnnotations(new SyntaxAnnotation(SimpleFileNameAnnotation, $"{this.SingleClassName}.{entry.Key}"));
.WithAdditionalAnnotations(new SyntaxAnnotation(SimpleFileNameAnnotation, $"{this.options.ClassName}.{entry.Key}"));
result = result.Concat(this.committedCode.GeneratedTypes);

ClassDeclarationSyntax constantClass = this.DeclareConstantDefiningClass();
if (constantClass.Members.Count > 0)
{
result = result.Concat(new MemberDeclarationSyntax[] { constantClass });
}

ClassDeclarationSyntax inlineArrayIndexerExtensionsClass = this.DeclareInlineArrayIndexerExtensionsClass();
if (inlineArrayIndexerExtensionsClass.Members.Count > 0)
{
Expand All @@ -436,6 +422,11 @@ select ClassDeclaration(Identifier(this.SingleClassName))
result = result.Concat(new MemberDeclarationSyntax[] { comInterfaceFriendlyExtensionsClass });
}

if (this.committedCode.Fields.Any())
{
result = result.Concat(new MemberDeclarationSyntax[] { this.DeclareConstantDefiningClass() });
}

return result;
}
}
Expand Down Expand Up @@ -1391,9 +1382,7 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
string releaseMethodModule = this.GetNormalizedModuleName(releaseMethodDef.GetImport());

var safeHandleTypeIdentifier = IdentifierName(safeHandleClassName);
safeHandleType = this.GroupByModule
? QualifiedName(IdentifierName(releaseMethodModule), safeHandleTypeIdentifier)
: safeHandleTypeIdentifier;
safeHandleType = safeHandleTypeIdentifier;

MethodSignature<TypeHandleInfo> releaseMethodSignature = releaseMethodDef.DecodeSignature(SignatureHandleProvider.Instance, null);
var releaseMethodParameterType = releaseMethodSignature.ParameterTypes[0].ToTypeSyntax(this.externSignatureTypeSettings, default);
Expand Down Expand Up @@ -1478,7 +1467,7 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
ExpressionSyntax releaseInvocation = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(this.GroupByModule ? releaseMethodModule : this.SingleClassName),
IdentifierName(this.options.ClassName),
IdentifierName(renamedReleaseMethod ?? releaseMethod)),
ArgumentList().AddArguments(Argument(CastExpression(releaseMethodParameterType.Type, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("handle"))))));
BlockSyntax? releaseBlock = null;
Expand Down Expand Up @@ -1523,12 +1512,12 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
break;
case "NTSTATUS":
this.TryGenerateConstantOrThrow("STATUS_SUCCESS");
ExpressionSyntax statusSuccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.constantsClassName, IdentifierName("STATUS_SUCCESS"));
ExpressionSyntax statusSuccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.methodsAndConstantsClassName, IdentifierName("STATUS_SUCCESS"));
releaseInvocation = BinaryExpression(SyntaxKind.EqualsExpression, releaseInvocation, statusSuccess);
break;
case "HRESULT":
this.TryGenerateConstantOrThrow("S_OK");
ExpressionSyntax ok = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.constantsClassName, IdentifierName("S_OK"));
ExpressionSyntax ok = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.methodsAndConstantsClassName, IdentifierName("S_OK"));
releaseInvocation = BinaryExpression(SyntaxKind.EqualsExpression, releaseInvocation, ok);
break;
default:
Expand All @@ -1555,16 +1544,11 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
.AddMembers(members.ToArray())
.WithLeadingTrivia(ParseLeadingTrivia($@"
/// <summary>
/// Represents a Win32 handle that can be closed with <see cref=""{(this.GroupByModule ? releaseMethodModule : this.SingleClassName)}.{renamedReleaseMethod ?? releaseMethod}""/>.
/// Represents a Win32 handle that can be closed with <see cref=""{this.options.ClassName}.{renamedReleaseMethod ?? releaseMethod}""/>.
/// </summary>
"));

this.volatileCode.AddSafeHandleType(safeHandleDeclaration);
if (this.GroupByModule)
{
this.volatileCode.AddMemberToModule(releaseMethodModule, safeHandleDeclaration);
}

return safeHandleType;
}

Expand Down Expand Up @@ -2667,8 +2651,7 @@ private void DeclareExternMethod(MethodDefinitionHandle methodDefinitionHandle)
methodDeclaration = methodDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.UnsafeKeyword));
}

NameSyntax declaringTypeName = ParseName(this.GroupByModule ? GetClassNameForModule(moduleName) : this.SingleClassName);
this.volatileCode.AddMemberToModule(moduleName, this.DeclareFriendlyOverloads(methodDefinition, methodDeclaration, declaringTypeName, FriendlyOverloadOf.ExternMethod));
this.volatileCode.AddMemberToModule(moduleName, this.DeclareFriendlyOverloads(methodDefinition, methodDeclaration, this.methodsAndConstantsClassName, FriendlyOverloadOf.ExternMethod));
this.volatileCode.AddMemberToModule(moduleName, methodDeclaration);
}
catch (Exception ex)
Expand Down Expand Up @@ -2847,7 +2830,7 @@ private FieldDeclarationSyntax DeclareConstant(FieldDefinitionHandle fieldDefHan

private ClassDeclarationSyntax DeclareConstantDefiningClass()
{
return ClassDeclaration(this.constantsClassName.Identifier)
return ClassDeclaration(this.methodsAndConstantsClassName.Identifier)
.AddMembers(this.committedCode.Fields.ToArray())
.WithModifiers(TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword)));
}
Expand Down
13 changes: 6 additions & 7 deletions src/Microsoft.Windows.CsWin32/GeneratorOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,10 @@ public record GeneratorOptions
public bool WideCharOnly { get; init; } = true;

/// <summary>
/// Gets the name of a single class under which all p/invoke methods are generated, regardless of imported module. Use <see langword="null"/> for one class per imported module.
/// Gets the name of a single class under which all p/invoke methods and constants are generated, regardless of imported module.
/// </summary>
/// <value>The default value is "PInvoke".</value>
public string? MethodsClassName { get; init; } = "PInvoke";

/// <summary>
/// Gets the name of the single class under which all constants are generated.
/// </summary>
public string ConstantsClassName { get; init; } = "Constants";
public string ClassName { get; init; } = "PInvoke";

/// <summary>
/// Gets a value indicating whether to emit a single source file as opposed to types spread across many files.
Expand Down Expand Up @@ -56,6 +51,10 @@ public record GeneratorOptions
/// <exception cref="InvalidOperationException">Thrown when some setting is invalid.</exception>
public void Validate()
{
if (string.IsNullOrWhiteSpace(this.ClassName))
{
throw new InvalidOperationException("The ClassName property must not be null or empty.");
}
}

/// <summary>
Expand Down
12 changes: 3 additions & 9 deletions src/Microsoft.Windows.CsWin32/settings.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,10 @@
"type": "boolean",
"default": false
},
"methodsClassName": {
"description": "The name of a single class under which all p/invoke methods are generated, regardless of imported module. Use null for one class per imported module.",
"type": [ "string", "null" ],
"default": "PInvoke",
"pattern": "^\\w+$"
},
"constantsClassName": {
"description": "The name of the single class under which all constants are generated.",
"className": {
"description": "The name of a single class under which all p/invoke methods and constants are generated, regardless of imported module.",
"type": "string",
"default": "Constants",
"default": "PInvoke",
"pattern": "^\\w+$"
},
"public": {
Expand Down
26 changes: 6 additions & 20 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,33 +1008,19 @@ internal static partial class InlineArrayIndexerExtensions
[Fact]
public void NullMethodsClass()
{
this.generator = this.CreateGenerator(new GeneratorOptions { MethodsClassName = null });
Assert.True(this.generator.TryGenerate("GetTickCount", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
Assert.Single(this.FindGeneratedType("Kernel32"));
Assert.Empty(this.FindGeneratedType("PInvoke"));
Assert.Throws<InvalidOperationException>(() => this.CreateGenerator(new GeneratorOptions { ClassName = null! }));
}

[Fact]
public void RenamedMethodsClass()
{
this.generator = this.CreateGenerator(new GeneratorOptions { MethodsClassName = "MyPInvoke" });
this.generator = this.CreateGenerator(new GeneratorOptions { ClassName = "MyPInvoke" });
Assert.True(this.generator.TryGenerate("GetTickCount", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
Assert.Single(this.FindGeneratedType("MyPInvoke"));
Assert.Empty(this.FindGeneratedType("PInvoke"));
}

[Fact]
public void RenamedConstantsClass()
{
this.generator = this.CreateGenerator(new GeneratorOptions { ConstantsClassName = "MyConstants" });
Assert.True(this.generator.TryGenerate("CDB_REPORT_BITS", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
Assert.Single(this.FindGeneratedType("MyConstants"));
Assert.Empty(this.FindGeneratedType("Constants"));
Assert.NotEmpty(this.FindGeneratedType("MyPInvoke"));
Assert.Empty(this.FindGeneratedType("PInvoke"));
}

[Theory, PairwiseData]
Expand All @@ -1059,15 +1045,15 @@ public void ProjectReferenceBetweenTwoGeneratingProjects(bool internalsVisibleTo
CSharpSyntaxTree.ParseText($@"[assembly: System.Runtime.CompilerServices.InternalsVisibleToAttribute(""{this.compilation.AssemblyName}"")]", this.parseOptions));
}

using var referencedGenerator = this.CreateGenerator(new GeneratorOptions { MethodsClassName = "P1" }, referencedProject);
using var referencedGenerator = this.CreateGenerator(new GeneratorOptions { ClassName = "P1" }, referencedProject);
Assert.True(referencedGenerator.TryGenerate("LockWorkStation", CancellationToken.None));
Assert.True(referencedGenerator.TryGenerate("CreateFile", CancellationToken.None));
referencedProject = this.AddGeneratedCode(referencedProject, referencedGenerator);
this.AssertNoDiagnostics(referencedProject);

// Now produce more code in a referencing project that includes at least one of the same types as generated in the referenced project.
this.compilation = this.compilation.AddReferences(referencedProject.ToMetadataReference());
this.generator = this.CreateGenerator(new GeneratorOptions { MethodsClassName = "P2" });
this.generator = this.CreateGenerator(new GeneratorOptions { ClassName = "P2" });
Assert.True(this.generator.TryGenerate("HidD_GetAttributes", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
Expand Down

0 comments on commit 3eb26b4

Please sign in to comment.