Skip to content

Commit

Permalink
Fix semantic of block reduce and grid reduce; support block reduce in…
Browse files Browse the repository at this point in the history
… loop (pytorch#712)

* fix reduce semantics; add sync to block in loop

* comment

* format

* clang-tidy

* merge the syncthread into block kernel

* reverted commented test check

* add syncthread at the end of blockbroadcast
  • Loading branch information
shmsong committed Mar 5, 2021
1 parent 91ecbaf commit d2cb5fb
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 41 deletions.
82 changes: 72 additions & 10 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11178,9 +11178,9 @@ __global__ void kernel1(
__shared__ long mem_N[512];
float in=inp[threadIdx.x*inp.stride[0]+
threadIdx.y*inp.stride[1]];
float tmp_M2;
float tmp_avg;
long tmp_N;
float tmp_M2=0;
float tmp_avg=0;
long tmp_N=0;
blockWelford<false,true,false>(
tmp_M2,
tmp_avg,
Expand Down Expand Up @@ -11265,9 +11265,9 @@ __global__ void kernel1(
float in=inp[threadIdx.x*inp.stride[0]+
threadIdx.y*inp.stride[1]+
threadIdx.z*inp.stride[2]];
float tmp_M2;
float tmp_avg;
long tmp_N;
float tmp_M2=0;
float tmp_avg=0;
long tmp_N=0;
blockWelford<false,true,true>(
tmp_M2,
tmp_avg,
Expand Down Expand Up @@ -11328,9 +11328,9 @@ __global__ void kernel1(
__shared__ float shared_buf_M2[512];
__shared__ float shared_buf_avg[512];
__shared__ long shared_buf_N[512];
float tmp_M2;
float tmp_avg;
long tmp_N;
float tmp_M2=0;
float tmp_avg=0;
long tmp_N=0;
float in = inp[ blockIdx.x * inp.stride[0]+
blockIdx.y * inp.stride[1]+
threadIdx.x * inp.stride[2]];
Expand Down Expand Up @@ -11750,7 +11750,6 @@ TEST(NVFuserTest, FusionWelfordShmoo_CUDA) {
if (rdim > 32768 && dtype == DataType::Half) {
continue;
}

testWelford(dtype, axis, odim, rdim);
}
}
Expand Down Expand Up @@ -13381,6 +13380,69 @@ TEST(NVFuserTest, FusionValidateParallelize5_CUDA) {
fe.compileFusion(&fusion);
}

TEST(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int M = 10;
constexpr int N = 20;
constexpr int K = 20;

auto tv0 = makeSymbolicTensor(3);
auto tv1 = sum(tv0, {{1, 2}});
fusion.addInput(tv0);
fusion.addOutput(tv1);

tv1->axis(-1)->parallelize(ParallelType::TIDx);
tv1->axis(0)->parallelize(ParallelType::BIDx);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::manual_seed(0);
at::Tensor t0 = at::randn({M, N, K}, options);
std::vector<IValue> aten_inputs = {t0};

FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion(aten_inputs);
at::Tensor aten_output = t0.sum({1, 2});
testValidate(
&fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
}

TEST(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int M = 10;
constexpr int N = 20;
constexpr int K = 20;

auto tv0 = makeSymbolicTensor(3);
auto tvs = Welford(tv0, {{1, 2}});
fusion.addInput(tv0);
auto tv_M2 = tvs.var;
auto tv_avg = tvs.avg;
auto tv_N = tvs.n;
fusion.addOutput(tv_M2);
fusion.addOutput(tv_avg);

tv_avg->axis(-1)->parallelize(ParallelType::TIDx);
tv_avg->axis(0)->parallelize(ParallelType::BIDx);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::manual_seed(0);
at::Tensor t0 = at::randn({M, N, K}, options);
std::vector<IValue> aten_inputs = {t0};

FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion(aten_inputs);
at::Tensor aten_M2 = t0.var({1, 2}, false) * N * K;
at::Tensor aten_avg = t0.mean({1, 2});
testValidate(
&fusion, outputs, aten_inputs, {aten_M2, aten_avg}, __LINE__, __FILE__);
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
12 changes: 10 additions & 2 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,9 @@ WelfordResult Welford(
// Initial values for welford op are tensors, so their dims have to match the
// output dim,
// i.e. original_dims - dims_to_be_reduced
Val* init_var_val = nullptr;
Val* init_avg_val = nullptr;

if (!init_N->isZeroInt()) {
TORCH_CHECK(
init_avg != nullptr && init_N != nullptr && init_var != nullptr,
Expand All @@ -745,6 +748,11 @@ WelfordResult Welford(
(axes.size() + init_avg->getRootDomain().size()) ==
tv->getRootDomain().size(),
"welford op: initial tensor mismatch");
init_var_val = init_var;
init_avg_val = init_avg;
} else {
init_var_val = new Double(0);
init_avg_val = new Double(0);
}

// Check and collect reduction axes
Expand Down Expand Up @@ -773,8 +781,8 @@ WelfordResult Welford(
out_var,
out_avg,
out_N, /*out var/avg/count */
init_var,
init_avg,
init_var_val,
init_avg_val,
init_N, /*init var/avg/count */
nullptr,
tv,
Expand Down
12 changes: 4 additions & 8 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,7 @@ class CudaKernelGenerator : private kir::IrVisitor {
if (has_block_reduce) {
if (has_grid_reduce) {
indent() << data_type << " "
<< "block_result"
<< ";\n";
<< "block_result=" << gen(node->init()) << ";\n";
}
indent() << "blockReduce<" << (tidx ? "true" : "false") << ", "
<< (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
Expand Down Expand Up @@ -662,14 +661,11 @@ class CudaKernelGenerator : private kir::IrVisitor {
if (has_grid_reduce) {
// allocate block result
indent() << data_type << " "
<< "block_result_var"
<< ";\n";
<< "block_result_var = " << gen(node->initVar()) << ";\n";
indent() << data_type << " "
<< "block_result_avg"
<< ";\n";
<< "block_result_avg = " << gen(node->initAvg()) << ";\n";
indent() << DataType::Int << " "
<< "block_result_n"
<< ";\n";
<< "block_result_n = " << gen(node->initN()) << ";\n";
}
indent() << "blockWelford<" << (tidx ? "true" : "false") << ", "
<< (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ class GpuLower::KernelIrMapper : private OptInConstDispatch {
lowerValue(node->outVar()),
lowerValue(node->outAvg()),
lowerValue(node->outN()),
lowerOptional(node->initVar()),
lowerOptional(node->initAvg()),
lowerValue(node->initVar()),
lowerValue(node->initAvg()),
lowerValue(node->initN()),
lowerOptional(node->inVar()),
lowerValue(node->inAvg()),
Expand Down
15 changes: 11 additions & 4 deletions torch/csrc/jit/codegen/cuda/runtime/block_reduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ __device__ void blockReduce(
}
}
__syncthreads();
// for (int factor = np2/2; factor > contig_threads / 2; factor>>=1) {
for (int factor = np2 / 2; factor > 0; factor >>= 1) {
// loop peel the final iteration to save one syncthread for the end
for (int factor = np2 / 2; factor > 1; factor >>= 1) {
if (reduction_tid < factor) {
reduction_op(
shared_mem[linear_tid],
Expand All @@ -99,6 +99,13 @@ __device__ void blockReduce(
__syncthreads();
}

if (should_write && read_write_pred)
out = shared_mem[linear_tid];
if (should_write && read_write_pred) {
T result = out;
reduction_op(result, shared_mem[linear_tid]);
if (reduction_size > 1) {
reduction_op(result, shared_mem[linear_tid + 1 * reduction_stride]);
}
out = result;
}
__syncthreads();
}
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/broadcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ __device__ void blockBroadcast(T& out, T inp_val, T* shared_mem) {
__syncthreads();

out = shared_mem[shared_offset];

__syncthreads();
}

} // namespace broadcast
6 changes: 4 additions & 2 deletions torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,9 @@ __device__ void gridReduceLastBlock(
if (rem_size > 1) {
const int rblock_offset = tid % rblock_size;
const int rblock_idx = tid / rblock_size;
T inp_tmp = init_val;
blockReduce<false, true, false>(
inp,
inp_tmp,
inp,
reduction_op,
dim3{(unsigned)rblock_offset, (unsigned)rblock_idx, 0},
Expand All @@ -231,6 +232,7 @@ __device__ void gridReduceLastBlock(
true,
init_val);
__syncthreads();
inp = inp_tmp;
if (tid < rblock_size) {
shared_buf[tid] = inp;
}
Expand All @@ -242,7 +244,7 @@ __device__ void gridReduceLastBlock(
}

if (should_write && read_write_pred) {
out = inp;
reduction_op(out, inp);
}
}

Expand Down
49 changes: 36 additions & 13 deletions torch/csrc/jit/codegen/cuda/runtime/welford.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ __inline__ __device__ void blockWelford(
}
}
__syncthreads();
for (int factor = np2 / 2; factor > 0; factor >>= 1) {

// loop peel the final iteration to save one syncthread for the end
for (int factor = np2 / 2; factor > 1; factor >>= 1) {
if (reduction_tid < factor) {
welfordCombine(
shared_mem_M2[linear_tid],
Expand All @@ -116,10 +118,30 @@ __inline__ __device__ void blockWelford(
__syncthreads();
}
if (should_write && read_write_pred) {
out_M2 = shared_mem_M2[linear_tid];
out_avg = shared_mem_avg[linear_tid];
out_N = shared_mem_N[linear_tid];
T res_M2 = out_M2;
T res_avg = out_avg;
TN res_N = out_N;
welfordCombine(
res_M2,
res_avg,
res_N,
shared_mem_M2[linear_tid],
shared_mem_avg[linear_tid],
shared_mem_N[linear_tid]);
if (reduction_size > 1) {
welfordCombine(
res_M2,
res_avg,
res_N,
shared_mem_M2[linear_tid + reduction_stride],
shared_mem_avg[linear_tid + reduction_stride],
shared_mem_N[linear_tid + reduction_stride]);
}
out_M2 = res_M2;
out_avg = res_avg;
out_N = res_N;
}
__syncthreads();
}
// -----------------------------------------------------------------------------------------------
// Grid Welford Prototype
Expand Down Expand Up @@ -278,10 +300,13 @@ __device__ void gridWelfordLastBlock(
if (rem_size > 1) {
const int rblock_offset = tid % rblock_size;
const int rblock_idx = tid / rblock_size;
T inp_M2_tmp = init_val;
T inp_avg_tmp = init_val;
TN inp_N_tmp = 0;
blockWelford<false, true, false>(
inp_M2,
inp_avg,
inp_N,
inp_M2_tmp,
inp_avg_tmp,
inp_N_tmp,
inp_M2,
inp_avg,
inp_N,
Expand All @@ -294,9 +319,9 @@ __device__ void gridWelfordLastBlock(
init_val);
__syncthreads();
if (tid < rblock_size) {
shared_buf_M2[tid] = inp_M2;
shared_buf_avg[tid] = inp_avg;
shared_buf_N[tid] = inp_N;
shared_buf_M2[tid] = inp_M2_tmp;
shared_buf_avg[tid] = inp_avg_tmp;
shared_buf_N[tid] = inp_N_tmp;
}
__syncthreads();
if (should_write) {
Expand All @@ -310,9 +335,7 @@ __device__ void gridWelfordLastBlock(
}

if (should_write && read_write_pred) {
out_M2 = inp_M2;
out_avg = inp_avg;
out_N = inp_N;
welfordCombine(out_M2, out_avg, out_N, inp_M2, inp_avg, inp_N);
}
}

Expand Down

0 comments on commit d2cb5fb

Please sign in to comment.