diff --git a/sources/LLVMSharp.Interop/Extensions/LLVMValueRef.cs b/sources/LLVMSharp.Interop/Extensions/LLVMValueRef.cs index 87db789..0c7af37 100644 --- a/sources/LLVMSharp.Interop/Extensions/LLVMValueRef.cs +++ b/sources/LLVMSharp.Interop/Extensions/LLVMValueRef.cs @@ -1161,4 +1161,8 @@ public void SetAlignment(uint Bytes) public readonly void ViewFunctionCFG() => LLVM.ViewFunctionCFG(this); public readonly void ViewFunctionCFGOnly() => LLVM.ViewFunctionCFGOnly(this); + + public readonly LLVMTypeRef GetFunctionType() => LLVM.TypeOf(this); + + public readonly LLVMTypeRef GetReturnType() => LLVM.GetReturnType(LLVM.TypeOf(this)); } diff --git a/tests/LLVMSharp.UnitTests/Functions.cs b/tests/LLVMSharp.UnitTests/Functions.cs index 4235b81..2de1819 100644 --- a/tests/LLVMSharp.UnitTests/Functions.cs +++ b/tests/LLVMSharp.UnitTests/Functions.cs @@ -28,4 +28,36 @@ public void AddsAttributeAtIndex() var attrs = functionValue.GetAttributesAtIndex((LLVMAttributeIndex)1); Assert.That((AttributeKind)attrs[0].KindAsEnum, Is.EqualTo(AttributeKind.ByVal)); } + + [Test] + public void LibLLVMSharp_GetFunctionType() + { + var module = LLVMModuleRef.CreateWithName("Test Module"); + var returnType = LLVMTypeRef.Int32; + var paramTypes = new[] { LLVMTypeRef.Int32, LLVMTypeRef.Int32 }; + var functionType = LLVMTypeRef.CreateFunction(returnType, paramTypes); + var functionValue = module.AddFunction("add", functionType); + + var retrievedFunctionType = functionValue.GetFunctionType(); + + Assert.That(retrievedFunctionType.Kind, Is.EqualTo(LLVMTypeKind.LLVMFunctionTypeKind)); + + Assert.That(retrievedFunctionType.GetReturnType().Kind, Is.EqualTo(LLVMTypeKind.LLVMIntegerTypeKind)); + + Assert.That(retrievedFunctionType.ParamTypesCount, Is.EqualTo(2)); + } + + [Test] + public void LibLLVMSharp_GetReturnType() + { + var module = LLVMModuleRef.CreateWithName("Test Module"); + var returnType = LLVMTypeRef.Int32; + var functionType = LLVMTypeRef.CreateFunction(returnType, [LLVMTypeRef.Int32, LLVMTypeRef.Int32]); + var functionValue = module.AddFunction("add", functionType); + + var retrievedReturnType = functionValue.GetReturnType(); + + Assert.That(retrievedReturnType.Kind, Is.EqualTo(LLVMTypeKind.LLVMIntegerTypeKind)); + Assert.That(retrievedReturnType.Handle, Is.EqualTo(returnType.Handle)); + } }