Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Consider bind(c) when lowering calls to intrinsic module procedures #70386

Merged
merged 4 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 7 additions & 3 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2122,7 +2122,12 @@ genProcedureRef(CallContext &callContext) {
mlir::Location loc = callContext.loc;
if (auto *intrinsic = callContext.procRef.proc().GetSpecificIntrinsic())
return genIntrinsicRef(intrinsic, callContext);
if (Fortran::lower::isIntrinsicModuleProcRef(callContext.procRef))
// If it is an intrinsic module procedure reference - then treat as
// intrinsic unless it is bind(c) (since implementation is external from
// module).
if (Fortran::lower::isIntrinsicModuleProcRef(callContext.procRef) &&
!Fortran::semantics::IsBindCProcedure(
*callContext.procRef.proc().GetSymbol()))
return genIntrinsicRef(nullptr, callContext);

if (callContext.isStatementFunctionCall())
Expand Down Expand Up @@ -2227,8 +2232,7 @@ bool Fortran::lower::isIntrinsicModuleProcRef(
return false;
const Fortran::semantics::Symbol *module =
symbol->GetUltimate().owner().GetSymbol();
return module && module->attrs().test(Fortran::semantics::Attr::INTRINSIC) &&
module->name().ToString().find("omp_lib") == std::string::npos;
return module && module->attrs().test(Fortran::semantics::Attr::INTRINSIC);
}

std::optional<hlfir::EntityWithAttributes> Fortran::lower::convertCallToHLFIR(
Expand Down
16 changes: 8 additions & 8 deletions flang/module/omp_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
omp_atv_blocked = 17, &
omp_atv_interleaved = 18

type :: omp_alloctrait
type, bind(c) :: omp_alloctrait
integer(kind=omp_alloctrait_key_kind) :: key, value
end type omp_alloctrait

Expand Down Expand Up @@ -264,23 +264,23 @@
integer(kind=omp_integer_kind), intent(out) :: place_nums(*)
end subroutine omp_get_partition_place_nums

subroutine omp_set_affinity_format(format)
subroutine omp_set_affinity_format(format) bind(c)
import
character(len=*), intent(in) :: format
end subroutine omp_set_affinity_format

function omp_get_affinity_format(buffer)
function omp_get_affinity_format(buffer) bind(c)
import
character(len=*), intent(out) :: buffer
integer(kind=omp_integer_kind) :: omp_get_affinity_format
end function omp_get_affinity_format

subroutine omp_display_affinity(format)
subroutine omp_display_affinity(format) bind(c)
import
character(len=*), intent(in) :: format
end subroutine omp_display_affinity

function omp_capture_affinity(buffer, format)
function omp_capture_affinity(buffer, format) bind(c)
import
character(len=*), intent(out) :: buffer
character(len=*), intent(in) :: format
Expand Down Expand Up @@ -339,7 +339,7 @@
integer(kind=omp_integer_kind) :: omp_pause_resource
end function omp_pause_resource

function omp_pause_resource_all(kind)
function omp_pause_resource_all(kind) bind(c)
import
integer(kind=omp_pause_resource_kind), value :: kind
integer(kind=omp_integer_kind) :: omp_pause_resource_all
Expand Down Expand Up @@ -428,7 +428,7 @@
! Device Memory Routines

! Memory Management Routines
function omp_init_allocator(memspace, ntraits, traits)
function omp_init_allocator(memspace, ntraits, traits) bind(c)
import
integer(kind=omp_memspace_handle_kind), value :: memspace
integer, value :: ntraits
Expand All @@ -446,7 +446,7 @@
integer(kind=omp_allocator_handle_kind), value :: allocator
end subroutine omp_set_default_allocator

function omp_get_default_allocator()
function omp_get_default_allocator() bind(c)
import
integer(kind=omp_allocator_handle_kind) :: omp_get_default_allocator
end function omp_get_default_allocator
Expand Down
21 changes: 21 additions & 0 deletions flang/test/Lower/OpenMP/omp-lib-num-threads.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - 2>&1 | FileCheck %s
! RUN: bbc -fopenmp -emit-hlfir -o - %s 2>&1 | FileCheck %s
!
! Test that the calls to omp_lib's omp_get_num_threads and omp_set_num_threads
! get lowered even though their implementation is not in the omp_lib module
! (and this matters because this is an intrinsic module - and calls to
! intrinsics are specially resolved).

program main
use omp_lib
integer(omp_integer_kind) :: num_threads
integer(omp_integer_kind), parameter :: requested_num_threads = 4
call omp_set_num_threads(requested_num_threads)
num_threads = omp_get_num_threads()
print *, num_threads
end program

!CHECK-NOT: not yet implemented: intrinsic: omp_set_num_threads
!CHECK-NOT: not yet implemented: intrinsic: omp_get_num_threads
!CHECK: fir.call @omp_set_num_threads
!CHECK: fir.call @omp_get_num_threads