diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index debe3bef..8293596c 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -159,7 +159,6 @@ public class Generator : IDisposable private static readonly SyntaxToken SemicolonWithLineFeed = TokenWithLineFeed(SyntaxKind.SemicolonToken); private static readonly IdentifierNameSyntax InlineArrayIndexerExtensionsClassName = IdentifierName("InlineArrayIndexerExtensions"); - private static readonly IdentifierNameSyntax ComInterfaceFriendlyExtensionsClassName = IdentifierName("FriendlyOverloadExtensions"); private static readonly TypeSyntax SafeHandleTypeSyntax = IdentifierName("SafeHandle"); private static readonly IdentifierNameSyntax IntPtrTypeSyntax = IdentifierName(nameof(IntPtr)); private static readonly IdentifierNameSyntax UIntPtrTypeSyntax = IdentifierName(nameof(UIntPtr)); @@ -607,11 +606,7 @@ ClassDeclarationSyntax DeclarePInvokeClass(string fileNameKey) => ClassDeclarati result = result.Concat(new MemberDeclarationSyntax[] { inlineArrayIndexerExtensionsClass }); } - ClassDeclarationSyntax comInterfaceFriendlyExtensionsClass = this.DeclareComInterfaceFriendlyExtensionsClass(); - if (comInterfaceFriendlyExtensionsClass.Members.Count > 0) - { - result = result.Concat(new MemberDeclarationSyntax[] { comInterfaceFriendlyExtensionsClass }); - } + result = result.Concat(this.committedCode.ComInterfaceExtensions); if (this.committedCode.TopLevelFields.Any()) { @@ -3429,14 +3424,6 @@ private ClassDeclarationSyntax DeclareInlineArrayIndexerExtensionsClass() .AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute)); } - private ClassDeclarationSyntax DeclareComInterfaceFriendlyExtensionsClass() - { - return ClassDeclaration(ComInterfaceFriendlyExtensionsClassName.Identifier) - .AddMembers(this.committedCode.ComInterfaceExtensions.ToArray()) - .WithModifiers(TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword))) - .AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute)); - } - /// /// Generates a type to represent a COM interface. /// @@ -3724,9 +3711,8 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta return null; } - IdentifierNameSyntax ifaceName = interfaceAsSubtype - ? NestedCOMInterfaceName - : IdentifierName(this.Reader.GetString(typeDef.Name)); + string actualIfaceName = this.Reader.GetString(typeDef.Name); + IdentifierNameSyntax ifaceName = interfaceAsSubtype ? NestedCOMInterfaceName : IdentifierName(actualIfaceName); TypeSyntaxSettings typeSettings = this.comSignatureTypeSettings; // It is imperative that we generate methods for all base interfaces as well, ahead of any implemented by *this* interface. @@ -3932,7 +3918,20 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta // Only add overloads to instance collections after everything else is done, // so we don't leave extension methods behind if we fail to generate the target interface. - this.volatileCode.AddComInterfaceExtension(friendlyOverloads); + if (friendlyOverloads.Count > 0) + { + string ns = this.Reader.GetString(typeDef.Namespace); + if (this.TryStripCommonNamespace(ns, out string? strippedNamespace)) + { + ns = strippedNamespace; + } + + ClassDeclarationSyntax friendlyOverloadClass = ClassDeclaration(Identifier($"{ns.Replace('.', '_')}_{actualIfaceName}_Extensions")) + .WithMembers(List(friendlyOverloads)) + .WithModifiers(TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword))) + .AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute)); + this.volatileCode.AddComInterfaceExtension(friendlyOverloadClass); + } return ifaceDeclaration; } @@ -6761,7 +6760,7 @@ private class GeneratedCode private readonly List inlineArrayIndexerExtensionsMembers = new(); - private readonly List comInterfaceFriendlyExtensionsMembers = new(); + private readonly List comInterfaceFriendlyExtensionsMembers = new(); private bool generating; @@ -6784,7 +6783,7 @@ internal GeneratedCode(GeneratedCode parent) internal IEnumerable GeneratedTopLevelTypes => this.specialTypes.Values.Where(st => st.TopLevel).Select(st => st.Type); - internal IEnumerable ComInterfaceExtensions => this.comInterfaceFriendlyExtensionsMembers; + internal IReadOnlyCollection ComInterfaceExtensions => this.comInterfaceFriendlyExtensionsMembers; internal IEnumerable InlineArrayIndexerExtensions => this.inlineArrayIndexerExtensionsMembers; @@ -6866,13 +6865,13 @@ internal void AddInlineArrayIndexerExtension(MethodDeclarationSyntax inlineIndex } } - internal void AddComInterfaceExtension(MethodDeclarationSyntax extension) + internal void AddComInterfaceExtension(ClassDeclarationSyntax extension) { this.ThrowIfNotGenerating(); this.comInterfaceFriendlyExtensionsMembers.Add(extension); } - internal void AddComInterfaceExtension(IEnumerable extension) + internal void AddComInterfaceExtension(IEnumerable extension) { this.ThrowIfNotGenerating(); this.comInterfaceFriendlyExtensionsMembers.AddRange(extension);