Skip to content

Commit

Permalink
Merge pull request #975 from microsoft/fix972
Browse files Browse the repository at this point in the history
Remove unnecessary `fixed` statement in COM struct
  • Loading branch information
AArnott committed Jun 23, 2023
2 parents 0bccff3 + ccdce47 commit 2c6ad76
Showing 1 changed file with 27 additions and 31 deletions.
58 changes: 27 additions & 31 deletions src/Microsoft.Windows.CsWin32/Generator.Com.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
var members = new List<MemberDeclarationSyntax>();
var vtblMembers = new List<MemberDeclarationSyntax>();
TypeSyntaxSettings typeSettings = context.Filter(this.comSignatureTypeSettings);
IdentifierNameSyntax pThisLocal = IdentifierName("pThis");
ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" && !this.IsNonCOMInterface(typeDef) ? Parameter(pThisLocal.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null;
IdentifierNameSyntax pThisParameterName = IdentifierName("pThis");
ExpressionSyntax pThis = ThisPointer(PointerType(ifaceName));
ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" && !this.IsNonCOMInterface(typeDef) ? Parameter(pThisParameterName.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null;
List<QualifiedMethodDefinitionHandle> ccwMethodsToSkip = new();
List<MemberDeclarationSyntax> ccwEntrypointMethods = new();
IdentifierNameSyntax vtblParamName = IdentifierName("vtable");
Expand Down Expand Up @@ -177,7 +178,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
CastExpression(unmanagedDelegateType, ElementAccessExpression(vtblFieldName).AddArgumentListArguments(Argument(methodOffset))));
InvocationExpressionSyntax vtblInvocation = InvocationExpression(vtblIndexingExpression)
.WithArgumentList(FixTrivia(ArgumentList()
.AddArguments(Argument(pThisLocal))
.AddArguments(Argument(pThis))
.AddArguments(parameterList.Parameters.Select(p => Argument(IdentifierName(p.Identifier.ValueText)).WithRefKindKeyword(p.Modifiers.Count > 0 ? p.Modifiers[0] : default)).ToArray())));

MemberDeclarationSyntax? propertyOrMethod;
Expand Down Expand Up @@ -205,28 +206,20 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
ArgumentSyntax resultArgument = funcPtrParameters.Parameters[1].Modifiers.Any(SyntaxKind.OutKeyword)
? Argument(resultLocal).WithRefKindKeyword(Token(SyntaxKind.OutKeyword))
: Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, resultLocal));
StatementSyntax vtblInvocationStatement = ThrowOnHRFailure(vtblInvocation.WithArgumentList(ArgumentList().AddArguments(Argument(pThisLocal), resultArgument)));
StatementSyntax vtblInvocationStatement = ThrowOnHRFailure(vtblInvocation.WithArgumentList(ArgumentList().AddArguments(Argument(pThis), resultArgument)));

// return __result;
StatementSyntax returnStatement = ReturnStatement(resultLocal);

body = Block().AddStatements(
FixedStatement(
VariableDeclaration(PointerType(ifaceName)).AddVariables(
VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))),
Block().AddStatements(
resultLocalDeclaration,
vtblInvocationStatement,
returnStatement)).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)));
resultLocalDeclaration,
vtblInvocationStatement,
returnStatement);
break;
case SyntaxKind.SetAccessorDeclaration:
// vtblInvoke(pThis, value).ThrowOnFailure();
vtblInvocationStatement = ThrowOnHRFailure(vtblInvocation.WithArgumentList(ArgumentList().AddArguments(Argument(pThisLocal), Argument(IdentifierName("value")))));
body = Block().AddStatements(
FixedStatement(
VariableDeclaration(PointerType(ifaceName)).AddVariables(
VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))),
vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)));
vtblInvocationStatement = ThrowOnHRFailure(vtblInvocation.WithArgumentList(ArgumentList().AddArguments(Argument(pThis), Argument(IdentifierName("value")))));
body = Block().AddStatements(vtblInvocationStatement);
break;
default:
throw new NotSupportedException("Unsupported accessor kind: " + accessorKind);
Expand Down Expand Up @@ -260,14 +253,15 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
}
else
{
StatementSyntax fixedBody;
BlockSyntax body;
bool preserveSig = this.UsePreserveSigForComMethod(methodDefinition.Method, signature, ifaceName.Identifier.ValueText, methodName);
if (preserveSig)
{
// return ...
fixedBody = IsVoid(returnType)
? ExpressionStatement(vtblInvocation)
: ReturnStatement(vtblInvocation);
body = Block().AddStatements(
IsVoid(returnType)
? ExpressionStatement(vtblInvocation)
: ReturnStatement(vtblInvocation));
}
else
{
Expand Down Expand Up @@ -298,24 +292,17 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
// return __retVal;
ReturnStatementSyntax returnStatement = ReturnStatement(retValLocalName);

fixedBody = Block().AddStatements(localRetValDecl, InvokeVtblAndThrow(), returnStatement);
body = Block().AddStatements(localRetValDecl, InvokeVtblAndThrow(), returnStatement);
}
else
{
// Remove the return type
returnType = PredefinedType(Token(SyntaxKind.VoidKeyword));

fixedBody = InvokeVtblAndThrow();
body = Block().AddStatements(InvokeVtblAndThrow());
}
}

// fixed (IPersist* pThis = &this)
FixedStatementSyntax fixedStatement = FixedStatement(
VariableDeclaration(PointerType(ifaceName)).AddVariables(
VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))),
fixedBody).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword));
BlockSyntax body = Block().AddStatements(fixedStatement);

methodDeclaration = MethodDeclaration(
List<AttributeListSyntax>(),
modifiers: TokenList(TokenWithSpace(SyntaxKind.PublicKeyword)), // always use public so struct can implement the COM interface
Expand Down Expand Up @@ -420,7 +407,7 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("ComHelpers"), IdentifierName("UnwrapCCW")),
ArgumentList().AddArguments(
Argument(pThisLocal),
Argument(pThisParameterName),
Argument(DeclarationExpression(NestedCOMInterfaceName.WithTrailingTrivia(Space), SingleVariableDesignation(objectLocal.Identifier))).WithRefKindKeyword(Token(SyntaxKind.OutKeyword))))))));

StatementSyntax ifNullReturnStatement = hrReturnType
Expand Down Expand Up @@ -477,6 +464,15 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
}
}

static ExpressionSyntax ThisPointer(PointerTypeSyntax? typedPointer = null)
{
// (type*)Unsafe.AsPointer(ref this)
InvocationExpressionSyntax invocation = InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(Unsafe)), IdentifierName(nameof(Unsafe.AsPointer))),
ArgumentList().AddArguments(Argument(RefExpression(ThisExpression()))));
return typedPointer is not null ? CastExpression(typedPointer, invocation) : invocation;
}

// We expose the vtbl struct to support CCWs.
IdentifierNameSyntax vtblStructName = IdentifierName("Vtbl");
StructDeclarationSyntax? vtblStruct = StructDeclaration(Identifier("Vtbl")).WithTrailingTrivia(Space)
Expand Down

0 comments on commit 2c6ad76

Please sign in to comment.