Skip to content

Commit

Permalink
Merge pull request #783 from microsoft/perfWork
Browse files Browse the repository at this point in the history
Distribute COM interface friendly overloads across many classes
  • Loading branch information
AArnott committed Nov 16, 2022
2 parents 67c41f7 + ba6b8c5 commit 52dde8b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
43 changes: 21 additions & 22 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ public class Generator : IDisposable

private static readonly SyntaxToken SemicolonWithLineFeed = TokenWithLineFeed(SyntaxKind.SemicolonToken);
private static readonly IdentifierNameSyntax InlineArrayIndexerExtensionsClassName = IdentifierName("InlineArrayIndexerExtensions");
private static readonly IdentifierNameSyntax ComInterfaceFriendlyExtensionsClassName = IdentifierName("FriendlyOverloadExtensions");
private static readonly TypeSyntax SafeHandleTypeSyntax = IdentifierName("SafeHandle");
private static readonly IdentifierNameSyntax IntPtrTypeSyntax = IdentifierName(nameof(IntPtr));
private static readonly IdentifierNameSyntax UIntPtrTypeSyntax = IdentifierName(nameof(UIntPtr));
Expand Down Expand Up @@ -607,11 +606,7 @@ ClassDeclarationSyntax DeclarePInvokeClass(string fileNameKey) => ClassDeclarati
result = result.Concat(new MemberDeclarationSyntax[] { inlineArrayIndexerExtensionsClass });
}

ClassDeclarationSyntax comInterfaceFriendlyExtensionsClass = this.DeclareComInterfaceFriendlyExtensionsClass();
if (comInterfaceFriendlyExtensionsClass.Members.Count > 0)
{
result = result.Concat(new MemberDeclarationSyntax[] { comInterfaceFriendlyExtensionsClass });
}
result = result.Concat(this.committedCode.ComInterfaceExtensions);

if (this.committedCode.TopLevelFields.Any())
{
Expand Down Expand Up @@ -3429,14 +3424,6 @@ private ClassDeclarationSyntax DeclareInlineArrayIndexerExtensionsClass()
.AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute));
}

private ClassDeclarationSyntax DeclareComInterfaceFriendlyExtensionsClass()
{
return ClassDeclaration(ComInterfaceFriendlyExtensionsClassName.Identifier)
.AddMembers(this.committedCode.ComInterfaceExtensions.ToArray())
.WithModifiers(TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword)))
.AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute));
}

/// <summary>
/// Generates a type to represent a COM interface.
/// </summary>
Expand Down Expand Up @@ -3724,9 +3711,8 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
return null;
}

IdentifierNameSyntax ifaceName = interfaceAsSubtype
? NestedCOMInterfaceName
: IdentifierName(this.Reader.GetString(typeDef.Name));
string actualIfaceName = this.Reader.GetString(typeDef.Name);
IdentifierNameSyntax ifaceName = interfaceAsSubtype ? NestedCOMInterfaceName : IdentifierName(actualIfaceName);
TypeSyntaxSettings typeSettings = this.comSignatureTypeSettings;

// It is imperative that we generate methods for all base interfaces as well, ahead of any implemented by *this* interface.
Expand Down Expand Up @@ -3932,7 +3918,20 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type

// Only add overloads to instance collections after everything else is done,
// so we don't leave extension methods behind if we fail to generate the target interface.
this.volatileCode.AddComInterfaceExtension(friendlyOverloads);
if (friendlyOverloads.Count > 0)
{
string ns = this.Reader.GetString(typeDef.Namespace);
if (this.TryStripCommonNamespace(ns, out string? strippedNamespace))
{
ns = strippedNamespace;
}

ClassDeclarationSyntax friendlyOverloadClass = ClassDeclaration(Identifier($"{ns.Replace('.', '_')}_{actualIfaceName}_Extensions"))
.WithMembers(List<MemberDeclarationSyntax>(friendlyOverloads))
.WithModifiers(TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword)))
.AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute));
this.volatileCode.AddComInterfaceExtension(friendlyOverloadClass);
}

