From d45e2da42722720ac16e547566a020d545e25994 Mon Sep 17 00:00:00 2001 From: Abhinav Tripathi Date: Sun, 2 Aug 2015 00:25:38 +0530 Subject: [PATCH] Added marshalling of primitive ptrs to ref type. Signed-off-by: Dimitar Dobrev --- src/Generator/Driver.cs | 1 + .../Generators/CSharp/CSharpTextTemplate.cs | 57 ++++++++++++++++--- src/Generator/Generators/ExtensionMethods.cs | 25 +++++++- .../Passes/HandleDefaultParamValuesPass.cs | 12 +++- .../MarshalPrimitivePointersAsRefTypePass.cs | 17 ++++++ tests/CSharpTemp/CSharpTemp.Tests.cs | 36 +++++++++++- tests/CSharpTemp/CSharpTemp.cpp | 29 ++++++++++ tests/CSharpTemp/CSharpTemp.h | 13 +++++ 8 files changed, 176 insertions(+), 14 deletions(-) create mode 100644 src/Generator/Passes/MarshalPrimitivePointersAsRefTypePass.cs diff --git a/src/Generator/Driver.cs b/src/Generator/Driver.cs index b820159812..2e41ea5f0e 100644 --- a/src/Generator/Driver.cs +++ b/src/Generator/Driver.cs @@ -242,6 +242,7 @@ public void SetupPasses(ILibrary library) TranslationUnitPasses.AddPass(new SortDeclarationsPass()); TranslationUnitPasses.AddPass(new ResolveIncompleteDeclsPass()); TranslationUnitPasses.AddPass(new CheckIgnoredDeclsPass()); + TranslationUnitPasses.AddPass(new MarshalPrimitivePointersAsRefTypePass()); if (Options.IsCSharpGenerator && Options.GenerateInlines) TranslationUnitPasses.AddPass(new GenerateInlinesCodePass()); diff --git a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs index 7e47bba4b3..afdfda106c 100644 --- a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs +++ b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs @@ -2146,16 +2146,39 @@ public void GenerateMethod(Method method, Class @class) PopBlock(NewLineKind.BeforeNextBlock); } + private static string OverloadParamNameWithDefValue(Parameter p, ref int index) + { + return (p.Type.IsPointerToPrimitiveType() && p.Usage == ParameterUsage.InOut && p.HasDefaultValue) + ? "ref param" + index++ + : p.DefaultArgument.String; + } + private void GenerateOverloadCall(Function function) { - Type type = function.OriginalReturnType.Type; + for (int i = 0, j = 0; i < function.Parameters.Count; i++) + { + var parameter = function.Parameters[i]; + PrimitiveType primitiveType; + if (parameter.Kind == ParameterKind.Regular && parameter.Ignore && + parameter.Type.IsPointerToPrimitiveType(out primitiveType) && + parameter.Usage == ParameterUsage.InOut && parameter.HasDefaultValue) + { + var pointeeType = ((PointerType) parameter.Type).Pointee.ToString(); + WriteLine("{0} param{1} = {2};", pointeeType, j++, + primitiveType == PrimitiveType.Bool ? "false" : "0"); + } + } + + var index = 0; + var type = function.OriginalReturnType.Type; WriteLine("{0}{1}({2});", type.IsPrimitiveType(PrimitiveType.Void) ? string.Empty : "return ", function.Name, string.Join(", ", function.Parameters.Where( p => p.Kind == ParameterKind.Regular).Select( - p => p.Ignore ? p.DefaultArgument.String : p.Name))); + p => p.Ignore ? OverloadParamNameWithDefValue(p, ref index) : + (p.Usage == ParameterUsage.InOut ? "ref " : string.Empty) + p.Name))); } private void GenerateEquals(Function method, Class @class) @@ -2498,6 +2521,10 @@ into context if (needsFixedThis && operatorParam == null) WriteCloseBraceIndent(); + + var numFixedBlocks = @params.Count(param => param.HasFixedBlock); + for(var i = 0; i < numFixedBlocks; ++i) + WriteCloseBraceIndent(); } private int GetInstanceParamIndex(Method method) @@ -2519,6 +2546,8 @@ private int GetInstanceParamIndex(Method method) { var param = paramInfo.Param; if (!(param.IsOut || param.IsInOut)) continue; + if (param.IsPrimitiveParameterConvertibleToRef()) + continue; var nativeVarName = paramInfo.Name; @@ -2548,6 +2577,7 @@ public struct ParamMarshal public string Name; public Parameter Param; public CSharpMarshalContext Context; + public bool HasFixedBlock; } public List GenerateFunctionParamsMarshal(IEnumerable @params, @@ -2601,16 +2631,25 @@ public struct ParamMarshal paramMarshal.Context = ctx; - var marshal = new CSharpMarshalManagedToNativePrinter(ctx); - param.CSharpMarshalToNative(marshal); + if (param.IsPrimitiveParameterConvertibleToRef()) + { + WriteLine("fixed({0} {1} = &{2})", param.Type.CSharpType(TypePrinter), argName, param.Name); + paramMarshal.HasFixedBlock = true; + WriteStartBraceIndent(); + } + else + { + var marshal = new CSharpMarshalManagedToNativePrinter(ctx); + param.CSharpMarshalToNative(marshal); - if (string.IsNullOrEmpty(marshal.Context.Return)) - throw new Exception("Cannot marshal argument of function"); + if (string.IsNullOrEmpty(marshal.Context.Return)) + throw new Exception("Cannot marshal argument of function"); - if (!string.IsNullOrWhiteSpace(marshal.Context.SupportBefore)) - Write(marshal.Context.SupportBefore); + if (!string.IsNullOrWhiteSpace(marshal.Context.SupportBefore)) + Write(marshal.Context.SupportBefore); - WriteLine("var {0} = {1};", argName, marshal.Context.Return); + WriteLine("var {0} = {1};", argName, marshal.Context.Return); + } return paramMarshal; } diff --git a/src/Generator/Generators/ExtensionMethods.cs b/src/Generator/Generators/ExtensionMethods.cs index 9edcda4262..433d38d779 100644 --- a/src/Generator/Generators/ExtensionMethods.cs +++ b/src/Generator/Generators/ExtensionMethods.cs @@ -1,4 +1,7 @@ -using CppSharp.AST; +using System.Linq; +using System.Collections.Generic; +using CppSharp.AST; +using CppSharp.AST.Extensions; using Interop = System.Runtime.InteropServices; namespace CppSharp.Generators @@ -23,5 +26,25 @@ public static Interop.CallingConvention ToInteropCallConv(this CallingConvention return Interop.CallingConvention.Winapi; } + + public static bool IsPrimitiveParameterConvertibleToRef(this Parameter param) + { + var allowedToHaveDefaultPtrVals = new List + { + PrimitiveType.Bool, + PrimitiveType.Double, + PrimitiveType.Float, + PrimitiveType.Int, + PrimitiveType.Long, + PrimitiveType.LongLong, + PrimitiveType.Short, + PrimitiveType.UInt, + PrimitiveType.ULong, + PrimitiveType.ULongLong, + PrimitiveType.UShort + }; + return param.Type.IsPointerToPrimitiveType() + && allowedToHaveDefaultPtrVals.Any(primType => param.Type.IsPointerToPrimitiveType(primType)); + } } } diff --git a/src/Generator/Passes/HandleDefaultParamValuesPass.cs b/src/Generator/Passes/HandleDefaultParamValuesPass.cs index 83eb02b6fd..12d934bcc3 100644 --- a/src/Generator/Passes/HandleDefaultParamValuesPass.cs +++ b/src/Generator/Passes/HandleDefaultParamValuesPass.cs @@ -1,8 +1,10 @@ using System.Collections.Generic; using System.Linq; using System.Text.RegularExpressions; + using CppSharp.AST; using CppSharp.AST.Extensions; +using CppSharp.Generators; using CppSharp.Generators.CSharp; using CppSharp.Types; @@ -52,7 +54,8 @@ public override bool VisitFunctionDecl(Function function) parameter.QualifiedType.Qualifiers); if (defaultConstruct == null || (!Driver.Options.MarshalCharAsManagedChar && - parameter.Type.Desugar().IsPrimitiveType(PrimitiveType.UChar))) + parameter.Type.Desugar().IsPrimitiveType(PrimitiveType.UChar)) || + parameter.IsPrimitiveParameterConvertibleToRef()) { overloadIndices.Add(function.Parameters.IndexOf(parameter)); continue; @@ -103,6 +106,8 @@ private bool CheckForDefaultPointer(Type desugared, Parameter parameter) parameter.DefaultArgument.String = "new global::System.IntPtr()"; return true; } + if (parameter.IsPrimitiveParameterConvertibleToRef()) + return false; Class @class; if (desugared.GetFinalPointee().TryGetClass(out @class) && @class.IsValueType) { @@ -243,9 +248,10 @@ private void GenerateOverloads(Function function, List overloadIndices) Function overload = method != null ? new Method(method) : new Function(function); overload.OriginalFunction = function; overload.SynthKind = FunctionSynthKind.DefaultValueOverload; - overload.Parameters[overloadIndex].GenerationKind = GenerationKind.None; + for (int i = overloadIndex; i < function.Parameters.Count; ++i) + overload.Parameters[i].GenerationKind = GenerationKind.None; - var indices = overloadIndices.Where(i => i != overloadIndex).ToList(); + var indices = overloadIndices.Where(i => i < overloadIndex).ToList(); if (indices.Any()) for (int i = 0; i <= indices.Last(); i++) if (i != overloadIndex) diff --git a/src/Generator/Passes/MarshalPrimitivePointersAsRefTypePass.cs b/src/Generator/Passes/MarshalPrimitivePointersAsRefTypePass.cs new file mode 100644 index 0000000000..f0e80d9133 --- /dev/null +++ b/src/Generator/Passes/MarshalPrimitivePointersAsRefTypePass.cs @@ -0,0 +1,17 @@ +using System.Linq; +using CppSharp.AST; +using CppSharp.Generators; + +namespace CppSharp.Passes +{ + public class MarshalPrimitivePointersAsRefTypePass : TranslationUnitPass + { + public override bool VisitFunctionDecl(Function function) + { + foreach (var param in function.Parameters.Where( + p => !p.IsOut && p.IsPrimitiveParameterConvertibleToRef())) + param.Usage = ParameterUsage.InOut; + return base.VisitFunctionDecl(function); + } + } +} \ No newline at end of file diff --git a/tests/CSharpTemp/CSharpTemp.Tests.cs b/tests/CSharpTemp/CSharpTemp.Tests.cs index fe3d23686e..481f61f494 100644 --- a/tests/CSharpTemp/CSharpTemp.Tests.cs +++ b/tests/CSharpTemp/CSharpTemp.Tests.cs @@ -345,4 +345,38 @@ public void TestNullAttributedFunctionPtr() foo.AttributedFunctionPtr = null; } } -} \ No newline at end of file + + [Test] + public unsafe void TestMultiOverLoadPtrToRef() + { + var obj = new MultiOverloadPtrToRef(); + var p = obj.ReturnPrimTypePtr(); + Assert.AreEqual(0, p[0]); + Assert.AreEqual(0, p[1]); + Assert.AreEqual(0, p[2]); + + obj.TakePrimTypePtr(ref *p); + Assert.AreEqual(100, p[0]); + Assert.AreEqual(200, p[1]); + Assert.AreEqual(300, p[2]); + + int[] array = { 1, 2, 3 }; + fixed (int* p1 = array) + { + obj.TakePrimTypePtr(ref *p1); + Assert.AreEqual(100, p1[0]); + Assert.AreEqual(200, p1[1]); + Assert.AreEqual(300, p1[2]); + } + + Assert.AreEqual(100, array[0]); + Assert.AreEqual(200, array[1]); + Assert.AreEqual(300, array[2]); + + float pThree = 0; + var refInt = 0; + obj.FuncPrimitivePtrToRef(ref refInt, null, ref pThree); + obj.FuncPrimitivePtrToRefWithDefVal(ref refInt, null, null, ref refInt); + obj.FuncPrimitivePtrToRefWithMultiOverload(ref refInt, null, null, ref refInt); + } +} diff --git a/tests/CSharpTemp/CSharpTemp.cpp b/tests/CSharpTemp/CSharpTemp.cpp index ace57f3254..adb8ec63f9 100644 --- a/tests/CSharpTemp/CSharpTemp.cpp +++ b/tests/CSharpTemp/CSharpTemp.cpp @@ -688,3 +688,32 @@ HasOverrideOfHasPropertyWithDerivedType::HasOverrideOfHasPropertyWithDerivedType void HasOverrideOfHasPropertyWithDerivedType::causeRenamingError() { } + +void MultiOverloadPtrToRef::funcPrimitivePtrToRef(int* pOne, char* pTwo, float* pThree, bool* pFour) +{ +} + +void MultiOverloadPtrToRef::funcPrimitivePtrToRefWithDefVal(int* pOne, char* pTwo, Foo* pThree, int* pFour) +{ +} + +void MultiOverloadPtrToRef::funcPrimitivePtrToRefWithMultiOverload(int* pOne, char* pTwo, Foo* pThree, int* pFour, long* pFive) +{ +} + +MultiOverloadPtrToRef::MultiOverloadPtrToRef() +{ + arr = new int[3]{0}; +} + +int* MultiOverloadPtrToRef::ReturnPrimTypePtr() +{ + return arr; +} + +void MultiOverloadPtrToRef::TakePrimTypePtr(int* ptr) +{ + ptr[0] = 100; + ptr[1] = 200; + ptr[2] = 300; +} diff --git a/tests/CSharpTemp/CSharpTemp.h b/tests/CSharpTemp/CSharpTemp.h index d4a74db912..bc6f1a0b5b 100644 --- a/tests/CSharpTemp/CSharpTemp.h +++ b/tests/CSharpTemp/CSharpTemp.h @@ -617,3 +617,16 @@ class DLL_API HasOverrideOfHasPropertyWithDerivedType : public HasPropertyWithDe HasOverrideOfHasPropertyWithDerivedType(); virtual void causeRenamingError(); }; + +class DLL_API MultiOverloadPtrToRef +{ + int * arr; +public: + void funcPrimitivePtrToRef(int *pOne, char* pTwo, float* pThree, bool* pFour = 0); + void funcPrimitivePtrToRefWithDefVal(int* pOne, char* pTwo, Foo* pThree, int* pFour = 0); + void funcPrimitivePtrToRefWithMultiOverload(int* pOne, char* pTwo, Foo* pThree, int* pFour = 0, long* pFive = 0); + + MultiOverloadPtrToRef(); + int* ReturnPrimTypePtr(); + void TakePrimTypePtr(int* ptr); +};