diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 6a515c2ba4e87..1daf60b8659bb 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -32,6 +32,7 @@ #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include #include #include @@ -2407,6 +2408,23 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, kernelInput.push_back(mapData.OriginalValue[i]); } + // Do some very basic handling of implicit captures that are caught + // by use in the target region. + // TODO/FIXME: Remove on addition of IsolatedFromAbove patch series + // as this will become redundant and perhaps erroneous in cases + // where more complex implicit capture semantics are required. + llvm::SetVector uses; + getUsedValuesDefinedAbove(targetRegion, uses); + + for (mlir::Value use : uses) { + llvm::Value *useValue = moduleTranslation.lookupValue(use); + if (useValue && + !std::any_of( + mapData.OriginalValue.begin(), mapData.OriginalValue.end(), + [&](llvm::Value *mapValue) { return mapValue == useValue; })) + kernelInput.push_back(useValue); + } + builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget( ompLoc, allocaIP, builder.saveIP(), entryInfo, defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB, argAccessorCB)); diff --git a/openmp/libomptarget/test/offloading/fortran/basic-target-region-1D-array-section.f90 b/openmp/libomptarget/test/offloading/fortran/basic-target-region-1D-array-section.f90 index 11d3b6936bcea..58f5379e330ec 100644 --- a/openmp/libomptarget/test/offloading/fortran/basic-target-region-1D-array-section.f90 +++ b/openmp/libomptarget/test/offloading/fortran/basic-target-region-1D-array-section.f90 @@ -14,10 +14,11 @@ program main integer :: write_arr(10) = (/0,0,0,0,0,0,0,0,0,0/) integer :: read_arr(10) = (/1,2,3,4,5,6,7,8,9,10/) integer :: i = 2 - - !$omp target map(to:read_arr(2:5)) map(from:write_arr(2:5)) map(tofrom:i) - do i = 2, 5 + integer :: j = 5 + !$omp target map(to:read_arr(2:5)) map(from:write_arr(2:5)) map(to:i,j) + do while (i <= j) write_arr(i) = read_arr(i) + i = i + 1 end do !$omp end target diff --git a/openmp/libomptarget/test/offloading/fortran/basic-target-region-3D-array-section.f90 b/openmp/libomptarget/test/offloading/fortran/basic-target-region-3D-array-section.f90 index 28b2afced4d1b..e3df7983e6b5c 100644 --- a/openmp/libomptarget/test/offloading/fortran/basic-target-region-3D-array-section.f90 +++ b/openmp/libomptarget/test/offloading/fortran/basic-target-region-3D-array-section.f90 @@ -14,6 +14,7 @@ program main integer :: inArray(3,3,3) integer :: outArray(3,3,3) integer :: i, j, k + integer :: j2 = 3, k2 = 3 do i = 1, 3 do j = 1, 3 @@ -24,11 +25,16 @@ program main end do end do -!$omp target map(tofrom:inArray(1:3, 1:3, 2:2), outArray(1:3, 1:3, 1:3), j, k) - do j = 1, 3 - do k = 1, 3 +j = 1 +k = 1 +!$omp target map(tofrom:inArray(1:3, 1:3, 2:2), outArray(1:3, 1:3, 1:3), j, k, j2, k2) + do while (j <= j2) + k = 1 + do while (k <= k2) outArray(k, j, 2) = inArray(k, j, 2) + k = k + 1 end do + j = j + 1 end do !$omp end target diff --git a/openmp/libomptarget/test/offloading/fortran/basic-target-region-3D-array.f90 b/openmp/libomptarget/test/offloading/fortran/basic-target-region-3D-array.f90 index 58f42138ad0af..44ff394dcda16 100644 --- a/openmp/libomptarget/test/offloading/fortran/basic-target-region-3D-array.f90 +++ b/openmp/libomptarget/test/offloading/fortran/basic-target-region-3D-array.f90 @@ -10,9 +10,9 @@ ! RUN: %libomptarget-compile-fortran-run-and-check-generic program main - implicit none integer :: x(2,2,2) - integer :: i = 1, j = 1, k = 1 + integer :: i, j, k + integer :: i2 = 2, j2 = 2, k2 = 2 integer :: counter = 1 do i = 1, 2 do j = 1, 2 @@ -22,14 +22,23 @@ program main end do end do -!$omp target map(tofrom:x, i, j, k, counter) - do i = 1, 2 - do j = 1, 2 - do k = 1, 2 +i = 1 +j = 1 +k = 1 + +!$omp target map(tofrom:x, counter) map(to: i, j, k, i2, j2, k2) + do while (i <= i2) + j = 1 + do while (j <= j2) + k = 1 + do while (k <= k2) x(i, j, k) = counter counter = counter + 1 + k = k + 1 end do + j = j + 1 end do + i = i + 1 end do !$omp end target