Skip to content

Commit

Permalink
Replace size quick fix (pytorch#97)
Browse files Browse the repository at this point in the history
* Fix replace size when a reduction dim is not in inner most.

* Clang tidy.

* Remove print statement in test.
  • Loading branch information
csarofeen committed Jun 18, 2020
1 parent 3ecb427 commit b1725ac
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
41 changes: 41 additions & 0 deletions test/cpp/jit/test_gpu.cpp
Expand Up @@ -3262,6 +3262,47 @@ void testGPU_FusionGridReduction6() {
auto aten_output = input.sum({1, 2});
TORCH_CHECK(aten_output.allclose(cg_output));
}

void testGPU_FusionNonRedAxisBind() {
int bid_x = 3;
int tid_x = 2;
int red_dim = 0;

torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);

// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);

TensorView* tv1 =
reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0);
fusion.addOutput(tv1);

tv1->split(-1, tid_x);
tv1->axis(-2)->parallelize(ParallelType::BIDx);
tv1->axis(-1)->parallelize(ParallelType::TIDx);

prog.device_ = 0;
prog.grid(bid_x);
prog.block(tid_x);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::rand({16, bid_x * tid_x}, options);
at::Tensor cg_output = at::empty({bid_x * tid_x}, options);

torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});

auto aten_output = input.sum({red_dim});

TORCH_CHECK(
aten_output.allclose(cg_output),
"Error of: ",
aten_output.sub(cg_output).abs().max());
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
3 changes: 2 additions & 1 deletion test/cpp/jit/tests.h
Expand Up @@ -151,7 +151,8 @@ namespace jit {
_(GPU_FusionGridReduction3dim0) \
_(GPU_FusionGridReduction4) \
_(GPU_FusionGridReduction5) \
_(GPU_FusionGridReduction6)
_(GPU_FusionGridReduction6) \
_(GPU_FusionNonRedAxisBind)
#else
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
Expand Down
9 changes: 5 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_validation.cpp
Expand Up @@ -101,15 +101,16 @@ void IRReplaceSizes() {
std::vector<IterDomain*> new_domain_iters;
const std::vector<IterDomain*>& root_td = tv->getRootDomain();

for (decltype(root_td.size()) i{0}; i < root_td.size(); i++) {
size_t dim = 0;
for (auto id : root_td) {
// Output sizes could have reduction axes, which isn't what gets output.
if (root_td[i]->isReduction())
if (id->isReduction())
continue;

Val* orig_size = root_td[i]->extent();
Val* orig_size = id->extent();

std::stringstream ss;
ss << "T" << tv->name() << ".size[" << i << "]";
ss << "T" << tv->name() << ".size[" << dim++ << "]";
Val* new_size =
new NamedScalar(ss.str(), orig_size->getDataType().value());
if (!orig_size->sameAs(new_size) ||
Expand Down

0 comments on commit b1725ac

Please sign in to comment.