diff --git a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp index e13d769a16aad9..18bf67f7fc8a14 100644 --- a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp +++ b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp @@ -789,9 +789,17 @@ int32_t __tgt_rtl_init_device(int device_id) { DP("Default number of teams set according to environment %d\n", DeviceInfo.EnvNumTeams); } else { - DeviceInfo.NumTeams[device_id] = RTLDeviceInfoTy::DefaultNumTeams; - DP("Default number of teams set according to library's default %d\n", - RTLDeviceInfoTy::DefaultNumTeams); + char *TeamsPerCUEnvStr = getenv("OMP_TARGET_TEAMS_PER_PROC"); + int TeamsPerCU = 1; // default number of teams per CU is 1 + if (TeamsPerCUEnvStr) { + TeamsPerCU = std::stoi(TeamsPerCUEnvStr); + } + + DeviceInfo.NumTeams[device_id] = + TeamsPerCU * DeviceInfo.ComputeUnits[device_id]; + DP("Default number of teams = %d * number of compute units %d\n", + TeamsPerCU, + DeviceInfo.ComputeUnits[device_id]); } if (DeviceInfo.NumTeams[device_id] > DeviceInfo.GroupsPerDevice[device_id]) { @@ -1548,11 +1556,12 @@ int32_t __tgt_rtl_data_delete(int device_id, void *tgt_ptr) { // loop_tripcount. void getLaunchVals(int &threadsPerGroup, int &num_groups, int ConstWGSize, int ExecutionMode, int EnvTeamLimit, int EnvNumTeams, - int num_teams, int thread_limit, uint64_t loop_tripcount) { + int num_teams, int thread_limit, uint64_t loop_tripcount, + int32_t device_id) { int Max_Teams = DeviceInfo.EnvMaxTeamsDefault > 0 ? DeviceInfo.EnvMaxTeamsDefault - : DeviceInfo.Max_Teams; + : DeviceInfo.NumTeams[device_id]; if (Max_Teams > DeviceInfo.HardTeamLimit) Max_Teams = DeviceInfo.HardTeamLimit; @@ -1752,7 +1761,8 @@ int32_t __tgt_rtl_run_target_team_region_locked( DeviceInfo.EnvNumTeams, num_teams, // From run_region arg thread_limit, // From run_region arg - loop_tripcount // From run_region arg + loop_tripcount, // From run_region arg + KernelInfo->device_id ); if (print_kernel_trace == 4)