diff --git a/src/Microsoft.Windows.CsWin32/Generator.Com.cs b/src/Microsoft.Windows.CsWin32/Generator.Com.cs index 173dcc91..c4cf8436 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Com.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Com.cs @@ -99,8 +99,9 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type var members = new List(); var vtblMembers = new List(); TypeSyntaxSettings typeSettings = context.Filter(this.comSignatureTypeSettings); - IdentifierNameSyntax pThisLocal = IdentifierName("pThis"); - ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" && !this.IsNonCOMInterface(typeDef) ? Parameter(pThisLocal.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null; + IdentifierNameSyntax pThisParameterName = IdentifierName("pThis"); + ExpressionSyntax pThis = ThisPointer(PointerType(ifaceName)); + ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" && !this.IsNonCOMInterface(typeDef) ? Parameter(pThisParameterName.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null; List ccwMethodsToSkip = new(); List ccwEntrypointMethods = new(); IdentifierNameSyntax vtblParamName = IdentifierName("vtable"); @@ -177,7 +178,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type CastExpression(unmanagedDelegateType, ElementAccessExpression(vtblFieldName).AddArgumentListArguments(Argument(methodOffset)))); InvocationExpressionSyntax vtblInvocation = InvocationExpression(vtblIndexingExpression) .WithArgumentList(FixTrivia(ArgumentList() - .AddArguments(Argument(pThisLocal)) + .AddArguments(Argument(pThis)) .AddArguments(parameterList.Parameters.Select(p => Argument(IdentifierName(p.Identifier.ValueText)).WithRefKindKeyword(p.Modifiers.Count > 0 ? p.Modifiers[0] : default)).ToArray()))); MemberDeclarationSyntax? propertyOrMethod; @@ -205,28 +206,20 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type ArgumentSyntax resultArgument = funcPtrParameters.Parameters[1].Modifiers.Any(SyntaxKind.OutKeyword) ? Argument(resultLocal).WithRefKindKeyword(Token(SyntaxKind.OutKeyword)) : Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, resultLocal)); - StatementSyntax vtblInvocationStatement = ThrowOnHRFailure(vtblInvocation.WithArgumentList(ArgumentList().AddArguments(Argument(pThisLocal), resultArgument))); + StatementSyntax vtblInvocationStatement = ThrowOnHRFailure(vtblInvocation.WithArgumentList(ArgumentList().AddArguments(Argument(pThis), resultArgument))); // return __result; StatementSyntax returnStatement = ReturnStatement(resultLocal); body = Block().AddStatements( - FixedStatement( - VariableDeclaration(PointerType(ifaceName)).AddVariables( - VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))), - Block().AddStatements( - resultLocalDeclaration, - vtblInvocationStatement, - returnStatement)).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword))); + resultLocalDeclaration, + vtblInvocationStatement, + returnStatement); break; case SyntaxKind.SetAccessorDeclaration: // vtblInvoke(pThis, value).ThrowOnFailure(); - vtblInvocationStatement = ThrowOnHRFailure(vtblInvocation.WithArgumentList(ArgumentList().AddArguments(Argument(pThisLocal), Argument(IdentifierName("value"))))); - body = Block().AddStatements( - FixedStatement( - VariableDeclaration(PointerType(ifaceName)).AddVariables( - VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))), - vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword))); + vtblInvocationStatement = ThrowOnHRFailure(vtblInvocation.WithArgumentList(ArgumentList().AddArguments(Argument(pThis), Argument(IdentifierName("value"))))); + body = Block().AddStatements(vtblInvocationStatement); break; default: throw new NotSupportedException("Unsupported accessor kind: " + accessorKind); @@ -260,14 +253,15 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type } else { - StatementSyntax fixedBody; + BlockSyntax body; bool preserveSig = this.UsePreserveSigForComMethod(methodDefinition.Method, signature, ifaceName.Identifier.ValueText, methodName); if (preserveSig) { // return ... - fixedBody = IsVoid(returnType) - ? ExpressionStatement(vtblInvocation) - : ReturnStatement(vtblInvocation); + body = Block().AddStatements( + IsVoid(returnType) + ? ExpressionStatement(vtblInvocation) + : ReturnStatement(vtblInvocation)); } else { @@ -298,24 +292,17 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type // return __retVal; ReturnStatementSyntax returnStatement = ReturnStatement(retValLocalName); - fixedBody = Block().AddStatements(localRetValDecl, InvokeVtblAndThrow(), returnStatement); + body = Block().AddStatements(localRetValDecl, InvokeVtblAndThrow(), returnStatement); } else { // Remove the return type returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); - fixedBody = InvokeVtblAndThrow(); + body = Block().AddStatements(InvokeVtblAndThrow()); } } - // fixed (IPersist* pThis = &this) - FixedStatementSyntax fixedStatement = FixedStatement( - VariableDeclaration(PointerType(ifaceName)).AddVariables( - VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))), - fixedBody).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)); - BlockSyntax body = Block().AddStatements(fixedStatement); - methodDeclaration = MethodDeclaration( List(), modifiers: TokenList(TokenWithSpace(SyntaxKind.PublicKeyword)), // always use public so struct can implement the COM interface @@ -420,7 +407,7 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn) InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("ComHelpers"), IdentifierName("UnwrapCCW")), ArgumentList().AddArguments( - Argument(pThisLocal), + Argument(pThisParameterName), Argument(DeclarationExpression(NestedCOMInterfaceName.WithTrailingTrivia(Space), SingleVariableDesignation(objectLocal.Identifier))).WithRefKindKeyword(Token(SyntaxKind.OutKeyword)))))))); StatementSyntax ifNullReturnStatement = hrReturnType @@ -477,6 +464,15 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn) } } + static ExpressionSyntax ThisPointer(PointerTypeSyntax? typedPointer = null) + { + // (type*)Unsafe.AsPointer(ref this) + InvocationExpressionSyntax invocation = InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(Unsafe)), IdentifierName(nameof(Unsafe.AsPointer))), + ArgumentList().AddArguments(Argument(RefExpression(ThisExpression())))); + return typedPointer is not null ? CastExpression(typedPointer, invocation) : invocation; + } + // We expose the vtbl struct to support CCWs. IdentifierNameSyntax vtblStructName = IdentifierName("Vtbl"); StructDeclarationSyntax? vtblStruct = StructDeclaration(Identifier("Vtbl")).WithTrailingTrivia(Space)