From 6ee7dbef9be643dfb79b87cc9eaa35fc735cd6da Mon Sep 17 00:00:00 2001 From: Andrew Arnott Date: Tue, 18 Oct 2022 08:07:15 -0600 Subject: [PATCH] Add `CharSet = Unicode` to `extern` methods with `char` parameters Fixes #734 --- src/Microsoft.Windows.CsWin32/Generator.cs | 11 +++++++++-- .../GeneratorTests.cs | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index 04681130..d38b6f57 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -2300,7 +2300,7 @@ private static AttributeSyntax InterfaceType(ComInterfaceType interfaceType) IdentifierName(Enum.GetName(typeof(ComInterfaceType), interfaceType)!)))); } - private static AttributeSyntax DllImport(MethodImport import, string moduleName, string? entrypoint) + private static AttributeSyntax DllImport(MethodImport import, string moduleName, string? entrypoint, CharSet charSet = CharSet.Ansi) { List args = new(); AttributeSyntax? dllImportAttribute = Attribute(IdentifierName("DllImport")); @@ -2319,6 +2319,12 @@ private static AttributeSyntax DllImport(MethodImport import, string moduleName, .WithNameEquals(NameEquals(nameof(DllImportAttribute.SetLastError)))); } + if (charSet != CharSet.Ansi) + { + args.Add(AttributeArgument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(CharSet)), IdentifierName(Enum.GetName(typeof(CharSet), charSet)!))) + .WithNameEquals(NameEquals(IdentifierName(nameof(DllImportAttribute.CharSet))))); + } + dllImportAttribute = dllImportAttribute.WithArgumentList(FixTrivia(AttributeArgumentList().AddArguments(args.ToArray()))); return dllImportAttribute; } @@ -3123,6 +3129,7 @@ private void DeclareExternMethod(MethodDefinitionHandle methodDefinitionHandle) // If this method releases a handle, recreate the method signature such that we take the struct rather than the SafeHandle as a parameter. TypeSyntaxSettings typeSettings = this.MetadataIndex.ReleaseMethods.Contains(entrypoint ?? methodName) ? this.externReleaseSignatureTypeSettings : this.externSignatureTypeSettings; MethodSignature signature = methodDefinition.DecodeSignature(SignatureHandleProvider.Instance, null); + bool requiresUnicodeCharSet = signature.ParameterTypes.Any(p => p is PrimitiveTypeHandleInfo { PrimitiveTypeCode: PrimitiveTypeCode.Char }); CustomAttributeHandleCollection? returnTypeAttributes = this.GetReturnTypeCustomAttributes(methodDefinition); TypeSyntaxAndMarshaling returnType = signature.ReturnType.ToTypeSyntax(typeSettings, returnTypeAttributes, ParameterAttributes.Out); @@ -3131,7 +3138,7 @@ private void DeclareExternMethod(MethodDefinitionHandle methodDefinitionHandle) List() .Add(AttributeList() .WithCloseBracketToken(TokenWithLineFeed(SyntaxKind.CloseBracketToken)) - .AddAttributes(DllImport(import, moduleName, entrypoint))), + .AddAttributes(DllImport(import, moduleName, entrypoint, requiresUnicodeCharSet ? CharSet.Unicode : CharSet.Ansi))), modifiers: TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ExternKeyword)), returnType.Type.WithTrailingTrivia(TriviaList(Space)), explicitInterfaceSpecifier: null!, diff --git a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs index 5bcb903c..2c0cd78e 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs @@ -1307,6 +1307,23 @@ public void InOutPWSTRGetsRefSpanCharFriendlyOverload() Assert.Contains(generatedMethods, m => m.ParameterList.Parameters.Count == 1 && m.ParameterList.Parameters[0].Modifiers.Any(SyntaxKind.RefKeyword) && m.ParameterList.Parameters[0].Type?.ToString() == "Span"); } + [Fact] + public void UnicodeExtenMethodsGetCharSet() + { + const string MethodName = "VkKeyScan"; + this.generator = this.CreateGenerator(); + Assert.True(this.generator.TryGenerate(MethodName, CancellationToken.None)); + this.CollectGeneratedCode(this.generator); + this.AssertNoDiagnostics(); + MethodDeclarationSyntax generatedMethod = this.FindGeneratedMethod(MethodName).Single(); + Assert.Contains( + generatedMethod.AttributeLists.SelectMany(al => al.Attributes), + a => a.Name.ToString() == "DllImport" && + a.ArgumentList?.Arguments.Any(arg => arg is { + NameEquals.Name.Identifier.ValueText: nameof(DllImportAttribute.CharSet), + Expression: MemberAccessExpressionSyntax { Name: IdentifierNameSyntax { Identifier.ValueText: nameof(CharSet.Unicode) } } }) is true); + } + [Fact] public void NullMethodsClass() {