diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs index 91028ff5a6ebd..25e09c52a6008 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs @@ -426,6 +426,11 @@ public static SyntaxTokenList GetManagedParameterModifiers(TypePositionInfo type } } + if (typeInfo.IsExplicitThis) + { + tokens = tokens.Add(Token(SyntaxKind.ThisKeyword)); + } + return tokens; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypePositionInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypePositionInfo.cs index 49ebf92f600ec..69654c34c4f75 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypePositionInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypePositionInfo.cs @@ -3,9 +3,10 @@ using System; using System.Collections.Generic; - +using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Microsoft.Interop @@ -77,6 +78,7 @@ public static int IncrementIndex(int index) public int ManagedIndex { get; init; } = UnsetIndex; public int NativeIndex { get; init; } = UnsetIndex; + public bool IsExplicitThis { get; init; } public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, MarshallingInfo marshallingInfo, Compilation compilation) { @@ -88,7 +90,8 @@ public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, RefKind = paramSymbol.RefKind, ByValueContentsMarshalKind = byValueContentsMarshalKind, ByValueMarshalAttributeLocations = (inLocation, outLocation), - ScopedKind = paramSymbol.ScopedKind + ScopedKind = paramSymbol.ScopedKind, + IsExplicitThis = ((ParameterSyntax)paramSymbol.DeclaringSyntaxReferences[0].GetSyntax()).Modifiers.Any(SyntaxKind.ThisKeyword) }; return typeInfo; diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.cs index e082fc4b25417..c0d7398b039c1 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.cs @@ -115,6 +115,73 @@ public static partial BoolStruct[] NegateBools( } } + public static partial class ArrayNativeExtensions + { + // The first parameter of a 'ref' extension method must be a value type or a generic type constrained to struct. + // The first 'in' or 'ref readonly' parameter of the extension method must be a concrete (non-generic) value type. + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int Sum(this int[] values, int numValues); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int Sum(this ref int values, int numValues); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "sum_char_array", StringMarshalling = StringMarshalling.Utf16)] + public static partial int SumChars(this char[] chars, int numElements); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "fill_char_array", StringMarshalling = StringMarshalling.Utf16)] + public static partial void FillChars([Out] this char[] chars, int length, ushort start); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] + public static partial int SumStringLengths([MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr)] this string[] strArray); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "reverse_strings_return")] + [return: MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr, SizeParamIndex = 1)] + public static partial string[] ReverseStrings_Return([MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr)] this string[] strArray, out int numElements); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "reverse_strings_out")] + public static partial void ReverseStrings_Out([MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr)] this string[] strArray, out int numElements, [MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr, SizeParamIndex = 1)] out string[] res); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "get_long_bytes")] + [return: MarshalAs(UnmanagedType.LPArray, SizeConst = sizeof(long))] + public static partial byte[] GetLongBytes(this long l); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "fill_range_array")] + [return: MarshalAs(UnmanagedType.U1)] + public static partial bool FillRangeArray([Out] this IntStructWrapper[] array, int length, int start); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "double_values")] + public static partial void DoubleValues([In, Out] this IntStructWrapper[] array, int length); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "and_bool_struct_array")] + [return: MarshalAs(UnmanagedType.U1)] + public static partial bool AndAllMembers(this BoolStruct[] pArray, int length); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "negate_bool_struct_array_out")] + public static partial void NegateBools( + this BoolStruct[] boolStruct, + int numValues, + [MarshalUsing(CountElementName = "numValues")] out BoolStruct[] pBoolStructOut); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "negate_bool_struct_array_return")] + [return: MarshalUsing(CountElementName = "numValues")] + public static partial BoolStruct[] NegateBools( + this BoolStruct[] boolStruct, + int numValues); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "transpose_matrix")] + [return: MarshalUsing(CountElementName = "numColumns")] + [return: MarshalUsing(CountElementName = "numRows", ElementIndirectionDepth = 1)] + public static partial int[][] TransposeMatrix(this int[][] matrix, int[] numRows, int numColumns); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "sum_int_ptr_array")] + public static unsafe partial int Sum(this int*[] values, int numValues); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "return_duplicate_int_ptr_array")] + [return: MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)] + public static unsafe partial int*[] ReturnDuplicate(this int*[] values, int numValues); + } + public class ArrayTests { private int[] GetIntArray() => new[] { 1, 5, 79, 165, 32, 3 }; @@ -126,6 +193,13 @@ public void IntArray_ByValue() Assert.Equal(array.Sum(), NativeExportsNE.Arrays.Sum(array, array.Length)); } + [Fact] + public void IntArray_ByValue_This() + { + int[] array = GetIntArray(); + Assert.Equal(array.Sum(), array.Sum(array.Length)); + } + [Fact] public void IntArray_RefToFirstElement() { @@ -133,6 +207,13 @@ public void IntArray_RefToFirstElement() Assert.Equal(array.Sum(), NativeExportsNE.Arrays.Sum(ref array[0], array.Length)); } + [Fact] + public void IntArray_RefToFirstElement_This() + { + int[] array = GetIntArray(); + Assert.Equal(array.Sum(), array[0].Sum(array.Length)); + } + [Fact] public void NullIntArray_ByValue() { @@ -147,6 +228,13 @@ public void ZeroLengthArray_MarshalledAsNonNull() Assert.Equal(0, NativeExportsNE.Arrays.Sum(array, array.Length)); } + [Fact] + public void ZeroLengthArray_MarshalledAsNonNull_This() + { + var array = new int[0]; + Assert.Equal(0, array.Sum(array.Length)); + } + [Fact] public void IntArray_In() { @@ -170,6 +258,13 @@ public void CharArray_ByValue() Assert.Equal(array.Sum(c => c), NativeExportsNE.Arrays.SumChars(array, array.Length)); } + [Fact] + public void CharArray_ByValue_This() + { + char[] array = CharacterTests.CharacterMappings().Select(o => (char)o[0]).ToArray(); + Assert.Equal(array.Sum(c => c), array.SumChars(array.Length)); + } + [Fact] public void CharArray_Ref() { @@ -219,6 +314,22 @@ public unsafe void PointerArray_ByValue() } } + [Fact] + public unsafe void PointerArray_ByValue_This() + { + int[] array = GetIntArray(); + fixed (int* arrayPointer = array) + { + int*[] pointerArray = new int*[array.Length]; + for (int i = 0; i < array.Length; i++) + { + pointerArray[i] = &arrayPointer[i]; + } + + Assert.Equal(array.Sum(), pointerArray.Sum(pointerArray.Length)); + } + } + [Fact] public unsafe void PointerArray_In() { @@ -281,6 +392,28 @@ public unsafe void PointerArray_Return() } } + [Fact] + public unsafe void PointerArray_Return_This() + { + int[] array = GetIntArray(); + fixed (int* arrayPointer = array) + { + int*[] pointerArray = new int*[array.Length]; + for (int i = 0; i < array.Length; i++) + { + pointerArray[i] = &arrayPointer[i]; + } + + int*[] res = pointerArray.ReturnDuplicate(pointerArray.Length); + Assert.Equal(pointerArray.Length, res.Length); + for (int i = 0; i < pointerArray.Length; i++) + { + Assert.Equal((IntPtr)pointerArray[i], (IntPtr)res[i]); + Assert.Equal(*pointerArray[i], *res[i]); + } + } + } + private static string[] GetStringArray() { return new[] @@ -301,12 +434,26 @@ public void ArrayWithElementMarshalling_ByValue() Assert.Equal(strings.Sum(str => str?.Length ?? 0), NativeExportsNE.Arrays.SumStringLengths(strings)); } + [Fact] + public void ArrayWithElementMarshalling_ByValue_This() + { + var strings = GetStringArray(); + Assert.Equal(strings.Sum(str => str?.Length ?? 0), strings.SumStringLengths()); + } + [Fact] public void NullArrayWithElementMarshalling_ByValue() { Assert.Equal(0, NativeExportsNE.Arrays.SumStringLengths(null)); } + [Fact] + public void NullArrayWithElementMarshalling_ByValue_This() + { + string[] strings = null; + Assert.Equal(0, strings.SumStringLengths()); + } + [Fact] public void ArrayWithElementMarshalling_Ref() { @@ -329,6 +476,18 @@ public void ArrayWithElementMarshalling_Return() Assert.Equal(expectedStrings, res); } + [Fact] + public void ArrayWithElementMarshalling_Return_This() + { + var strings = GetStringArray(); + var expectedStrings = strings.Select(s => ReverseChars(s)).ToArray(); + Assert.Equal(expectedStrings, strings.ReverseStrings_Return(out _)); + + string[] res; + strings.ReverseStrings_Out(out _, out res); + Assert.Equal(expectedStrings, res); + } + [Fact] public void NullArrayWithElementMarshalling_Ref() { @@ -349,6 +508,17 @@ public void NullArrayWithElementMarshalling_Return() Assert.Null(res); } + [Fact] + public void NullArrayWithElementMarshalling_Return_This() + { + string[] strings = null; + Assert.Null(strings.ReverseStrings_Return(out _)); + + string[] res; + strings.ReverseStrings_Out(out _, out res); + Assert.Null(res); + } + [Fact] public void ConstantSizeArray() { @@ -357,6 +527,14 @@ public void ConstantSizeArray() Assert.Equal(longVal, MemoryMarshal.Read(NativeExportsNE.Arrays.GetLongBytes(longVal))); } + [Fact] + public void ConstantSizeArray_This() + { + var longVal = 0x12345678ABCDEF10L; + + Assert.Equal(longVal, MemoryMarshal.Read(longVal.GetLongBytes())); + } + [Fact] public void DynamicSizedArrayWithConstantComponent() { @@ -400,6 +578,39 @@ public void Array_ByValueOut() } } + [Fact] + public void Array_ByValueOut_This() + { + { + var testArray = new IntStructWrapper[10]; + int start = 5; + + testArray.FillRangeArray(testArray.Length, start); + Assert.Equal(Enumerable.Range(start, testArray.Length), testArray.Select(wrapper => wrapper.Value)); + + // Any items not populated by the invoke target should be initialized to default + testArray = new IntStructWrapper[10]; + int lengthToFill = testArray.Length / 2; + testArray.FillRangeArray(lengthToFill, start); + Assert.Equal(Enumerable.Range(start, lengthToFill), testArray[..lengthToFill].Select(wrapper => wrapper.Value)); + Assert.All(testArray[lengthToFill..], wrapper => Assert.Equal(0, wrapper.Value)); + } + { + var testArray = new char[10]; + ushort start = 65; + + testArray.FillChars(testArray.Length, start); + Assert.Equal(Enumerable.Range(start, testArray.Length), testArray.Select(c => (int)c)); + + // Any items not populated by the invoke target should be initialized to default + testArray = new char[10]; + int lengthToFill = testArray.Length / 2; + testArray.FillChars(lengthToFill, start); + Assert.Equal(Enumerable.Range(start, lengthToFill), testArray[..lengthToFill].Select(c => (int)c)); + Assert.All(testArray[lengthToFill..], c => Assert.Equal(0, c)); + } + } + [Fact] public void Array_ByValueInOut() { @@ -412,6 +623,18 @@ public void Array_ByValueInOut() Assert.Equal(testValues.Select(wrapper => wrapper.Value * 2), testArray.Select(wrapper => wrapper.Value)); } + [Fact] + public void Array_ByValueInOut_This() + { + var testValues = Enumerable.Range(42, 15).Select(i => new IntStructWrapper { Value = i }); + + var testArray = testValues.ToArray(); + + testArray.DoubleValues(testArray.Length); + + Assert.Equal(testValues.Select(wrapper => wrapper.Value * 2), testArray.Select(wrapper => wrapper.Value)); + } + [Theory] [InlineData(true)] [InlineData(false)] @@ -421,6 +644,15 @@ public void NonBlittableElementArray_ByValue(bool result) Assert.Equal(result, NativeExportsNE.Arrays.AndAllMembers(array, array.Length)); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public void NonBlittableElementArray_ByValue_This(bool result) + { + BoolStruct[] array = GetBoolStructsToAnd(result); + Assert.Equal(result, array.AndAllMembers(array.Length)); + } + [Theory] [InlineData(true)] [InlineData(false)] @@ -451,6 +683,17 @@ public void NonBlittableElementArray_Out() Assert.Equal(expected, result); } + [Fact] + public void NonBlittableElementArray_Out_This() + { + BoolStruct[] array = GetBoolStructsToNegate(); + BoolStruct[] expected = GetNegatedBoolStructs(array); + + BoolStruct[] result; + array.NegateBools(array.Length, out result); + Assert.Equal(expected, result); + } + [Fact] public void NonBlittableElementArray_Return() { @@ -461,6 +704,16 @@ public void NonBlittableElementArray_Return() Assert.Equal(expected, result); } + [Fact] + public void NonBlittableElementArray_Return_This() + { + BoolStruct[] array = GetBoolStructsToNegate(); + BoolStruct[] expected = GetNegatedBoolStructs(array); + + BoolStruct[] result = array.NegateBools(array.Length); + Assert.Equal(expected, result); + } + private static BoolStruct[] GetBoolStructsToAnd(bool result) => new BoolStruct[] { new BoolStruct @@ -544,6 +797,36 @@ public void ArraysOfArrays() } } + [Fact] + public void ArraysOfArrays_This() + { + var random = new Random(42); + int numRows = random.Next(1, 5); + int numColumns = random.Next(1, 5); + int[][] matrix = new int[numRows][]; + for (int i = 0; i < numRows; i++) + { + matrix[i] = new int[numColumns]; + for (int j = 0; j < numColumns; j++) + { + matrix[i][j] = random.Next(); + } + } + + int[] numRowsArray = new int[numColumns]; + numRowsArray.AsSpan().Fill(numRows); + + int[][] transposed = matrix.TransposeMatrix(numRowsArray, numColumns); + + for (int i = 0; i < numRows; i++) + { + for (int j = 0; j < numColumns; j++) + { + Assert.Equal(matrix[i][j], transposed[j][i]); + } + } + } + private static string ReverseChars(string value) { if (value == null) diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/BlittableStructTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/BlittableStructTests.cs index f02379ddcabf2..c3a6ea8588837 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/BlittableStructTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/BlittableStructTests.cs @@ -41,6 +41,16 @@ public static partial void IncrementInvertPointerFieldsRefReturn( PointerFields input, ref PointerFields result); } + public static partial class IntStructExtensions + { + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "blittablestructs_return_instance")] + public static partial IntFields DoubleIntFields(this IntFields result); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "blittablestructs_double_intfields_refreturn")] + public static partial void DoubleIntFieldsOutReturn( + this IntFields input, + out IntFields result); + } public class BlittableStructTests { @@ -67,6 +77,11 @@ public void ValidateIntFields() Assert.Equal(initial, input); Assert.Equal(expected, result); } + { + var result = input.DoubleIntFields(); + Assert.Equal(initial, input); + Assert.Equal(expected, result); + } { var result = new IntFields(); NativeExportsNE.DoubleIntFieldsRefReturn(input, ref result); @@ -80,6 +95,12 @@ public void ValidateIntFields() Assert.Equal(initial, input); Assert.Equal(expected, result); } + { + IntFields result; + input.DoubleIntFieldsOutReturn(out result); + Assert.Equal(initial, input); + Assert.Equal(expected, result); + } { input = initial; diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs index 32a0d9a414e19..6204b100ce325 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs @@ -532,6 +532,15 @@ partial class Test } """; + public static string ExplicitThis => $$""" + using System.Runtime.InteropServices; + static partial class StringNativeExtensions + { + [LibraryImport("DoesNotExist")] + public static partial void Method(this int t); + } + """; + public static string BasicParametersAndModifiers(string preDeclaration = "") => BasicParametersAndModifiers(typeof(T).ToString(), preDeclaration); /// diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs index 4c80ac61cc460..43fdaf4930037 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs @@ -42,6 +42,7 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { ID(), CodeSnippets.DefaultParameters }; yield return new[] { ID(), CodeSnippets.UseCSharpFeaturesForConstants }; yield return new[] { ID(), CodeSnippets.LibraryImportInRefStruct }; + yield return new[] { ID(), CodeSnippets.ExplicitThis }; // Parameter / return types yield return new[] { ID(), CodeSnippets.BasicParametersAndModifiers() }; @@ -719,7 +720,7 @@ public class Basic { } [Theory] [MemberData(nameof(CodeSnippetsToVerifyNoTreesProduced))] - public async Task ValidateNoGeneratedOuptutForNoImport(string id, string source, TestTargetFramework framework) + public async Task ValidateNoGeneratedOutputForNoImport(string id, string source, TestTargetFramework framework) { TestUtils.Use(id); var test = new NoChangeTest(framework)