diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs
index debe3bef..8293596c 100644
--- a/src/Microsoft.Windows.CsWin32/Generator.cs
+++ b/src/Microsoft.Windows.CsWin32/Generator.cs
@@ -159,7 +159,6 @@ public class Generator : IDisposable
private static readonly SyntaxToken SemicolonWithLineFeed = TokenWithLineFeed(SyntaxKind.SemicolonToken);
private static readonly IdentifierNameSyntax InlineArrayIndexerExtensionsClassName = IdentifierName("InlineArrayIndexerExtensions");
- private static readonly IdentifierNameSyntax ComInterfaceFriendlyExtensionsClassName = IdentifierName("FriendlyOverloadExtensions");
private static readonly TypeSyntax SafeHandleTypeSyntax = IdentifierName("SafeHandle");
private static readonly IdentifierNameSyntax IntPtrTypeSyntax = IdentifierName(nameof(IntPtr));
private static readonly IdentifierNameSyntax UIntPtrTypeSyntax = IdentifierName(nameof(UIntPtr));
@@ -607,11 +606,7 @@ ClassDeclarationSyntax DeclarePInvokeClass(string fileNameKey) => ClassDeclarati
result = result.Concat(new MemberDeclarationSyntax[] { inlineArrayIndexerExtensionsClass });
}
- ClassDeclarationSyntax comInterfaceFriendlyExtensionsClass = this.DeclareComInterfaceFriendlyExtensionsClass();
- if (comInterfaceFriendlyExtensionsClass.Members.Count > 0)
- {
- result = result.Concat(new MemberDeclarationSyntax[] { comInterfaceFriendlyExtensionsClass });
- }
+ result = result.Concat(this.committedCode.ComInterfaceExtensions);
if (this.committedCode.TopLevelFields.Any())
{
@@ -3429,14 +3424,6 @@ private ClassDeclarationSyntax DeclareInlineArrayIndexerExtensionsClass()
.AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute));
}
- private ClassDeclarationSyntax DeclareComInterfaceFriendlyExtensionsClass()
- {
- return ClassDeclaration(ComInterfaceFriendlyExtensionsClassName.Identifier)
- .AddMembers(this.committedCode.ComInterfaceExtensions.ToArray())
- .WithModifiers(TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword)))
- .AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute));
- }
-
///
/// Generates a type to represent a COM interface.
///
@@ -3724,9 +3711,8 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta
return null;
}
- IdentifierNameSyntax ifaceName = interfaceAsSubtype
- ? NestedCOMInterfaceName
- : IdentifierName(this.Reader.GetString(typeDef.Name));
+ string actualIfaceName = this.Reader.GetString(typeDef.Name);
+ IdentifierNameSyntax ifaceName = interfaceAsSubtype ? NestedCOMInterfaceName : IdentifierName(actualIfaceName);
TypeSyntaxSettings typeSettings = this.comSignatureTypeSettings;
// It is imperative that we generate methods for all base interfaces as well, ahead of any implemented by *this* interface.
@@ -3932,7 +3918,20 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta
// Only add overloads to instance collections after everything else is done,
// so we don't leave extension methods behind if we fail to generate the target interface.
- this.volatileCode.AddComInterfaceExtension(friendlyOverloads);
+ if (friendlyOverloads.Count > 0)
+ {
+ string ns = this.Reader.GetString(typeDef.Namespace);
+ if (this.TryStripCommonNamespace(ns, out string? strippedNamespace))
+ {
+ ns = strippedNamespace;
+ }
+
+ ClassDeclarationSyntax friendlyOverloadClass = ClassDeclaration(Identifier($"{ns.Replace('.', '_')}_{actualIfaceName}_Extensions"))
+ .WithMembers(List(friendlyOverloads))
+ .WithModifiers(TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword)))
+ .AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute));
+ this.volatileCode.AddComInterfaceExtension(friendlyOverloadClass);
+ }
return ifaceDeclaration;
}
@@ -6761,7 +6760,7 @@ private class GeneratedCode
private readonly List inlineArrayIndexerExtensionsMembers = new();
- private readonly List comInterfaceFriendlyExtensionsMembers = new();
+ private readonly List comInterfaceFriendlyExtensionsMembers = new();
private bool generating;
@@ -6784,7 +6783,7 @@ internal GeneratedCode(GeneratedCode parent)
internal IEnumerable GeneratedTopLevelTypes => this.specialTypes.Values.Where(st => st.TopLevel).Select(st => st.Type);
- internal IEnumerable ComInterfaceExtensions => this.comInterfaceFriendlyExtensionsMembers;
+ internal IReadOnlyCollection ComInterfaceExtensions => this.comInterfaceFriendlyExtensionsMembers;
internal IEnumerable InlineArrayIndexerExtensions => this.inlineArrayIndexerExtensionsMembers;
@@ -6866,13 +6865,13 @@ internal void AddInlineArrayIndexerExtension(MethodDeclarationSyntax inlineIndex
}
}
- internal void AddComInterfaceExtension(MethodDeclarationSyntax extension)
+ internal void AddComInterfaceExtension(ClassDeclarationSyntax extension)
{
this.ThrowIfNotGenerating();
this.comInterfaceFriendlyExtensionsMembers.Add(extension);
}
- internal void AddComInterfaceExtension(IEnumerable extension)
+ internal void AddComInterfaceExtension(IEnumerable extension)
{
this.ThrowIfNotGenerating();
this.comInterfaceFriendlyExtensionsMembers.AddRange(extension);