return ifaceDeclaration;
}
Expand Down Expand Up @@ -6761,7 +6760,7 @@ private class GeneratedCode

private readonly List<MethodDeclarationSyntax> inlineArrayIndexerExtensionsMembers = new();

private readonly List<MethodDeclarationSyntax> comInterfaceFriendlyExtensionsMembers = new();
private readonly List<ClassDeclarationSyntax> comInterfaceFriendlyExtensionsMembers = new();

private bool generating;

Expand All @@ -6784,7 +6783,7 @@ internal GeneratedCode(GeneratedCode parent)

internal IEnumerable<MemberDeclarationSyntax> GeneratedTopLevelTypes => this.specialTypes.Values.Where(st => st.TopLevel).Select(st => st.Type);

internal IEnumerable<MethodDeclarationSyntax> ComInterfaceExtensions => this.comInterfaceFriendlyExtensionsMembers;
internal IReadOnlyCollection<ClassDeclarationSyntax> ComInterfaceExtensions => this.comInterfaceFriendlyExtensionsMembers;

internal IEnumerable<MethodDeclarationSyntax> InlineArrayIndexerExtensions => this.inlineArrayIndexerExtensionsMembers;

Expand Down Expand Up @@ -6866,13 +6865,13 @@ internal void AddInlineArrayIndexerExtension(MethodDeclarationSyntax inlineIndex
}
}

internal void AddComInterfaceExtension(MethodDeclarationSyntax extension)
internal void AddComInterfaceExtension(ClassDeclarationSyntax extension)
{
this.ThrowIfNotGenerating();
this.comInterfaceFriendlyExtensionsMembers.Add(extension);
}

internal void AddComInterfaceExtension(IEnumerable<MethodDeclarationSyntax> extension)
internal void AddComInterfaceExtension(IEnumerable<ClassDeclarationSyntax> extension)
{
this.ThrowIfNotGenerating();
this.comInterfaceFriendlyExtensionsMembers.AddRange(extension);
Expand Down
7 changes: 4 additions & 3 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -896,11 +896,12 @@ public void HasGeneratedCodeAttribute()
ClassDeclarationSyntax arrayExtensions = Assert.IsType<ClassDeclarationSyntax>(this.FindGeneratedType("InlineArrayIndexerExtensions").Single());
Assert.Contains(arrayExtensions.AttributeLists, al => al.Attributes.Any(a => a.Name.ToString().Contains("GeneratedCode")));

ClassDeclarationSyntax overloadsExtensions = Assert.IsType<ClassDeclarationSyntax>(this.FindGeneratedType("FriendlyOverloadExtensions").Single());
Assert.Contains(overloadsExtensions.AttributeLists, al => al.Attributes.Any(a => a.Name.ToString().Contains("GeneratedCode")));
Assert.All(
this.compilation.SyntaxTrees.SelectMany(st => st.GetRoot().DescendantNodes().OfType<BaseTypeDeclarationSyntax>()).Where(btd => btd.Identifier.ValueText.EndsWith("_Extensions", StringComparison.Ordinal)),
e => Assert.Contains(e.AttributeLists, al => al.Attributes.Any(a => a.Name.ToString().Contains("GeneratedCode"))));

ClassDeclarationSyntax sysFreeStringSafeHandleClass = Assert.IsType<ClassDeclarationSyntax>(this.FindGeneratedType("SysFreeStringSafeHandle").Single());
Assert.Contains(overloadsExtensions.AttributeLists, al => al.Attributes.Any(a => a.Name.ToString().Contains("GeneratedCode")));
Assert.Contains(sysFreeStringSafeHandleClass.AttributeLists, al => al.Attributes.Any(a => a.Name.ToString().Contains("GeneratedCode")));
}

[Fact]
Expand Down

0 comments on commit 52dde8b

Please sign in to comment.